Skip to content

[Performance] Faster tensorclass getattr #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/vmoens/48/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2696,10 +2696,9 @@ def _stack_onto_at_(
return self

def _get_str(self, key, default):
first_key = key
out = self._tensordict.get(first_key)
out = self._tensordict.get(key)
if out is None:
return self._default_get(first_key, default)
return self._default_get(key, default)
return out

def _get_tuple(self, key, default):
Expand Down
22 changes: 12 additions & 10 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
_GENERIC_NESTED_ERR,
_is_dataclass as is_dataclass,
_is_json_serializable,
_is_tensorclass,
_is_tensorclass,_is_non_tensor,
_LOCK_ERROR,
_td_fields,
_TENSORCLASS_MEMO,
Expand Down Expand Up @@ -1473,11 +1473,17 @@ def _setstate(self, state: dict[str, Any]) -> None: # noqa: D417

def _getattr(self, item: str) -> Any:
_tensordict = self._tensordict
__dataclass_fields__ = type(self).__expected_keys__

if item in __dataclass_fields__:
_non_tensordict = self._non_tensordict
if _non_tensordict:
out = _tensordict._get_str(item, NO_DEFAULT)
if out is not NO_DEFAULT:
if isinstance(out, (NonTensorData, NonTensorStack)):
return out.data if not isinstance(out, NonTensorStack) else out.tolist()
return out

_non_tensordict = self._non_tensordict
if _non_tensordict:
__dataclass_fields__ = type(self).__expected_keys__
if item in __dataclass_fields__:
out = _non_tensordict.get(item, NO_DEFAULT)
if out is not NO_DEFAULT:
if (
Expand All @@ -1487,16 +1493,12 @@ def _getattr(self, item: str) -> Any:
):
return _from_shared_nontensor(out)
return out
out = _tensordict._get_str(item, NO_DEFAULT)
if is_non_tensor(out):
return out.data if not isinstance(out, NonTensorStack) else out.tolist()
return out

out = getattr(_tensordict, item, NO_DEFAULT)
if out is not NO_DEFAULT:
if not callable(out) and not is_non_tensor(out):
return out
if is_non_tensor(out):
if _is_non_tensor(type(out)):
return out.data if hasattr(out, "data") else out.tolist()
return _wrap_method(self, item, out)
raise AttributeError(item)
Expand Down
Loading