diff --git a/tensordict/_td.py b/tensordict/_td.py index 1e6a9aeb9..346f0edd3 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2696,10 +2696,9 @@ def _stack_onto_at_( return self def _get_str(self, key, default): - first_key = key - out = self._tensordict.get(first_key) + out = self._tensordict.get(key) if out is None: - return self._default_get(first_key, default) + return self._default_get(key, default) return out def _get_tuple(self, key, default): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index f993952f7..b3f27265b 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -50,7 +50,7 @@ _GENERIC_NESTED_ERR, _is_dataclass as is_dataclass, _is_json_serializable, - _is_tensorclass, + _is_tensorclass,_is_non_tensor, _LOCK_ERROR, _td_fields, _TENSORCLASS_MEMO, @@ -1473,11 +1473,17 @@ def _setstate(self, state: dict[str, Any]) -> None: # noqa: D417 def _getattr(self, item: str) -> Any: _tensordict = self._tensordict - __dataclass_fields__ = type(self).__expected_keys__ - if item in __dataclass_fields__: - _non_tensordict = self._non_tensordict - if _non_tensordict: + out = _tensordict._get_str(item, NO_DEFAULT) + if out is not NO_DEFAULT: + if isinstance(out, (NonTensorData, NonTensorStack)): + return out.data if not isinstance(out, NonTensorStack) else out.tolist() + return out + + _non_tensordict = self._non_tensordict + if _non_tensordict: + __dataclass_fields__ = type(self).__expected_keys__ + if item in __dataclass_fields__: out = _non_tensordict.get(item, NO_DEFAULT) if out is not NO_DEFAULT: if ( @@ -1487,16 +1493,12 @@ def _getattr(self, item: str) -> Any: ): return _from_shared_nontensor(out) return out - out = _tensordict._get_str(item, NO_DEFAULT) - if is_non_tensor(out): - return out.data if not isinstance(out, NonTensorStack) else out.tolist() - return out out = getattr(_tensordict, item, NO_DEFAULT) if out is not NO_DEFAULT: if not callable(out) and not is_non_tensor(out): return out - if is_non_tensor(out): + if _is_non_tensor(type(out)): return out.data if hasattr(out, "data") else out.tolist() return _wrap_method(self, item, out) raise AttributeError(item)