|  | 
|  | 1 | +# ----------------------------------------------------------------------------- | 
|  | 2 | +# Copyright (c) 2025, Oracle and/or its affiliates. | 
|  | 3 | +# | 
|  | 4 | +# Licensed under the Universal Permissive License v 1.0 as shown at | 
|  | 5 | +# http://oss.oracle.com/licenses/upl. | 
|  | 6 | +# ----------------------------------------------------------------------------- | 
|  | 7 | + | 
|  | 8 | +import inspect | 
|  | 9 | +from collections.abc import Mapping, Sequence, Set | 
|  | 10 | +from functools import wraps | 
|  | 11 | +from typing import Any, get_args, get_origin, get_type_hints | 
|  | 12 | + | 
|  | 13 | +NoneType = type(None) | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +def _match(value, annot) -> bool: | 
|  | 17 | +    """Recursively validate value against a typing annotation.""" | 
|  | 18 | +    if annot is Any: | 
|  | 19 | +        return True | 
|  | 20 | + | 
|  | 21 | +    origin = get_origin(annot) | 
|  | 22 | +    args = get_args(annot) | 
|  | 23 | + | 
|  | 24 | +    # Handle Annotated[T, ...] → treat as T | 
|  | 25 | +    if origin is getattr(__import__("typing"), "Annotated", None): | 
|  | 26 | +        annot = args[0] | 
|  | 27 | +        origin = get_origin(annot) | 
|  | 28 | +        args = get_args(annot) | 
|  | 29 | + | 
|  | 30 | +    # Optional[T] is Union[T, NoneType] | 
|  | 31 | +    if origin is getattr(__import__("typing"), "Union", None): | 
|  | 32 | +        return any(_match(value, a) for a in args) | 
|  | 33 | + | 
|  | 34 | +    # Literal[…] | 
|  | 35 | +    if origin is getattr(__import__("typing"), "Literal", None): | 
|  | 36 | +        return any(value == lit for lit in args) | 
|  | 37 | + | 
|  | 38 | +    # Tuple cases | 
|  | 39 | +    if origin is tuple: | 
|  | 40 | +        if not isinstance(value, tuple): | 
|  | 41 | +            return False | 
|  | 42 | +        if len(args) == 2 and args[1] is Ellipsis: | 
|  | 43 | +            # tuple[T, ...] | 
|  | 44 | +            return all(_match(v, args[0]) for v in value) | 
|  | 45 | +        if len(args) != len(value): | 
|  | 46 | +            return False | 
|  | 47 | +        return all(_match(v, a) for v, a in zip(value, args)) | 
|  | 48 | + | 
|  | 49 | +    # Mappings (dict-like) | 
|  | 50 | +    if origin in (dict, Mapping): | 
|  | 51 | +        if not isinstance(value, Mapping): | 
|  | 52 | +            return False | 
|  | 53 | +        k_annot, v_annot = args if args else (Any, Any) | 
|  | 54 | +        return all( | 
|  | 55 | +            _match(k, k_annot) and _match(v, v_annot) for k, v in value.items() | 
|  | 56 | +        ) | 
|  | 57 | + | 
|  | 58 | +    # Sequences (list, Sequence) – but not str/bytes | 
|  | 59 | +    if origin in (list, Sequence): | 
|  | 60 | +        if isinstance(value, (str, bytes)): | 
|  | 61 | +            return False | 
|  | 62 | +        if not isinstance(value, Sequence): | 
|  | 63 | +            return False | 
|  | 64 | +        elem_annot = args[0] if args else Any | 
|  | 65 | +        return all(_match(v, elem_annot) for v in value) | 
|  | 66 | + | 
|  | 67 | +    # Sets | 
|  | 68 | +    if origin in (set, frozenset, Set): | 
|  | 69 | +        if not isinstance(value, (set, frozenset)): | 
|  | 70 | +            return False | 
|  | 71 | +        elem_annot = args[0] if args else Any | 
|  | 72 | +        return all(_match(v, elem_annot) for v in value) | 
|  | 73 | + | 
|  | 74 | +    # Fall back to normal isinstance for non-typing classes | 
|  | 75 | +    if isinstance(annot, type): | 
|  | 76 | +        return isinstance(value, annot) | 
|  | 77 | + | 
|  | 78 | +    # If annot is a typing alias like 'list' without args | 
|  | 79 | +    if origin is not None: | 
|  | 80 | +        # Treat bare containers as accepting anything inside | 
|  | 81 | +        return isinstance(value, origin) | 
|  | 82 | + | 
|  | 83 | +    # Unknown/unsupported typing form: accept conservatively | 
|  | 84 | +    return True | 
|  | 85 | + | 
|  | 86 | + | 
|  | 87 | +def enforce_types(func): | 
|  | 88 | +    # Resolve ForwardRefs using function globals (handles "User" as a string, etc.) | 
|  | 89 | +    hints = get_type_hints( | 
|  | 90 | +        func, globalns=func.__globals__, include_extras=True | 
|  | 91 | +    ) | 
|  | 92 | +    sig = inspect.signature(func) | 
|  | 93 | + | 
|  | 94 | +    def _check(bound): | 
|  | 95 | +        for name, val in bound.arguments.items(): | 
|  | 96 | +            if name in hints: | 
|  | 97 | +                annot = hints[name] | 
|  | 98 | +                if not _match(val, annot): | 
|  | 99 | +                    raise TypeError( | 
|  | 100 | +                        f"Argument '{name}' failed type check: expected {annot!r}, " | 
|  | 101 | +                        f"got {type(val).__name__} -> {val!r}" | 
|  | 102 | +                    ) | 
|  | 103 | + | 
|  | 104 | +    if inspect.iscoroutinefunction(func): | 
|  | 105 | + | 
|  | 106 | +        @wraps(func) | 
|  | 107 | +        async def aw(*args, **kwargs): | 
|  | 108 | +            bound = sig.bind(*args, **kwargs) | 
|  | 109 | +            bound.apply_defaults() | 
|  | 110 | +            _check(bound) | 
|  | 111 | +            return await func(*args, **kwargs) | 
|  | 112 | + | 
|  | 113 | +        return aw | 
|  | 114 | +    else: | 
|  | 115 | + | 
|  | 116 | +        @wraps(func) | 
|  | 117 | +        def w(*args, **kwargs): | 
|  | 118 | +            bound = sig.bind(*args, **kwargs) | 
|  | 119 | +            bound.apply_defaults() | 
|  | 120 | +            _check(bound) | 
|  | 121 | +            return func(*args, **kwargs) | 
|  | 122 | + | 
|  | 123 | +        return w | 
0 commit comments