Skip to content

Commit

Permalink
Add register_pickle_by_value (#16)
Browse files Browse the repository at this point in the history
By enabling register_pickle_by_value=True, the function is serialized by value instead of being referenced by its module path. This embeds the function unpickled even if the original module is unavailable on the remote computer.

Note: If the function contains import statements, the imported modules must still be installed on the remote computer.
  • Loading branch information
superstar54 authored Feb 11, 2025
1 parent 839e3ba commit 3b0f48f
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 33 deletions.
27 changes: 27 additions & 0 deletions docs/gallery/autogen/how_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,33 @@ def add(x, y):
print("exit_status:", node.exit_status)
print("exit_message:", node.exit_message)

######################################################################
# Using `register_pickle_by_value`
# --------------------------------
#
# If the function is defined inside an external module that is **not installed** on
# the remote computer, this can cause import errors during execution.
#
# **Solution:**
# By enabling `register_pickle_by_value=True`, the function is serialized **by value**
# instead of being referenced by its module path. This embeds the function unpickled
# even if the original module is unavailable on the remote computer.
#
# **Example:**
#
# .. code-block:: python
#
# inputs = prepare_pythonjob_inputs(
# my_function,
# function_inputs={"x": 1, "y": 2},
# computer="localhost",
# register_pickle_by_value=True, # Ensures function is embedded
# )
#
# **Important Considerations:**: If the function **contains import statements**,
# the imported modules **must still be installed** on the remote computer.
#


######################################################################
# Define your data serializer and deserializer
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ dependencies = [
]

[project.optional-dependencies]
dev = [
"hatch",
]
pre-commit = [
'pre-commit~=3.5',
]
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def prepare_pythonjob_inputs(
function_data: dict | None = None,
deserializers: dict | None = None,
serializers: dict | None = None,
register_pickle_by_value: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the inputs for PythonJob"""
Expand All @@ -33,7 +34,7 @@ def prepare_pythonjob_inputs(
raise ValueError("Only one of function or function_data should be provided")
# if function is a function, inspect it and get the source code
if function is not None and inspect.isfunction(function):
function_data = build_function_data(function)
function_data = build_function_data(function, register_pickle_by_value=register_pickle_by_value)
new_upload_files = {}
# change the string in the upload files to SingleFileData, or FolderData
for key, source in upload_files.items():
Expand Down
47 changes: 28 additions & 19 deletions src/aiida_pythonjob/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, _SpecialForm, get_type_hints

Expand All @@ -6,8 +7,6 @@


def import_from_path(path: str) -> Any:
import importlib

module_name, object_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
try:
Expand Down Expand Up @@ -47,24 +46,39 @@ def add_imports(type_hint):
return imports


def inspect_function(func: Callable) -> Dict[str, Any]:
def inspect_function(
func: Callable, inspect_source: bool = False, register_pickle_by_value: bool = False
) -> Dict[str, Any]:
"""Serialize a function for storage or transmission."""
# we need save the source code explicitly, because in the case of jupyter notebook,
# the source code is not saved in the pickle file
import cloudpickle

from aiida_pythonjob.data.pickled_data import PickledData

try:
source_code = inspect.getsource(func)
# Split the source into lines for processing
source_code_lines = source_code.split("\n")
source_code = "\n".join(source_code_lines)
except OSError:
source_code = "Failed to retrieve source code."
if inspect_source:
try:
source_code = inspect.getsource(func)
# Split the source into lines for processing
source_code_lines = source_code.split("\n")
source_code = "\n".join(source_code_lines)
except OSError:
source_code = "Failed to retrieve source code."
else:
source_code = ""

if register_pickle_by_value:
module = importlib.import_module(func.__module__)
cloudpickle.register_pickle_by_value(module)
pickled_function = PickledData(value=func)
cloudpickle.unregister_pickle_by_value(module)
else:
pickled_function = PickledData(value=func)

return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": PickledData(value=func)}
return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": pickled_function}


def build_function_data(func: Callable) -> Dict[str, Any]:
def build_function_data(func: Callable, register_pickle_by_value: bool = False) -> Dict[str, Any]:
"""Inspect the function and return a dictionary with the function data."""
import types

Expand All @@ -73,15 +87,10 @@ def build_function_data(func: Callable) -> Dict[str, Any]:
function_data = {"name": func.__name__}
if func.__module__ == "__main__" or "." in func.__qualname__.split(".", 1)[-1]:
# Local or nested callable, so pickle the callable
function_data.update(inspect_function(func))
function_data.update(inspect_function(func, inspect_source=True))
else:
# Global callable (function/class), store its module and name for reference
function_data.update(
{
"mode": "use_module_path",
"source_code": f"from {func.__module__} import {func.__name__}",
}
)
function_data.update(inspect_function(func, register_pickle_by_value=register_pickle_by_value))
else:
raise TypeError("Provided object is not a callable function or class.")
return function_data
Expand Down
55 changes: 42 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
import pytest
from aiida_pythonjob.utils import build_function_data


def test_build_function_data():
from math import sqrt

function_data = build_function_data(sqrt)
assert function_data == {
"name": "sqrt",
"mode": "use_module_path",
"source_code": "from math import sqrt",
}
#
try:
function_data = build_function_data(1)
except Exception as e:
assert str(e) == "Provided object is not a callable function or class."
"""Test the build_function_data function behavior."""

with pytest.raises(TypeError, match="Provided object is not a callable function or class."):
build_function_data(1)

function_data = build_function_data(build_function_data)
assert function_data["name"] == "build_function_data"
assert "source_code" in function_data
assert "pickled_function" in function_data
node = function_data["pickled_function"]
with node.base.repository.open(node.FILENAME, mode="rb") as f:
text = f.read()
assert b"cloudpickle" not in text

function_data = build_function_data(build_function_data, register_pickle_by_value=True)
assert function_data["name"] == "build_function_data"
assert "source_code" in function_data
assert "pickled_function" in function_data
node = function_data["pickled_function"]
with node.base.repository.open(node.FILENAME, mode="rb") as f:
text = f.read()
assert b"cloudpickle" in text

def local_function(x, y):
return x + y

function_data = build_function_data(local_function)
assert function_data["name"] == "local_function"
assert "source_code" in function_data
assert function_data["mode"] == "use_pickled_function"

def outer_function():
def nested_function(x, y):
return x + y

return nested_function

nested_func = outer_function()
function_data = build_function_data(nested_func)
assert function_data["name"] == "nested_function"
assert function_data["mode"] == "use_pickled_function"

0 comments on commit 3b0f48f

Please sign in to comment.