Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Mar 28, 2021
0 parents commit beb3ad2
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
**/__pycache__/
**/.ipynb_checkpoints/
*.py[cod]
.idea/
.vs/
build/
dist/
*.egg_info/
*.egg
*.so
*.egg-info/
**/.mypy_cache/
env/
60 changes: 60 additions & 0 deletions examples/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#####################
# torchtyping is designed to be highly extensible.
#####################

from __future__ import annotations

import torch
from torchtyping import TensorType
import typeguard

from typing import Any, Tuple


#####################
# It's possible to check any other property of a tensor, as well as just the defaults.
#
# Here we check that the tensor has an attribute called "foo" on it, which should
# take a particular value.
#####################
class TensorTypeFooChecker(TensorType):
foo = None

@classmethod
def fields(cls) -> Tuple[str]:
return super().fields() + ('foo',)

@classmethod
def check(cls, instance: Any) -> bool:
check = super().check(instance)
if cls.foo is not None:
check = check and hasattr(instance, "foo") and instance.foo == cls.foo
return check

@classmethod
def getitem(cls, item: Any) -> TensorTypeFooChecker:
foo = cls.foo
if isinstance(item, slice):
if item.start == "foo":
foo = item.stop
item = None
dict = super().getitem(item)
dict.update(foo=foo)
return dict


@typeguard.typechecked
def foo_checker(tensor: TensorTypeFooChecker["foo":"good-foo"][float]):
pass


def valid_foo():
x = torch.rand(3)
x.foo = "good-foo"
foo_checker(x)


def invalid_foo():
x = torch.rand(3)
x.foo = "bad-foo"
foo_checker(x)
1 change: 1 addition & 0 deletions torchtyping/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tensor import TensorType
80 changes: 80 additions & 0 deletions torchtyping/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import torch

from typing import Any, Tuple


class _TensorTypeMeta(type):
_cache = {}

def __repr__(cls) -> str:
return cls.__name__

def __instancecheck__(cls, instance: Any) -> bool:
return cls.check(instance)

def __getitem__(cls, item: Any) -> _TensorTypeMeta:
if item is None:
# Corresponding to how None is allow in TensorType.getitem: it has a
# special value there, so we disallow it here.
raise ValueError(f"{item} not a valid type argument.")

if cls._is_getitem_subclass:
assert len(cls.__bases__) == 1
base_cls = cls.__bases__[0]
else:
base_cls = cls
name = base_cls.__name__
dict = cls.getitem(item)
for field in cls.fields():
value = dict[field]
if value is not None:
name += f"[{field}={value}]"
dict["_is_getitem_subclass"] = True
try:
return type(cls)._cache[name, base_cls]
except KeyError:
out = type(cls)(name, (base_cls,), dict)
type(cls)._cache[name, base_cls] = out
return out


class TensorType(metaclass=_TensorTypeMeta):
_is_getitem_subclass = False

def __new__(cls, *args, **kwargs):
raise RuntimeError(f"Class {cls.__name__} cannot be instantiated.")

dtype = None
layout = None

@classmethod
def fields(cls) -> Tuple[str]:
return ('dtype', 'layout')

@classmethod
def check(cls, instance: Any) -> bool:
return isinstance(instance, torch.Tensor) and (cls.dtype in (None, instance.dtype)) and (cls.layout in (None, instance.layout))

@classmethod
def getitem(cls, item: Any) -> TensorType:
dtype = cls.dtype
layout = cls.layout

if item is int:
dtype = torch.long
elif item is float:
dtype = torch.get_default_dtype()
elif item is bool:
dtype = torch.bool
elif isinstance(item, torch.dtype):
dtype = item
elif isinstance(item, torch.layout):
layout = item
elif item is None:
pass # To allow subclasses to pass item=None to indicate no further processing.
else:
raise ValueError(f"{item} not a valid type argument.")

return dict(dtype=dtype, layout=layout)

0 comments on commit beb3ad2

Please sign in to comment.