Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ data/vavam-driver/*
# uv.lock is only used in local development. On the server, we use uv tool run,
# which does not require uv.lock.
**/uv.lock

# local dir
my_run/
outputs/
logs/
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ members = [
"src/physics",
"src/tools",
"src/driver",
"src/render",
]

# This is a hacky workaround to skip installing flash attention
Expand Down
13 changes: 12 additions & 1 deletion src/grpc/alpasim_grpc/v0/sensorsim.proto
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,18 @@ message RGBRenderRequest {
fixed64 frame_start_us = 5;
fixed64 frame_end_us = 6;

PosePair sensor_pose = 7;
// Deprecated: Use ego_pose instead. sensor_pose is camera pose in world frame.
// Kept for backward compatibility. If ego_pose is provided, it will be used instead.
PosePair sensor_pose = 7 [deprecated = true];

// Ego vehicle pose in world frame (preferred for MTGS renderer)
// If provided, this will be used instead of deriving from sensor_pose
PosePair ego_pose = 13;

// Rig to camera transformation (camera pose relative to ego/rig frame)
// If provided, this will be used for accurate camera2ego transformation
// If not provided, will be loaded from asset folder or data_source
common.Pose rig_to_camera = 14;

repeated DynamicObject dynamic_objects = 8;

Expand Down
Empty file added src/render/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions src/render/base_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from abc import ABC
from dataclasses import dataclass, field
from functools import cached_property
from typing import Dict, List, Optional, Tuple, Type, Union, Any
from typing_extensions import Literal

import torch
from torch.nn import Module, Parameter


class RenderState(dict):
CAMERAS = "cameras"
LIDAR = "lidar"
AGENT_STATE = "agent_state" # {object_id: np.ndarray([x, y, heading])}
TIMESTAMP = "timestamp"


class BaseRenderer(ABC):

def __init__(
self,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
self.device = device
self.__from_scratch__()

def __from_scratch__(self):
self.sensors = None

def _check_for_reliance(self):
if self.background_asset is None:
raise ValueError("No background asset set for renderer.")

def reset(self):
self.__from_scratch__()

@property
def background_asset(self):
return None

def set_asset(self, asset):
pass

def render(self, render_state: RenderState):
pass

def physical_world(self, agent_state):
# [x, y, heading]
raise NotImplementedError
41 changes: 41 additions & 0 deletions src/render/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "alpasim_render"
version = "0.1.0"
description = "DigitalTwin renderer module for Alpasim"
requires-python = ">=3.11,<3.12"
dependencies = [
"torch>=2.0.0",
"numpy>=1.24.0",
"opencv-python>=4.8.0",
"pyquaternion>=0.9.0",
"gsplat>=1.0.0",
"alpasim_utils", # For geometry_utils and other utilities
]

authors = [
{name = "Alpasim Team"},
]

[tool.uv.sources]
alpasim_utils = {workspace = true}

# Note: render module uses 'src.render' and 'src.utils' import path structure
# The module structure is:
# - src/render/base_renderer.py -> src.render.base_renderer
# - src/render/src/digitaltwin.py -> src.render.src.digitaltwin
# - src/render/src/utils/ -> src.utils (gaussian_utils, portable_utils)
# - src/render/src/gaussian_model/ -> src.gaussian_model

# Package discovery configuration
# render module has a special structure where:
# - base_renderer.py is at the root of render/
# - src/ subdirectory contains the main code
# - src/utils/ contains utilities that are imported as src.utils.*
[tool.setuptools.packages.find]
where = ["."]
include = ["src*"]
exclude = ["tests*", "*.tests*"]
Empty file added src/render/src/__init__.py
Empty file.
Empty file.
152 changes: 152 additions & 0 deletions src/render/src/gaussian_model/rigid_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import logging

import torch
from torch.nn import Parameter, Module
try:
from gsplat.cuda._wrapper import spherical_harmonics
except ImportError:
print("Please install gsplat>=1.0.0")

from ..utils.gaussian_utils import quat_mult, quat_to_rotmat, interpolate_quats, IDFT
from .vanilla_gaussian_splatting import VanillaPortableGaussianModel
logger = logging.getLogger(__name__)

class RigidPortableSubModel(VanillaPortableGaussianModel):
"""Portable Gaussian Splatting model

Args:
asset: portable Gaussian Model with config include
- type
- sh_degree
- scale_dim
- fourier_features_dim
- fourier_features_scale
- fourier_in_space
- log_timestamps
"""
MODEL_TYPE = "rigid"

def __init__(self, **kwargs):
self.log_replay = kwargs.get("log_replay", False)
super().__init__(**kwargs)
# TODO
# self.visible_range = kwargs.get("visible_range", None)

def _update_default_config(self, config):
config = super()._update_default_config(config)
config["fourier_features_dim"] = config.get("fourier_features_dim", None)
config["fourier_features_scale"] = config.get("fourier_features_scale", 1.0)
config["fourier_in_space"] = config.get("fourier_in_space", 'temporal')
config["log_timestamps"] = config.get("log_timestamps", None)
return config

def load_state_dict(self, dict: dict):
super().load_state_dict(dict)
self.exist_log = ('instance_trans' in dict.keys())
if self.exist_log:
self.log_trans = Parameter(dict['instance_trans'].squeeze())
self.log_quats = Parameter(dict['instance_quats'].squeeze())
self.static_in_log = (self.log_trans.dim() == 1)
if not self.static_in_log:
assert getattr(self.config, "log_timestamps", None) is not None
self.log_start_time = self.config.log_timestamps.min().item()
self.log_timestamps = self.config.log_timestamps.squeeze() - self.log_start_time

# TODO:
# maybe make some transition trajectories at the beginning and the ending to avoid sudden appearance or disappearance

log_type = "static" if self.static_in_log else "dynamic"
logger.debug(f"Log pose for `{self.model_name_abbr}` loaded. LOG TYPE: `{log_type}`")
if self.log_replay:
logger.debug(f"Log replay for `{self.model_name_abbr}` enabled.")

def set_static_force(self):
self.static_in_log = True
in_frame_mask = self.log_trans[:, 2] < 1000
self.log_trans = Parameter(self.log_trans[in_frame_mask][0])
self.log_quats = Parameter(self.log_quats[in_frame_mask][0])

def get_means(self, global_quat, global_trans):
local_means = self.gauss_params['means']
rot_cur_frame = quat_to_rotmat(global_quat[None, ...])[0, ...]
self.global_means = local_means @ rot_cur_frame.T + global_trans
return self.global_means

def get_quats(self, global_quat, global_trans=None):
local_quats = self.quats / self.quats.norm(dim=-1, keepdim=True)
global_quats = quat_mult(global_quat[None, ...], local_quats)
return global_quats

def get_fourier_features(self, x):
scaled_x = x * self.config.fourier_features_scale
input_is_normalized = (self.config.fourier_in_space == 'temporal')
idft_base = IDFT(scaled_x, self.config.fourier_features_dim, input_is_normalized).to(self.device)
return torch.sum(self.features_dc * idft_base[..., None], dim=1, keepdim=False)

def get_true_features_dc(self, timestamp=None, cam_obj_yaw=None):
if self.config.fourier_features_dim is None:
return self.features_dc
normalized_x = timestamp if self.config.fourier_in_space == 'temporal' else cam_obj_yaw
assert normalized_x is not None
return self.get_fourier_features(normalized_x)

def get_gaussian_rgbs(self, camera_to_worlds, timestamp, device=None):
device = device if device is not None else self.device
assert device != torch.device("cpu"), "`sphereical_harmonics` in `gsplat` only supports CUDA"
true_features_dc = self.get_true_features_dc(timestamp, None)
colors = torch.cat((true_features_dc[:, None, :], self.features_rest), dim=1).to(device)
if self.sh_degree > 0:
viewdirs = self.global_means.detach().to(device) - camera_to_worlds[..., :3, 3].to(device) # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
else:
rgbs = torch.sigmoid(colors[:, 0, :])

return rgbs

def get_opacity(self):
return torch.sigmoid(self.gauss_params['opacities']).squeeze(-1)

def _get_log_pose_from_timestamp(self, timestamp):
self.log_timestamps = self.log_timestamps.to(self.device)
relative_timestamp = timestamp - self.log_start_time

diffs = relative_timestamp - self.log_timestamps
prev_frame = torch.argmin(torch.where(diffs >= 0, diffs, float('inf')))
next_frame = torch.argmin(torch.where(diffs <= 0, -diffs, float('inf')))

if next_frame == prev_frame:
# Timestamp exactly matches a frame, no interpolation needed
return self.log_quats[next_frame], self.log_trans[next_frame], timestamp

# Calculate interpolation factor
t = (relative_timestamp - self.log_timestamps[prev_frame]) / (self.log_timestamps[next_frame] - self.log_timestamps[prev_frame])

# Interpolate quaternions (using slerp) and translations
quat_interp = interpolate_quats(self.log_quats[prev_frame], self.log_quats[next_frame], t).squeeze()
trans_interp = torch.lerp(self.log_trans[prev_frame], self.log_trans[next_frame], t)

return quat_interp, trans_interp, timestamp

def _decide_global_pose(self, quat=None, trans=None, timestamp=None):
if self.static_in_log:
return self.log_quats, self.log_trans, timestamp
if quat is not None and trans is not None:
assert quat.shape == (4,) and trans.shape == (3,)
return quat.float(), trans.float(), timestamp
return self._get_log_pose_from_timestamp(timestamp)

def get_global_gaussians(self, quat=None, trans=None, timestamp=None, **kwargs):
quat, trans, timestamp = self._decide_global_pose(
quat=quat,
trans=trans,
timestamp=timestamp,
)

return {
"means": self.get_means(global_trans=trans, global_quat=quat),
"scales": self.get_scales(),
"quats": self.get_quats(global_trans=trans, global_quat=quat),
"opacities": self.get_opacity(),
}
83 changes: 83 additions & 0 deletions src/render/src/gaussian_model/rigid_object_mirrored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
try:
from gsplat.cuda._wrapper import spherical_harmonics
except ImportError:
print("Please install gsplat>=1.0.0")

from ..utils.gaussian_utils import quat_mult, quat_to_rotmat
from .rigid_object import RigidPortableSubModel


def flip_spherical_harmonics(coeff):
"""
Flip the spherical harmonics coefficients along the y-axis.

Args:
coeff (torch.Tensor): A tensor of shape [N, 16, 3], where N is the number of Gaussians,
16 is the number of spherical harmonics coefficients (up to degree l=3),
and 3 is the feature dimension.

Returns:
torch.Tensor: The flipped spherical harmonics coefficients.
"""
# Indices corresponding to m < 0 for l up to 3
indices_m_negative = [1, 4, 5, 9, 10, 11]

# Create a flip factor tensor of ones and minus ones
flip_factors = torch.ones(coeff.shape[1], device=coeff.device)
flip_factors[indices_m_negative] = -1

# Reshape flip_factors to [1, 16, 1] for broadcasting
flip_factors = flip_factors.view(1, -1, 1)

# Apply the flip factors to the coefficients
flipped_coeff = coeff * flip_factors

return flipped_coeff


class MirroredRigidPortableSubModel(RigidPortableSubModel):
MODEL_TYPE = "mirrored"

def get_means(self, global_quat, global_trans):
local_means: torch.Tensor = self.gauss_params['means']
local_means = local_means.unsqueeze(0).repeat(2, 1, 1)
local_means[1, :, :] = local_means[1, :, :] * local_means.new_tensor([1, -1, 1]).view(1, 3) # flip y
local_means = local_means.view(-1, 3)

rot_cur_frame = quat_to_rotmat(global_quat[None, ...])[0, ...]
self.global_means = local_means @ rot_cur_frame.T + global_trans
return self.global_means

def get_quats(self, global_quat, global_trans=None):
local_quats = self.quats / self.quats.norm(dim=-1, keepdim=True)
local_quats = local_quats.unsqueeze(0).repeat(2, 1, 1)
local_quats[1, :, :] = local_quats[1, :, :] * local_quats.new_tensor([1, -1, 1, -1]).view(1, 4) # flip quats at y axis
local_quats = local_quats.view(-1, 4)

global_quats = quat_mult(global_quat[None, ...], local_quats)
return global_quats

def get_scales(self):
return torch.exp(self.scales).repeat(2, 1)

def get_gaussian_rgbs(self, camera_to_worlds, timestamp, device=None):
device = device if device is not None else self.device
assert device != torch.device("cpu"), "`sphereical_harmonics` in `gsplat` only supports CUDA"
true_features_dc = self.get_true_features_dc(timestamp, None)
colors = torch.cat((true_features_dc[:, None, :], self.features_rest), dim=1).to(device)
colors = colors.unsqueeze(0).repeat(2, 1, 1, 1)
colors[1, ...] = flip_spherical_harmonics(colors[1, ...])
colors = colors.view(-1, 16, 3)
if self.sh_degree > 0:
viewdirs = self.global_means.detach().to(device) - camera_to_worlds[..., :3, 3].to(device) # (C, N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
else:
rgbs = torch.sigmoid(colors[:, 0, :])

return rgbs

def get_opacity(self):
return torch.sigmoid(self.gauss_params['opacities']).squeeze(-1).repeat(2)
Loading