Skip to content
Open
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,52 @@ response = robot_agent.act(instruction, image)[0]
response = robot_agent.act([instruction1, image1, instruction2, image2])[0]
```

### Motor Agent

The [Motor Agent](src/mbodied_agents/agents/motor/motor_agent.py) serves as the base class for various motor control models that dictate robotic actions. Each subclass implements the `act` method, tailoring it to produce specific action types based on the output of respective models.

For example, `RT1Agent` uses the Robotics Transformer 1 model to generate actions in its `act` method.

You can easily extend the `MotorAgent` by creating new subclasses that adhere to different motion control paradigms or models.

#### Example: Using RT1Agent
To harness the power of the Robotics Transformer 1 (RT1) model with your robot backend:

```python
from src.mbodied_agents.agents.motion.rt1agent import RT1Agent

robot_agent = RT1Agent(config=rt1_config)
```

Here, model_output_data represents the output generated by the Robotics Transformer 1 model.

To perform an action based on model output:

```python
actions = robot_agent.act(robot_state=robot_state_data)
```

For example, robot_state_data could be a dictionary containing the current state of the robot, while actions would be a list of Motion objects that dictate the robot's movements.

#### CustomMotorAgent

Here's a simple example of how to implement a custom motor agent to suit your needs:

```python
from src.mbodied_agents.agents.motion.custom_motor_agent import CustomMotorAgent

# Initialize your custom motor agent with specific model output
robot_agent = CustomMotorAgent(model_output=custom_model_output)

# Generate actions based on the robot state
actions = robot_agent.act(robot_state=current_state)
```

In this example, CustomMotorAgent should be a subclass of MotorAgent where the act method is implemented to use custom_model_output to generate actions.

Stay tuned for more motor control models and enhancements to the MotorAgent framework!


### Controls

The [controls](src/mbodied_agents/types/controls.py) module defines various motions to control a robot as Pydantic models. They are also subclassed from `Sample`, thus possessing all the capability of `Sample` as mentioned above. These controls cover a range of actions, from simple joint movements to complex poses and full robot control.
Expand Down
45 changes: 45 additions & 0 deletions examples/rt1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 Mbodi AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from mbodied_agents.agents.motor.rt1.rt1_agent import RT1Agent


def main() -> None:
# Define the configuration for the RT1Agent
rt1_agent_config = {
"num_layers": 8,
"layer_size": 128,
"observation_history_size": 6,
"future_prediction": 6,
"token_embedding_dim": 512,
"causal_attention": True,
}

# Instantiate the RT1Agent
rt1_agent = RT1Agent(rt1_agent_config)

# Create dummy input data
image = torch.rand(224, 224, 3) # Assume this is an example image input
instruction_emb = torch.rand(1, 512) # Assume this is an example instruction embedding

# Use the act method of RT1Agent to get actions
actions = rt1_agent.act(image=image, instruction_emb=instruction_emb)

# Print the actions
print("Actions received from RT1Agent:")
print(actions)

if __name__ == "__main__":
main()
Empty file.
38 changes: 38 additions & 0 deletions src/mbodied_agents/agents/motor/motor_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Mbodi AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import List

from mbodied_agents.types.controls import Motion


class MotorAgent(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it MotionAgent. Also rename the directory.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially it was MotionAgent. @sebbyjp suggested MotorAgent

"""Abstract base class for a Motion Agent.

Subclasses must implement the `act` method, which generates a list of
Motion objects based on given parameters.
"""

@abstractmethod
def act(self, **kwargs) -> List['Motion']:
"""Generate a list of Motion objects based on given parameters.

Args:
**kwargs: Arbitrary keyword arguments that will be used to determine the Motion objects.

Returns:
List[Motion]: A list of Motion objects based on the provided arguments.
"""
pass
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 Mbodi AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch import nn
from torchvision.models import EfficientNet_B5_Weights, efficientnet_b5
from torchvision.models.efficientnet import MBConv

from .film import FilmLayer


class EfficientNetB5(nn.Module):
def __init__(self, context_dim: int = 512):
super().__init__()
net = efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1)
film_layers = []

for layer in net.features:
for sublayer in layer:
if isinstance(sublayer, MBConv):
film_layers.append(FilmLayer(sublayer.out_channels, context_dim))

# Don't add a film layer to the last layer
self.film_layers = nn.ModuleList(film_layers[:-1])
self.features = net.features

def forward(self, x, context):
film_layers = iter(self.film_layers)
film_layer = next(film_layers, None)

for layer in self.features:
for sublayer in layer:
x = sublayer(x)
if isinstance(sublayer, MBConv):
if film_layer is not None:
x = film_layer(x, context)
film_layer = next(film_layers, None)
else:
return x
return x

# import torchinfo
# from torchinfo import summary
# model = EfficientNetB3()
# summary(model, input_size=[(6, 3, 300, 300),(6, 512)])
94 changes: 94 additions & 0 deletions src/mbodied_agents/agents/motor/rt1/film_efficientnet/film.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 Mbodi AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from functools import partial

import torch
import torch.nn as nn
from einops import rearrange
from torchvision.models.efficientnet import MBConv


def film_conditioned(cls):
"""Decorator to add FiLM conditioning to a forward method. Adds argument context
to forward method and applies FiLM conditioning to the output of the forward.

Args:
context_key (str, optional): The argument name for the context on which to
condition the forward pass. Defaults to "context".
"""
init_func = cls.__init__

def init_with_film_layers(*args, **kwargs):
context_dim = kwargs.get("context_dim", None)
if context_dim is None:
return init_func(*args, **kwargs)
self = init_func(*args, **kwargs)
film_layers = []
for layer in self.features:
for sublayer in layer:
if isinstance(sublayer, MBConv):
film_layers.append(
FilmLayer(sublayer.out_channels, context_dim))

# Don't add a film layer to the last layer
self.film_layers = nn.ModuleList(film_layers[:-1])
return self

cls.__init__ = init_with_film_layers

forward_func = cls.forward

def forward_with_film_layers(*args, **kwargs):
if 'conditioning_funcs' in kwargs or kwargs.get('context', None) is None:
return forward_func(*args, **kwargs)

self = args[0]
context = kwargs['context']
conditioners = iter([partial(film_layer, context=context)
for film_layer in self.film_layers])
return forward_func(*args, **kwargs, conditioning_funcs=conditioners)

cls.forward = forward_with_film_layers
return cls


class FilmLayer(nn.Module):
"""Layer to conditionally modulate the input tensor with the context tensor."""

def __init__(
self,
num_channels: int,
context_dim: int = 512,
):
super().__init__()
self.beta = nn.Linear(context_dim, num_channels, bias=False)
self.gamma = nn.Linear(context_dim, num_channels, bias=False)

nn.init.constant_(self.beta.weight, 0)
nn.init.constant_(self.gamma.weight, 0)

def forward(self, x: torch.Tensor, context: torch.Tensor):
context = context.to(x.device)
beta = self.beta(context)
gamma = self.gamma(context)

beta = rearrange(beta, 'b c -> b c 1 1')
gamma = rearrange(gamma, 'b c -> b c 1 1')

# Initialize to identity op.
result = (1 + gamma) * x + beta

return result
Loading