Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions agent/mappo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class MAPPOAgent(Agent):
def __init__(self, obs_dim, num_agents, action_dim = 49, lr=3e-4, gamma=0.99,
lam=0.95, clip=0.2, epochs=4, batch_size=64, hidden_dim=128):
self.num_agents = num_agents
# The observation now contains TWO maps, so we multiply the base obs_dim by 2.
self.obs_dim = (obs_dim * 2) + num_agents
# The observation is now the distance vector + one-hot agent ID
self.obs_dim = obs_dim + num_agents
self.action_dim = action_dim
self.gamma = gamma
self.lam = lam
Expand All @@ -75,10 +75,10 @@ def __init__(self, obs_dim, num_agents, action_dim = 49, lr=3e-4, gamma=0.99,
self.buffer = RolloutBuffer()

def _process_obs(self, obs, agent_id):
obs_flat = torch.tensor(obs.ravel(), dtype=torch.float32)
obs_tensor = torch.tensor(obs, dtype=torch.float32)
one_hot = torch.zeros(self.num_agents)
one_hot[agent_id] = 1.0
return torch.cat([obs_flat, one_hot])
return torch.cat([obs_tensor, one_hot])

def select_action(self, obs, agent_id, mask=None):
with torch.no_grad():
Expand All @@ -95,7 +95,7 @@ def select_action(self, obs, agent_id, mask=None):
return action.item(), log_prob.item()

def store(self, obs, agent_id, action, log_prob, reward, done, mask):
self.buffer.states.append(obs.flatten())
self.buffer.states.append(obs) # No longer need .flatten()
self.buffer.agent_ids.append(agent_id)
self.buffer.actions.append(action)
self.buffer.log_probs.append(log_prob)
Expand Down
94 changes: 94 additions & 0 deletions ogm/occupancy_grid_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,100 @@ def calc_pairwise_norms(self, mod_pos):
def check_final(self, tol=1e-6):
return np.allclose(self.final_pairwise_norms, self.curr_pairwise_norms, atol=tol)

def compute_pairwise_sqdist(self, positions):
"""Compute pairwise squared distances as integers.

Args:
positions: Dictionary mapping module numbers to their positions (x,y,z)

Returns:
n×n integer matrix of squared distances
"""
n = len(positions)
D = np.zeros((n, n), dtype=np.int64)

for i in range(n):
xi = positions[i + 1] # modules are 1-indexed
for j in range(i + 1, n):
xj = positions[j + 1] # modules are 1-indexed
dx = xi[0] - xj[0]
dy = xi[1] - xj[1]
dz = xi[2] - xj[2]
sqdist = dx*dx + dy*dy + dz*dz
D[i, j] = sqdist
D[j, i] = sqdist

return D

def pair_index(self, i, j, n):
"""Convert pair (i,j) to canonical index in upper triangular matrix.

Args:
i, j: Module indices (0-indexed, i < j)
n: Total number of modules

Returns:
Index in canonical pair ordering
"""
if i >= j:
raise ValueError("i must be less than j")
return i * n - (i * (i + 1)) // 2 + (j - i - 1)

def generate_pair_bounties(self, D_final, base_value=1.0):
"""Generate bounty structure for all pairs.

Args:
D_final: Final squared distance matrix
base_value: Base bounty value for each pair

Returns:
Tuple of (pairs_list, base_values_array)
"""
n = D_final.shape[0]
pairs = []
base_values = []

for i in range(n):
for j in range(i + 1, n):
pairs.append((i, j))
base_values.append(base_value)

return pairs, np.array(base_values)

def get_pairwise_dist_vector(self, positions):
"""
Computes the vech(D(sigma)) vector of pairwise distances.
Returns the upper triangular entries of the distance matrix (excluding diagonal).
"""
n = len(self.modules)
if n < 2:
return np.array([])

# The number of unique pairs is n * (n - 1) / 2
num_pairs = n * (n - 1) // 2
dist_vector = np.zeros(num_pairs)

idx = 0
for i in range(1, n + 1):
for j in range(i + 1, n + 1):
pos_i = np.array(positions[i])
pos_j = np.array(positions[j])
dist_vector[idx] = np.linalg.norm(pos_i - pos_j)
idx += 1

return dist_vector

def get_state_observation_vector(self):
"""
Generates the full observation vector for the agent, containing
the difference between target and current pairwise distances.
"""
current_dists = self.get_pairwise_dist_vector(self.module_positions)
target_dists = self.get_pairwise_dist_vector(self.final_module_positions)

# The observation is the difference between target and current distances
return target_dists - current_dists

# need to calculate edges first
def calculate_edges(self, modules, module_positions):
edges = []
Expand Down
182 changes: 163 additions & 19 deletions ogm/ogm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,47 @@
Environment wrapper for OGM
"""
class OGMEnv:
def __init__(self, step_cost = -0.01, max_steps = None):
def __init__(self, step_cost = -0.005, max_steps = None, enable_bounty_reward = True,
bounty_gamma = 0.999, bounty_eta = 2.0, bounty_base_value = 1.0,
bounty_total_frac_of_success = 0.2, bounty_cap_per_step = 20.0,
enable_potential_reward = True, potential_scale = 1.0,
potential_normalize = 'n2', success_bonus = 100.0):
self.step_cost = step_cost
self.max_steps = max_steps
self.steps_taken = 0
self.ogm = None
self.action_buffer = []
self.action_count = 0
self.num_modules = None
self.initial_norm_diff = None
self.initial_norm_diff = None

# Bounty system parameters
self.enable_bounty_reward = enable_bounty_reward
self.bounty_params = {
'gamma': bounty_gamma, # decay factor
'eta': bounty_eta, # multiplier coefficient
'base_value': bounty_base_value
}
self.bounty_total_frac_of_success = bounty_total_frac_of_success
self.bounty_cap_per_step = bounty_cap_per_step

# Bounty system state
self.final_sqdist = None
self.curr_sqdist = None
self.pairs = None
self.bounty_base_values = None
self.bounty_multipliers = None
self.bounty_available = None
self.R0 = None # initial number of bounties
self.last_bounty_raw = 0.0
self.last_num_matches = 0

# Potential-based reward parameters/state
self.enable_potential_reward = enable_potential_reward
self.potential_scale = potential_scale
self.potential_normalize = potential_normalize # 'n2' or 'max' (only 'n2' used currently)
self.success_bonus = success_bonus
self.prev_frob_norm = None

def reset(self, initial_config, final_config):
self.ogm = OccupancyGridMap(initial_config, final_config, len(initial_config))
Expand All @@ -26,6 +58,45 @@ def reset(self, initial_config, final_config):
self.initial_norm_diff = np.linalg.norm(
self.ogm.final_pairwise_norms - self.ogm.curr_pairwise_norms, 'fro'
)

# Initialize squared-distance matrices for potential/bounty systems
if self.enable_bounty_reward or self.enable_potential_reward:
self.final_sqdist = self.ogm.compute_pairwise_sqdist(self.ogm.final_module_positions)
self.curr_sqdist = self.ogm.compute_pairwise_sqdist(self.ogm.module_positions)

# Initialize bounty system
if self.enable_bounty_reward:
self.pairs, self.bounty_base_values = self.ogm.generate_pair_bounties(
self.final_sqdist, base_value=self.bounty_params['base_value']
)
self.bounty_multipliers = np.ones(len(self.pairs), dtype=float)

# Auto-scale bounty base values so total ~= fraction * success_bonus
desired_frac = getattr(self, 'bounty_total_frac_of_success', 0.2)
success_bonus = getattr(self, 'success_bonus', 100.0)
total_pairs = len(self.pairs) if self.pairs is not None else 0

if total_pairs > 0:
desired_total = float(desired_frac) * float(success_bonus)
current_sum = float(np.sum(self.bounty_base_values)) if self.bounty_base_values is not None else 0.0
if current_sum > 0.0:
scale = desired_total / current_sum
self.bounty_base_values = np.array(self.bounty_base_values, dtype=float) * float(scale)
# initialize availability and mark already-correct pairs as unavailable
self.bounty_available = np.ones(len(self.pairs), dtype=bool)
for idx, (i, j) in enumerate(self.pairs):
if self.curr_sqdist[i, j] == self.final_sqdist[i, j]:
self.bounty_available[idx] = False
self.R0 = int(self.bounty_available.sum())
else:
self.R0 = 0

# Initialize potential state
if self.enable_potential_reward:
self.prev_frob_norm = float(np.linalg.norm(self.curr_sqdist - self.final_sqdist, 'fro'))
else:
self.prev_frob_norm = None

return self.get_observation()

def step(self, action):
Expand All @@ -49,42 +120,115 @@ def step(self, action):

# Execute buffered actions
for module, pivot in self.action_buffer:
is_valid_action = self.ogm.calc_possible_actions()[module][pivot-1]
# Action 49 (no-op) is always valid
if pivot == 49:
is_valid_action = True
else:
is_valid_action = self.ogm.calc_possible_actions()[module][pivot-1]

if is_valid_action:
self.ogm.take_action(module, pivot)
else:
invalid_move = True

final_norm_diff = np.linalg.norm(
self.ogm.final_pairwise_norms - self.ogm.curr_pairwise_norms, 'fro'
)
# Update current squared distances after actions (for potential and/or bounty)
if self.enable_bounty_reward or self.enable_potential_reward:
self.curr_sqdist = self.ogm.compute_pairwise_sqdist(self.ogm.module_positions)

# Calculate reward
# the Frobenius norm scales with n^2 (for n modules). This could result in large reward values for large n
# so normalizing the norm difference by n^2 keeps rewards in a reasonable range
potential_reward = (self.initial_norm_diff - final_norm_diff) / (self.num_modules ** 2)
success_bonus = 100.0 if self.ogm.check_final() else 0.0
invalid_move_penalty = -1.0 if invalid_move else 0.0
reward = potential_reward + success_bonus + invalid_move_penalty + self.step_cost
# Calculate potential-based reward using squared-distance Frobenius norm
potential_reward = 0.0
if self.enable_potential_reward:
frob = float(np.linalg.norm(self.curr_sqdist - self.final_sqdist, 'fro'))
# Improvement: previous - current (positive if closer)
raw_improvement = (self.prev_frob_norm - frob)
if self.potential_normalize == 'n2':
norm_factor = (self.num_modules ** 2)
else:
norm_factor = 1.0
potential_reward = (raw_improvement / norm_factor) * self.potential_scale
# Update previous
self.prev_frob_norm = frob

self.initial_norm_diff = final_norm_diff
success_bonus = self.success_bonus if self.ogm.check_final() else 0.0
invalid_move_penalty = -1.0 if invalid_move else 0.0

# Calculate bounty reward
if self.enable_bounty_reward:
bounty_reward, bounties_collected, num_matches = self._compute_bounty_reward()
else:
bounty_reward, bounties_collected, num_matches = 0.0, 0, 0

reward = potential_reward + success_bonus + invalid_move_penalty + bounty_reward + self.step_cost

done = self.ogm.check_final() or (self.max_steps is not None and self.steps_taken >= self.max_steps)

self.action_buffer = []
self.action_count = 0

observation = self.get_observation()
info = {'step': self.steps_taken}
info = {
'step': self.steps_taken,
'potential_reward': potential_reward,
'bounty_reward': bounty_reward,
'bounty_reward_raw': getattr(self, 'last_bounty_raw', 0.0),
'bounties_collected': bounties_collected,
'num_matches': num_matches,
'bounties_remaining': int(self.bounty_available.sum()) if self.enable_bounty_reward and self.bounty_available is not None else 0,
}
return observation, reward, done, info

def _compute_bounty_reward(self):
"""Compute bounty reward for newly matched pairs.

Returns:
Tuple of (bounty_reward_capped, number_of_bounties_collected, num_matches)
- bounty_reward_capped: final bounty paid this step (after cap)
- number_of_bounties_collected: how many pair-bounties were claimed
- num_matches: same as number_of_bounties_collected (kept for clarity)
"""
newly_matched_idx = []

# Find pairs that newly match their target distances
for idx, (i, j) in enumerate(self.pairs):
if self.bounty_available[idx]:
if self.curr_sqdist[i, j] == self.final_sqdist[i, j]:
newly_matched_idx.append(idx)

num_matches = len(newly_matched_idx)
bounty_reward_raw = 0.0
R_before = int(self.bounty_available.sum()) if self.bounty_available is not None else 0

# Compute multiplier once based on R_before (order-independent within step)
multiplier = 1.0 + self.bounty_params['eta'] * (R_before / self.R0) if self.R0 and R_before > 0 else 1.0

# Sum raw payments (before any capping)
for idx in newly_matched_idx:
b0 = self.bounty_base_values[idx]
decay = self.bounty_params['gamma'] ** self.steps_taken
pay = b0 * decay * multiplier
bounty_reward_raw += pay
# mark as claimed
self.bounty_available[idx] = False

# store raw for logging/inspection
self.last_bounty_raw = float(bounty_reward_raw)
self.last_num_matches = int(num_matches)

# apply per-step cap if configured
cap = getattr(self, 'bounty_cap_per_step', None)
if cap is not None:
bounty_reward_capped = float(min(bounty_reward_raw, cap))
else:
bounty_reward_capped = float(bounty_reward_raw)

return bounty_reward_capped, len(newly_matched_idx), num_matches

def get_observation(self):
"""
Stacks the current and final grid maps to create a 2-channel observation.
This allows the agent to see both its current state and its goal state.
Returns the observation for the agent, which is the vector of
differences between target and current pairwise distances.
"""
if self.ogm is None:
raise Exception("Environment not set. call reset function")

# Shape becomes (2, grid_size, grid_size, grid_size)
return np.stack([self.ogm.curr_grid_map, self.ogm.final_grid_map], axis=0)
return self.ogm.get_state_observation_vector()
Loading