diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb756aa..62f201d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,4 +34,4 @@ jobs: python-version-file: ".python-version" - name: Run tests - run: uv run pytest -n auto -v + run: uv run pytest -n 16 -v diff --git a/pyproject.toml b/pyproject.toml index 1f4b935..1d24b4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "tests/*" = [ + "D", # Test helpers don't need docstrings + "PLR2004", # Magic values are fine in test assertions "S", # It's tests, no need for security "PLC0415", # imports need to be late for pytest_in_docker to work. ] diff --git a/pytest_in_docker/_container.py b/pytest_in_docker/_container.py index 0f5342f..f92d1f8 100644 --- a/pytest_in_docker/_container.py +++ b/pytest_in_docker/_container.py @@ -5,8 +5,10 @@ import sys import tarfile import time +from types import FunctionType from typing import TYPE_CHECKING, Any +import cloudpickle import rpyc from pytest_in_docker._types import ContainerPrepareError @@ -173,76 +175,6 @@ def bootstrap_container( ) -def _make_picklable[T: Callable[..., Any]](func: T) -> T: - """Return a copy of *func* that cloudpickle will serialise by value. - - Two things prevent a test function from being naively pickled into a - remote container: - - 1. cloudpickle pickles importable functions *by reference* - (module + qualname), but the test module does not exist in the - container. - 2. pytest's assertion rewriter injects ``@pytest_ar`` into the - function's ``__globals__``, and that object drags in the test - module itself. - - We fix both by creating a **new** function object whose - ``__module__`` is ``"__mp_main__"`` (forces pickle-by-value) and - whose ``__globals__`` are a *shared* clean dict stripped of the - assertion-rewriter helper. All sibling callables (same module) are - cloned into the same ``clean_globals`` dict so transitive calls - between helpers resolve to the patched versions. - """ - import types # noqa: PLC0415 - - original_module = func.__module__ - - # First pass: build clean_globals with non-callable entries, - # collect names of same-module callables to patch. - clean_globals: dict[str, Any] = {} - to_patch: list[str] = [] - for k, v in func.__globals__.items(): - if k == "@pytest_ar": - continue - if ( - isinstance(v, types.FunctionType) - and getattr(v, "__module__", None) == original_module - ): - to_patch.append(k) - else: - clean_globals[k] = v - - # Second pass: clone callables so they all share clean_globals. - for k in to_patch: - orig = func.__globals__[k] - clone = types.FunctionType( - orig.__code__, - clean_globals, - orig.__name__, - orig.__defaults__, - orig.__closure__, - ) - clone.__module__ = "__mp_main__" - clone.__qualname__ = orig.__qualname__ - clone.__annotations__ = orig.__annotations__ - clone.__kwdefaults__ = orig.__kwdefaults__ - clean_globals[k] = clone - - # Clone the test function itself into the same shared dict. - clone = types.FunctionType( - func.__code__, - clean_globals, - func.__name__, - func.__defaults__, - func.__closure__, - ) - clone.__annotations__ = func.__annotations__ - clone.__kwdefaults__ = func.__kwdefaults__ - clone.__module__ = "__mp_main__" - clone.__qualname__ = func.__qualname__ - return clone # type: ignore[return-value] - - def run_pickled[T]( conn: Any, # noqa: ANN401 func: Callable[..., T], @@ -250,9 +182,30 @@ def run_pickled[T]( **kwargs: Any, # noqa: ANN401 ) -> T: """Serialize *func* with cloudpickle, send to container, execute there.""" - import cloudpickle # noqa: PLC0415 + module = sys.modules.get(func.__module__) + if module is not None: + cloudpickle.register_pickle_by_value(module) + + # Pickle a shallow copy of the function, excluding pytestmark — it + # references host-only objects (e.g. container factories) that can't + # be unpickled in the container. Copying avoids mutating the original. + func_copy = FunctionType( + func.__code__, + func.__globals__, + func.__name__, + func.__defaults__, + func.__closure__, + ) + func_copy.__kwdefaults__ = func.__kwdefaults__ + func_copy.__dict__.update( + {k: v for k, v in func.__dict__.items() if k != "pytestmark"} + ) + try: + payload = cloudpickle.dumps(func_copy) + finally: + if module is not None: + cloudpickle.unregister_pickle_by_value(module) - payload = cloudpickle.dumps(_make_picklable(func)) rpickle = conn.modules["pickle"] remote_func = rpickle.loads(payload) return remote_func(*args, **kwargs) diff --git a/tests/test_class_support.py b/tests/test_class_support.py new file mode 100644 index 0000000..250a3db --- /dev/null +++ b/tests/test_class_support.py @@ -0,0 +1,88 @@ +"""Tests proving class serialization works with module-level classes.""" + +from enum import Enum + +from pytest_in_docker import in_container + + +class Greeter: + def __init__(self, name: str) -> None: + self.name = name + + def greet(self) -> str: + return f"hello {self.name}" + + +@in_container("python:alpine") +def test_module_level_class() -> None: + """Test function instantiates a module-level class.""" + g = Greeter("world") + assert g.greet() == "hello world" + + +class Base: + def base_method(self) -> str: + return "base" + + +class Child(Base): + def __init__(self, x: int) -> None: + self.x = x + + @property + def value(self) -> int: + return self.x + + @staticmethod + def static_thing() -> str: + return "static" + + @classmethod + def from_string(cls, s: str) -> Child: + return cls(int(s)) + + def child_method(self) -> str: + return f"child {self.base_method()}" + + +@in_container("python:alpine") +def test_class_inheritance() -> None: + """Inherited classes serialize correctly.""" + c = Child(42) + assert c.value == 42 + assert c.static_thing() == "static" + assert c.child_method() == "child base" + c2 = Child.from_string("10") + assert c2.x == 10 + + +class Config: + def __init__(self, name: str) -> None: + self.name = name + + +class Service: + def __init__(self) -> None: + self.config = Config("default") + + def get_config_name(self) -> str: + return self.config.name + + +@in_container("python:alpine") +def test_class_referencing_another_class() -> None: + """A class whose methods instantiate another module-level class.""" + s = Service() + assert s.get_config_name() == "default" + + +class Color(Enum): + RED = 1 + GREEN = 2 + + +@in_container("python:alpine") +def test_module_level_enum() -> None: + """Enum classes serialize correctly.""" + assert Color.RED.value == 1 + assert Color.GREEN.name == "GREEN"