Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add register_pickle_by_value #16

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/gallery/autogen/how_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,33 @@ def add(x, y):
print("exit_status:", node.exit_status)
print("exit_message:", node.exit_message)

######################################################################
# Using `register_pickle_by_value`
# --------------------------------
#
# If the function is defined inside an external module that is **not installed** on
# the remote computer, this can cause import errors during execution.
#
# **Solution:**
# By enabling `register_pickle_by_value=True`, the function is serialized **by value**
# instead of being referenced by its module path. This embeds the function unpickled
# even if the original module is unavailable on the remote computer.
#
# **Example:**
#
# .. code-block:: python
#
# inputs = prepare_pythonjob_inputs(
# my_function,
# function_inputs={"x": 1, "y": 2},
# computer="localhost",
# register_pickle_by_value=True, # Ensures function is embedded
# )
#
# **Important Considerations:**: If the function **contains import statements**,
# the imported modules **must still be installed** on the remote computer.
#


######################################################################
# Define your data serializer and deserializer
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ dependencies = [
]

[project.optional-dependencies]
dev = [
"hatch",
]
pre-commit = [
'pre-commit~=3.5',
]
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def prepare_pythonjob_inputs(
function_data: dict | None = None,
deserializers: dict | None = None,
serializers: dict | None = None,
register_pickle_by_value: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the inputs for PythonJob"""
Expand All @@ -33,7 +34,7 @@ def prepare_pythonjob_inputs(
raise ValueError("Only one of function or function_data should be provided")
# if function is a function, inspect it and get the source code
if function is not None and inspect.isfunction(function):
function_data = build_function_data(function)
function_data = build_function_data(function, register_pickle_by_value=register_pickle_by_value)
new_upload_files = {}
# change the string in the upload files to SingleFileData, or FolderData
for key, source in upload_files.items():
Expand Down
47 changes: 28 additions & 19 deletions src/aiida_pythonjob/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, _SpecialForm, get_type_hints

Expand All @@ -6,8 +7,6 @@


def import_from_path(path: str) -> Any:
import importlib

module_name, object_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
try:
Expand Down Expand Up @@ -47,24 +46,39 @@ def add_imports(type_hint):
return imports


def inspect_function(func: Callable) -> Dict[str, Any]:
def inspect_function(
func: Callable, inspect_source: bool = False, register_pickle_by_value: bool = False
) -> Dict[str, Any]:
"""Serialize a function for storage or transmission."""
# we need save the source code explicitly, because in the case of jupyter notebook,
# the source code is not saved in the pickle file
import cloudpickle

from aiida_pythonjob.data.pickled_data import PickledData

try:
source_code = inspect.getsource(func)
# Split the source into lines for processing
source_code_lines = source_code.split("\n")
source_code = "\n".join(source_code_lines)
except OSError:
source_code = "Failed to retrieve source code."
if inspect_source:
try:
source_code = inspect.getsource(func)
# Split the source into lines for processing
source_code_lines = source_code.split("\n")
source_code = "\n".join(source_code_lines)
except OSError:
source_code = "Failed to retrieve source code."
else:
source_code = ""

if register_pickle_by_value:
module = importlib.import_module(func.__module__)
cloudpickle.register_pickle_by_value(module)
pickled_function = PickledData(value=func)
cloudpickle.unregister_pickle_by_value(module)
else:
pickled_function = PickledData(value=func)

return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": PickledData(value=func)}
return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": pickled_function}


def build_function_data(func: Callable) -> Dict[str, Any]:
def build_function_data(func: Callable, register_pickle_by_value: bool = False) -> Dict[str, Any]:
"""Inspect the function and return a dictionary with the function data."""
import types

Expand All @@ -73,15 +87,10 @@ def build_function_data(func: Callable) -> Dict[str, Any]:
function_data = {"name": func.__name__}
if func.__module__ == "__main__" or "." in func.__qualname__.split(".", 1)[-1]:
# Local or nested callable, so pickle the callable
function_data.update(inspect_function(func))
function_data.update(inspect_function(func, inspect_source=True))
else:
# Global callable (function/class), store its module and name for reference
function_data.update(
{
"mode": "use_module_path",
"source_code": f"from {func.__module__} import {func.__name__}",
}
)
function_data.update(inspect_function(func, register_pickle_by_value=register_pickle_by_value))
else:
raise TypeError("Provided object is not a callable function or class.")
return function_data
Expand Down
55 changes: 42 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
import pytest
from aiida_pythonjob.utils import build_function_data


def test_build_function_data():
from math import sqrt

function_data = build_function_data(sqrt)
assert function_data == {
"name": "sqrt",
"mode": "use_module_path",
"source_code": "from math import sqrt",
}
#
try:
function_data = build_function_data(1)
except Exception as e:
assert str(e) == "Provided object is not a callable function or class."
"""Test the build_function_data function behavior."""

with pytest.raises(TypeError, match="Provided object is not a callable function or class."):
build_function_data(1)

function_data = build_function_data(build_function_data)
assert function_data["name"] == "build_function_data"
assert "source_code" in function_data
assert "pickled_function" in function_data
node = function_data["pickled_function"]
with node.base.repository.open(node.FILENAME, mode="rb") as f:
text = f.read()
assert b"cloudpickle" not in text

function_data = build_function_data(build_function_data, register_pickle_by_value=True)
assert function_data["name"] == "build_function_data"
assert "source_code" in function_data
assert "pickled_function" in function_data
node = function_data["pickled_function"]
with node.base.repository.open(node.FILENAME, mode="rb") as f:
text = f.read()
assert b"cloudpickle" in text

def local_function(x, y):
return x + y

function_data = build_function_data(local_function)
assert function_data["name"] == "local_function"
assert "source_code" in function_data
assert function_data["mode"] == "use_pickled_function"

def outer_function():
def nested_function(x, y):
return x + y

return nested_function

nested_func = outer_function()
function_data = build_function_data(nested_func)
assert function_data["name"] == "nested_function"
assert function_data["mode"] == "use_pickled_function"
Loading