diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 1667cb7a916a..af8eabbbeff0 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -129,6 +129,10 @@ def get_coder(self, typehint: Any) -> coders.Coder: # See https://github.com/apache/beam/issues/21541 # TODO(robertwb): Remove once all runners are portable. typehint = getattr(typehint, '__name__', str(typehint)) + if hasattr(typehint, '__supertype__'): + # Typehint is a typing.NewType. We need to get the underlying type. + while hasattr(typehint, '__supertype__'): + typehint = typehint.__supertype__ coder = self._coders.get( typehint.__class__ if isinstance(typehint, typehints.TypeConstraint) else typehint, diff --git a/sdks/python/apache_beam/coders/typecoders_test.py b/sdks/python/apache_beam/coders/typecoders_test.py index 3adc8255409d..46ea12fdfa96 100644 --- a/sdks/python/apache_beam/coders/typecoders_test.py +++ b/sdks/python/apache_beam/coders/typecoders_test.py @@ -16,6 +16,7 @@ # """Unit tests for the typecoders module.""" +import typing # pytype: skip-file import unittest @@ -121,6 +122,12 @@ def test_iterable_coder(self): self.assertEqual(expected_coder, real_coder) self.assertEqual(real_coder.encode(values), expected_coder.encode(values)) + def test_newtype_coder(self): + UserID = typing.NewType('UserID', str) + expected_coder = typecoders.registry.get_coder(str) + real_coder = typecoders.registry.get_coder(UserID) + self.assertEqual(expected_coder, real_coder) + @unittest.skip('https://github.com/apache/beam/issues/21658') def test_list_coder(self): real_coder = typecoders.registry.get_coder(typehints.List[bytes]) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 55653ecec19b..a1524fae83e0 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -319,12 +319,6 @@ def convert_to_beam_type(typ): # TODO(https://github.com/apache/beam/issues/19954): Currently unhandled. _LOGGER.info('Converting string literal type hint to Any: "%s"', typ) return typehints.Any - elif sys.version_info >= (3, 10) and isinstance(typ, typing.NewType): # pylint: disable=isinstance-second-argument-not-valid-type - # Special case for NewType, where, since Python 3.10, NewType is now a class - # rather than a function. - # TODO(https://github.com/apache/beam/issues/20076): Currently unhandled. - _LOGGER.info('Converting NewType type hint to Any: "%s"', typ) - return typehints.Any elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \ getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue': # Need to pass through WindowedValue class so that it can be converted @@ -343,9 +337,6 @@ def convert_to_beam_type(typ): return typ type_map = [ - # TODO(https://github.com/apache/beam/issues/20076): Currently - # unsupported. - _TypeMapEntry(match=is_new_type, arity=0, beam_type=typehints.Any), # TODO(https://github.com/apache/beam/issues/19954): Currently # unsupported. _TypeMapEntry(match=is_forward_ref, arity=0, beam_type=typehints.Any), diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index 3f3603c2c978..0242c08511db 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -278,10 +278,6 @@ def test_generator_converted_to_iterator(self): typehints.Iterator[int], convert_to_beam_type(typing.Generator[int, None, None])) - def test_newtype(self): - self.assertEqual( - typehints.Any, convert_to_beam_type(typing.NewType('Number', int))) - def test_pattern(self): # TODO(https://github.com/apache/beam/issues/20489): Unsupported. self.assertEqual(typehints.Any, convert_to_beam_type(typing.Pattern)) diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 0e18e887c2a0..5d6a7391ec73 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1287,6 +1287,15 @@ def normalize(x, none_as_type=False): }) +# newtype cases +# is_consistent_with(sub=UserID, base=int) +# is_consistent_with(sub=int, base=UserID) +# is_consistent_with(sub=SuperUserID, base=UserID) +# is_consistent_with(sub=UserID, base=SuperUserID) +def _is_newtype(type_hint): + return hasattr(type_hint, '__supertype__') + + def is_consistent_with(sub, base): """Checks whether sub a is consistent with base. @@ -1303,6 +1312,17 @@ def is_consistent_with(sub, base): return True if isinstance(sub, AnyTypeConstraint) or isinstance(base, AnyTypeConstraint): return True + if _is_newtype(base) and not _is_newtype(sub): + return False # non-newtypes are never subtypes of newtypes + if _is_newtype(sub): + supertypes = [sub.__supertype__] + while _is_newtype(supertypes[-1]): + supertypes.append(supertypes[-1].__supertype__) + if _is_newtype(base): + return base in supertypes + else: + return is_consistent_with(supertypes[-1], base) + # Per PEP484, ints are considered floats and complexes and # floats are considered complexes. if sub is int and base in (float, complex): diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 6611dcecab01..392eb0bf2f48 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -166,6 +166,24 @@ def test_any_compatibility(self): self.assertCompatible(object, typehints.Any) self.assertCompatible(typehints.Any, object) + def test_newtype_compatibility(self): + UserID = typing.NewType('UserID', str) + SuperUserID = typing.NewType('SuperUserID', UserID) + + self.assertCompatible(UserID, typehints.Any) + self.assertCompatible(typehints.Any, UserID) + + self.assertCompatible(UserID, UserID) + self.assertCompatible(str, UserID) + self.assertNotCompatible(UserID, str) + + self.assertCompatible(SuperUserID, SuperUserID) + self.assertCompatible(UserID, SuperUserID) + self.assertNotCompatible(SuperUserID, UserID) + + self.assertCompatible(str, SuperUserID) + self.assertNotCompatible(SuperUserID, str) + def test_int_float_complex_compatibility(self): self.assertCompatible(float, int) self.assertCompatible(complex, int) diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 8000c24f28aa..3c2bc631aacb 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -28,7 +28,7 @@ requires = [ # Numpy headers "numpy>=1.14.3,<2.3.0", # Update setup.py as well. # having cython here will create wheels that are platform dependent. - "cython>=3.0,<4", + #"cython>=3.0,<4", ## deps for generating external transform wrappers: # also update PyYaml bounds in sdks:python:generateExternalTransformsConfig 'pyyaml>=3.12,<7.0.0',