diff --git a/octo/data/dataset.py b/octo/data/dataset.py index fe247f34..233c9ae9 100644 --- a/octo/data/dataset.py +++ b/octo/data/dataset.py @@ -132,7 +132,9 @@ def apply_trajectory_transforms( num_parallel_calls, ) - # chunks observations and actions + # chunks observations and actions, giving them a new axis at index 1 of size `window_size` and + # `window_size + future_action_window_size`, respectively + dataset = dataset.traj_map( partial( traj_transforms.chunk_act_obs, @@ -391,7 +393,9 @@ def is_nonzero_length(traj): full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec)) if ignore_errors: full_dataset = full_dataset.ignore_errors() + full_dataset = full_dataset.traj_map(restructure).filter(is_nonzero_length) + # tries to load from cache, otherwise computes on the fly dataset_statistics = get_dataset_statistics( full_dataset, @@ -454,13 +458,13 @@ def is_nonzero_length(traj): return dataset, dataset_statistics - def make_single_dataset( dataset_kwargs: dict, *, train: bool, traj_transform_kwargs: dict = {}, frame_transform_kwargs: dict = {}, + user_modify_traj ) -> dl.DLataset: """Creates a single dataset from kwargs. Returns a dataset of trajectories. @@ -474,6 +478,9 @@ def make_single_dataset( **dataset_kwargs, train=train, ) + + dataset = dataset.traj_map(user_modify_traj) + dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train) dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train) diff --git a/octo/data/utils/data_utils.py b/octo/data/utils/data_utils.py index 55196a35..5ffb3bec 100644 --- a/octo/data/utils/data_utils.py +++ b/octo/data/utils/data_utils.py @@ -257,6 +257,7 @@ def normalize_action_and_proprio( mask = metadata[key].get( "mask", tf.ones_like(metadata[key]["mean"], dtype=tf.bool) ) + traj = dl.transforms.selective_tree_map( traj, match=lambda k, _: k == traj_key, diff --git a/octo/model/components/action_heads.py b/octo/model/components/action_heads.py index b56db26e..e60badfb 100644 --- a/octo/model/components/action_heads.py +++ b/octo/model/components/action_heads.py @@ -244,6 +244,8 @@ class DiscreteActionHead(nn.Module, ActionHead): action_dim: int = 7 vocab_size: int = 256 normalization_type: str = "uniform" + low: Optional[float] = None + high: Optional[float] = None def setup(self): total_output = self.action_horizon * self.action_dim * self.vocab_size @@ -267,6 +269,8 @@ def setup(self): self.action_tokenizer = BinTokenizer( n_bins=self.vocab_size, bin_type=self.normalization_type, + low=self.low, + high=self.high ) def __call__( diff --git a/octo/model/components/tokenizers.py b/octo/model/components/tokenizers.py index 0956ba23..91951238 100644 --- a/octo/model/components/tokenizers.py +++ b/octo/model/components/tokenizers.py @@ -244,11 +244,14 @@ class BinTokenizer(nn.Module): n_bins: int = 256 bin_type: str = "uniform" - low: float = 0 - high: float = 1 + low: Optional[float] = None + high: Optional[float] = None def setup(self): if self.bin_type == "uniform": + if self.low is None or self.high is None: + raise ValueError("Low and high must be provided for uniform normalization") + self.thresholds = jnp.linspace(self.low, self.high, self.n_bins + 1) elif self.bin_type == "normal": self.thresholds = norm.ppf(jnp.linspace(EPS, 1 - EPS, self.n_bins + 1)) diff --git a/octo/model/octo_model.py b/octo/model/octo_model.py index 23cc497f..27d139e5 100644 --- a/octo/model/octo_model.py +++ b/octo/model/octo_model.py @@ -250,6 +250,162 @@ def sample_actions( else: raise ValueError(f"Unknown normalization type: {normalization_type}") return action + + @partial(jax.jit, static_argnames=("train", "sample_shape", "argmax", "beam")) + def sample_future_actions( + self, + observations: Data, + tasks: Data, + unnormalization_statistics: Optional[Data] = None, + normalization_type: NormalizationType = NormalizationType.NORMAL, + beam:int = 1, + timestep_pad_mask: Optional[ArrayLike] = None, + train: bool = False, + argmax: bool = False, + sample_shape: Tuple[int, ...] = (), + rng: Optional[PRNGKey] = None, + temperature: float = 1.0, + ): + """Samples actions from the model. See `action_heads.py` for more info. + + Args: + observations: dictionary of arrays of shape (batch_size, window_size, *) + tasks: dict of tasks of shape (batch_size, *) + unnormalization_statistics: dict of statistics for unnormalizing actions (must contain "mean", + "std", and optionally "mask") + normalization_type: type of normalization applied to the actions + timestep_pad_mask: (batch_size, window_size) Boolean mask that is False when the timestep corresponds to padding + train: whether to run in train mode + ...see `action_heads.py` for the rest of the kwargs. + Returns: + actions: (*sample_shape, batch_size, action_horizon, action_dim) + """ + if timestep_pad_mask is None: + timestep_pad_mask = observations["pad_mask"] + + transformer_outputs = self.run_transformer( + observations, tasks, timestep_pad_mask, train=train + ) + action_head = self.module.bind({"params": self.params}).heads[ + "action" + ] + + action_logits = action_head(transformer_outputs, train=train)[:, -1] + + action_distribution = jax.nn.softmax(action_logits, axis=-1) + + action_tokens = jnp.argsort(action_distribution, axis=-1)[..., -beam:].astype(jnp.int32) + confidence = jnp.take_along_axis(action_distribution, action_tokens, axis=-1) + + action_tokens = jnp.broadcast_to( + action_tokens, sample_shape + action_tokens.shape + ) + + action = action_head.action_tokenizer.decode(action_tokens) + + if unnormalization_statistics is not None: + if normalization_type == NormalizationType.NORMAL: + mask = unnormalization_statistics.get( + "mask", + jnp.ones_like(unnormalization_statistics["mean"], dtype=bool), + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action * unnormalization_statistics["std"]) + + unnormalization_statistics["mean"], + action, + ) + elif normalization_type == NormalizationType.BOUNDS: + mask = unnormalization_statistics.get( + "mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool) + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action + 1) + * ( + unnormalization_statistics["p99"] + - unnormalization_statistics["p01"] + ) + / 2 + + unnormalization_statistics["p01"], + action, + ) + else: + raise ValueError(f"Unknown normalization type: {normalization_type}") + + return action, confidence + + @partial(jax.jit, static_argnames=("train", "sample_shape", "beam")) + def sample_trajectory( + self, + observations: Data, + next_action, + tasks: Data, + unnormalization_statistics: Optional[Data] = None, + normalization_type: NormalizationType = NormalizationType.NORMAL, + beam: int = 1, + timestep_pad_mask: Optional[ArrayLike] = None, + train: bool = False, + argmax: bool = False, + sample_shape: Tuple[int, ...] = (), + rng: Optional[PRNGKey] = None, + temperature: float = 1.0, + ): + if timestep_pad_mask is None: + pad_mask = observations["pad_mask"] + + transformer_outputs = self.run_transformer( + observations, tasks, pad_mask, train=train + ) + + trajectory_head = self.module.bind({"params": self.params}).heads[ + "trajectory" + ] + + action = trajectory_head.predict_action( + transformer_outputs, + train=train, + argmax=argmax, + sample_shape=sample_shape, + rng=rng, + temperature=temperature, + ) + + if unnormalization_statistics is not None: + if normalization_type == NormalizationType.NORMAL: + mask = unnormalization_statistics.get( + "mask", + jnp.ones_like(unnormalization_statistics["mean"], dtype=bool), + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action * unnormalization_statistics["std"]) + + unnormalization_statistics["mean"], + action, + ) + elif normalization_type == NormalizationType.BOUNDS: + mask = unnormalization_statistics.get( + "mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool) + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action + 1) + * ( + unnormalization_statistics["p99"] + - unnormalization_statistics["p01"] + ) + / 2 + + unnormalization_statistics["p01"], + action, + ) + else: + raise ValueError(f"Unknown normalization type: {normalization_type}") + + return action @classmethod def load_pretrained( @@ -277,6 +433,8 @@ def load_pretrained( tf.io.gfile.join(checkpoint_path, "config.json"), "r" ) as f: config = json.load(f) + if 'readouts' in config['model']: + config['model']['readout_tokenizers'] = config['model'].pop('readouts') # shim to support old configs if "pred_horizon" in config["model"]["heads"]["action"]["kwargs"]: diff --git a/octo/model/octo_module.py b/octo/model/octo_module.py index c14343ac..7abf9b47 100644 --- a/octo/model/octo_module.py +++ b/octo/model/octo_module.py @@ -79,7 +79,7 @@ class OctoTransformer(nn.Module): observation_tokenizers: Dict[str, nn.Module] task_tokenizers: Dict[str, nn.Module] - readouts: Dict[str, int] + readout_tokenizers: Dict[str, int | nn.Module] transformer_kwargs: Dict token_embedding_size: int max_horizon: int @@ -92,7 +92,7 @@ def __call__( observations: Data, tasks: Data, timestep_pad_mask: jax.Array, - readouts: Optional[Sequence[str]] = None, + readout_tokenizers: Optional[Sequence[str]] = None, train: bool = False, verbose: bool = False, ) -> Dict[str, TokenGroup]: @@ -114,15 +114,15 @@ def __call__( Note: Horizon can be anything <= max_horizon. """ - if readouts is None: - readouts = list(self.readouts.keys()) + if readout_tokenizers is None: + readout_tokenizers = list(self.readout_tokenizers.keys()) # # Check that all inputs are valid # - assert set(readouts).issubset( - set(self.readouts.keys()) + assert set(readout_tokenizers).issubset( + set(self.readout_tokenizers.keys()) ), "readouts must be specified in the model config" batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2] @@ -241,32 +241,58 @@ def __call__( # Finally, add the readout tokens # - for readout_name in readouts: - group_name = f"readout_{readout_name}" - # Readouts do not correspond to any inputs, just positional embeddings - n_tokens_for_readout = self.readouts[readout_name] - readout_tokens = jnp.zeros( - (batch_size, horizon, n_tokens_for_readout, self.token_embedding_size) - ) + for name, tok in self.readout_tokenizers.items(): + group_name = f"readout_{name}" + if isinstance(tok, nn.Module): + tokenizer_output: TokenGroup = tok(observations, tasks, train=train) + if tokenizer_output is None: + logging.warning(f"Skipping observation tokenizer: {group_name}") + continue - # Add positional embedding - readout_tokens += self._create_positional_embedding( - group_name, readout_tokens - ) - readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout)) - readout_attention_rules = { - "task_*": AttentionRule.CAUSAL, - "obs_*": AttentionRule.CAUSAL, - group_name: AttentionRule.CAUSAL, - } # Attend to tasks, all previous observations, and *only it's own own readout* + obs_tokens = nn.Dense( + self.token_embedding_size, name=f"{group_name}_projection" + )(tokenizer_output.tokens) + # obs_tokens shape is (batch, horizon, n_tokens, token_embedding_size) - all_timestep_groups.append( - TimestepGroup( - tokens=readout_tokens, - mask=readout_mask, - name=group_name, - attention_rules=readout_attention_rules, + # Add positional embedding + obs_tokens += self._create_positional_embedding(group_name, obs_tokens) + + # Update mask to account for which timesteps are padding + obs_pad_mask = jnp.logical_and(timestep_pad_mask[:, :, None], tokenizer_output.mask) + + all_timestep_groups.append( + TimestepGroup( + tokens=obs_tokens, + mask=obs_pad_mask, + name=group_name, + attention_rules=observation_attention_rules, + ) + ) + elif isinstance(tok, int): + # Readouts do not correspond to any inputs, just positional embeddings + n_tokens_for_readout = self.readout_tokenizers[name] + readout_tokens = jnp.zeros( + (batch_size, horizon, n_tokens_for_readout, self.token_embedding_size) ) + + # Add positional embedding + readout_tokens += self._create_positional_embedding( + group_name, readout_tokens + ) + readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout)) + readout_attention_rules = { + "task_*": AttentionRule.CAUSAL, + "obs_*": AttentionRule.CAUSAL, + group_name: AttentionRule.CAUSAL, + } # Attend to tasks, all previous observations, and *only it's own own readout* + + all_timestep_groups.append( + TimestepGroup( + tokens=readout_tokens, + mask=readout_mask, + name=group_name, + attention_rules=readout_attention_rules, + ) ) # Run the transformer! @@ -373,7 +399,7 @@ def create( observation_tokenizers: Dict[str, ModuleSpec], task_tokenizers: Dict[str, ModuleSpec], heads: Dict[str, ModuleSpec], - readouts: Dict[str, int], + readout_tokenizers: Dict[str, int | ModuleSpec], transformer_kwargs: Dict, token_embedding_size: int, max_horizon: int, @@ -407,13 +433,17 @@ def create( task_tokenizer_defs = { k: ModuleSpec.instantiate(spec)() for k, spec in task_tokenizers.items() } + readout_tokenizer_defs = { + k: ModuleSpec.instantiate(spec)() if isinstance(spec, dict) else spec + for k, spec in readout_tokenizers.items() + } head_defs = {k: ModuleSpec.instantiate(spec)() for k, spec in heads.items()} model_def = OctoTransformer( observation_tokenizers=observation_tokenizer_defs, task_tokenizers=task_tokenizer_defs, - readouts=readouts, + readout_tokenizers=readout_tokenizer_defs, token_embedding_size=token_embedding_size, max_horizon=max_horizon, repeat_task_tokens=repeat_task_tokens,