-
-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor duck #40
base: master
Are you sure you want to change the base?
Tensor duck #40
Conversation
…se the protocol instead of torch.Tensor directly.
torchtyping/typechecker.py
Outdated
@@ -301,7 +300,7 @@ def check_type(*args, **kwargs): | |||
# Now check if it's annotating a tensor | |||
if is_torchtyping_annotation: | |||
base_cls, *all_metadata = get_args(expected_type) | |||
if not issubclass(base_cls, torch.Tensor): | |||
if not isinstance(base_cls(), TensorLike): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this last change. As mentioned, the protocol class only supports isintance() because it has properties. This means I had to require default construction.
But, I think this test may be unnecessary - after all the other tests I think we know this is a TensorLike element?
I think it might be better to just get rid of this test. @patrick-kidger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, it does seem I have a strong motivation to remove this. The case where I want to apply it is to check shape signatures on an abstract base class so default construction may not be an option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the PR accordingly.
Thanks for the PR! Unfortunately, this isn't quite the direction I had in mind. Following on from the discussion in #39, perhaps it's worth making clear that I don't intend to make TorchTyping depend on JAX. Rather, that the plan is to simply copy over the non-JAX parts of the code. (Which is most of it.) The idea would be to end up with annotations that look like At a technical level this should be essentially simple. The main hurdle - and the reason I've been putting off doing this is - is writing up documentation that makes this transition clear. |
Thanks for the clarification! I can totally see why you want to pull over the jaxtyping code and have a single code base. I understand that this PR is perhaps not what you were looking for, but I think it could actually represent a very important step in generalizing what you have and maybe even merging the two code bases. Let's take an example snippet from jaxtyping where the dtype is extracted (array_types.py: 129-)
I think you would agree that it's a bit awkward and somewhat hard to extend since the supported classes have to be coded in advance.
Then, to type-check a concrete class like numpy.array or torch.Tensor we just use the adapter pattern to map the specialized methods to the interface. (As an example, a simple name remapper: Adapter Method – Python Design Patterns). This would make it easy for folks like me to extend your library to array-type objects such as LinearOperator by just writing an adapter to the interface specified by the library. In addition, I think it could also let you merge these two libraries and make your life easier. You wrote that: |
Hmm. I suppose the practical implementation of such an adaptor would be via a registry: import functools as ft
@ft.singledispatch
def get_dtype(obj):
# Note that this default implementation does not explicitly
# depend on any of PyTorch/etc; thus the singledispatch
# hook is made available just for the sake of user-defined
# custom types.
if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
dtype = obj.dtype.type.__name__
elif hasattr(obj.dtype, "as_numpy_dtype"):
# TensorFlow
dtype = obj.dtype.as_numpy_dtype.__name__
else:
# PyTorch
repr_dtype = repr(obj.dtype).split(".")
if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
dtype = repr_dtype[1]
else:
raise RuntimeError(
"Unrecognised array/tensor type to extract dtype from"
)
class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
...
dtype = get_dtype(obj)
... and then in your user code, you could add a custom overload for your type. I'd be willing to accept a PR for this over in jaxtyping. |
As promised, here is the PR to upgrade the library to define a 'torch-like' protocol and use that for the base type rather than using torch.Tensor directly. This lets users perform dimension checking on classes that support a Tensor interface but do not directly inherit from torch.Tensor. I think the change is fairly clear-cut, I have added a test case to demonstrate and verify that dimensions are actually checked.
The only question I have is about the change to line 304 in typechecker.py (the last change below).
Is this test really necessary?
I had to change it to use default construction because protocols don't support isinstance if they have properties.