Skip to content

Commit

Permalink
[BugFix] Ensure that Composite.set returns self as TensorDict does
Browse files Browse the repository at this point in the history
ghstack-source-id: ee47d30c335a95b5a100ba4a32f10c578cacdbdc
Pull Request resolved: #2784
  • Loading branch information
vmoens committed Feb 12, 2025
1 parent fd0645e commit ed9130d
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4665,7 +4665,11 @@ def separates(self, *keys: NestedKey, default: Any = None) -> Composite:
out[key] = result
return out

def set(self, name, spec):
def set(self, name: str, spec: TensorSpec) -> Composite:
"""Sets a spec in the Composite spec."""
if not isinstance(name, str):
self[name] = spec
return self
if self.locked:
raise RuntimeError("Cannot modify a locked Composite.")
if spec is not None and self.device is not None and spec.device != self.device:
Expand Down Expand Up @@ -4698,6 +4702,7 @@ def set(self, name, spec):
f"Composite.shape={self.shape}."
)
self._specs[name] = spec
return self

def __init__(
self, *args, shape: torch.Size = None, device: torch.device = None, **kwargs
Expand Down Expand Up @@ -5733,9 +5738,10 @@ def ndim(self):
def ndimension(self):
return len(self.shape)

def set(self, name, spec):
def set(self, name: str, spec: TensorSpec) -> StackedComposite:
for sub_spec, sub_item in zip(self._specs, spec.unbind(self.dim)):
sub_spec[name] = sub_item
return self

@property
def shape(self):
Expand Down

0 comments on commit ed9130d

Please sign in to comment.