From a325e4ef00bd2726efbc0c83cf1466bccfc217d2 Mon Sep 17 00:00:00 2001 From: Mason Peterson Date: Wed, 15 Jan 2025 09:01:20 -0500 Subject: [PATCH 1/3] enable demo to be run on heterogeneous robot teams Previously the same data.yaml params had to be used for all robots (meaning RGB-D images and pose information had to be of the same format). Changes to the params python classes (data_params.py, fastsam_params.py, and mapper_params.py) enable different image and pose information formats. --- demo/demo.py | 20 +++-- demo/mapping.py | 134 ++++++++++++++++++++------------- roman/params/data_params.py | 27 ++++--- roman/params/fastsam_params.py | 2 +- roman/params/mapper_params.py | 2 +- 5 files changed, 116 insertions(+), 69 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 4e9d6be..5b5a5ce 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -88,10 +88,6 @@ if not args.skip_map: - # TODO: support including different data params for different runs - # currently can only use the same data params for all runs - args.run = None - for i, run in enumerate(data_params.runs): if args.skip_indices and i in args.skip_indices: continue @@ -104,7 +100,21 @@ os.environ[data_params.run_env] = run print(f"Mapping: {run}") - mapping.mapping(args) + mapping_viz_params = \ + mapping.VisualizationParams( + viz_map=args.viz_map, + viz_observations=args.viz_observations, + viz_3d=args.viz_3d, + vid_rate=args.vid_rate, + save_img_data=args.save_img_data + ) + mapping.mapping( + params_path=args.params, + output_path=args.output, + run_name=run, + max_time=args.max_time, + viz_params=mapping_viz_params + ) if not args.skip_align: # TODO: support ground truth pose file for validation diff --git a/demo/mapping.py b/demo/mapping.py index 6b8c6c3..4a756c1 100644 --- a/demo/mapping.py +++ b/demo/mapping.py @@ -15,17 +15,14 @@ from matplotlib.animation import FuncAnimation, FFMpegWriter import argparse import pickle -import tqdm -import yaml import time import cv2 as cv -import signal -import sys import open3d as o3d import logging import os from os.path import expandvars from threading import Thread +from dataclasses import dataclass from roman.map.run import ROMANMapRunner from roman.params.data_params import DataParams @@ -36,6 +33,15 @@ from robotdatapy.data import ImgData from merge_demo_output import merge_demo_output +@dataclass +class VisualizationParams: + viz_map: bool = False + viz_observations: bool = False + viz_3d: bool = False + vid_rate: float = 1.0 + save_img_data: bool = False + + def extract_params(data_params_path, fastsam_params_path, mapper_params_path, run_name=None): assert os.path.exists(data_params_path), "Data params file does not exist." data_params = DataParams.from_yaml(data_params_path, run=run_name) @@ -52,15 +58,21 @@ def extract_params(data_params_path, fastsam_params_path, mapper_params_path, ru return data_params, fastsam_params, mapper_params -def run(args, data_params: DataParams, fastsam_params: FastSAMParams, mapper_params: MapperParams): +def run( + data_params: DataParams, + fastsam_params: FastSAMParams, + mapper_params: MapperParams, + output_path: str, + viz_params: VisualizationParams = VisualizationParams() +): runner = ROMANMapRunner(data_params=data_params, fastsam_params=fastsam_params, mapper_params=mapper_params, - verbose=True, viz_map=args.viz_map, - viz_observations=args.viz_observations, - viz_3d=args.viz_3d, - save_viz=args.save_img_data) + verbose=True, viz_map=viz_params.viz_map, + viz_observations=viz_params.viz_observations, + viz_3d=viz_params.viz_3d, + save_viz=viz_params.save_img_data) # Setup logging # TODO: add support for logfile @@ -74,11 +86,11 @@ def run(args, data_params: DataParams, fastsam_params: FastSAMParams, mapper_par print("Running segment tracking! Start time {:.1f}, end time {:.1f}".format(runner.t0, runner.tf)) wc_t0 = time.time() - vid = args.viz_map or args.viz_observations + vid = viz_params.viz_map or viz_params.viz_observations if vid: fc = cv.VideoWriter_fourcc(*"mp4v") - video_file = os.path.expanduser(expandvars(args.output)) + ".mp4" - fps = int(np.max([5., args.vid_rate*1/data_params.dt])) + video_file = os.path.expanduser(expandvars(output_path)) + ".mp4" + fps = int(np.max([5., viz_params.vid_rate*1/data_params.dt])) if fastsam_params.rotate_img not in ['CCW', 'CW']: width = runner.img_data.camera_params.width height = runner.img_data.camera_params.height @@ -86,11 +98,11 @@ def run(args, data_params: DataParams, fastsam_params: FastSAMParams, mapper_par width = runner.img_data.camera_params.height height = runner.img_data.camera_params.width num_panes = 0 - if args.viz_map: + if viz_params.viz_map: num_panes += 1 - if args.viz_observations: + if viz_params.viz_observations: num_panes += 1 - if args.viz_3d: + if viz_params.viz_3d: num_panes += 1 video = cv.VideoWriter(video_file, fc, fps, (width*num_panes, height)) @@ -110,25 +122,25 @@ def run(args, data_params: DataParams, fastsam_params: FastSAMParams, mapper_par print(f"Number of poses: {len(runner.mapper.poses_flu_history)}.") - if args.output: - pkl_path = os.path.expanduser(expandvars(args.output)) + ".pkl" - pkl_file = open(pkl_path, 'wb') - pickle.dump(runner.mapper.get_roman_map(), pkl_file, -1) - logging.info(f"Saved tracker, poses_flu_history to file: {pkl_path}.") - pkl_file.close() - - timing_file = os.path.expanduser(expandvars(args.output)) + ".time.txt" - with open(timing_file, 'w') as f: - f.write(f"dt: {data_params.dt}\n\n") - f.write(f"AVERAGE TIMES\n") - f.write(f"fastsam: {np.mean(runner.processing_times.fastsam_times):.3f}\n") - f.write(f"segment_track: {np.mean(runner.processing_times.map_times):.3f}\n") - f.write(f"total: {np.mean(runner.processing_times.total_times):.3f}\n") - f.write(f"TOTAL TIMES\n") - f.write(f"total: {np.sum(runner.processing_times.total_times):.2f}\n") + # Output results + pkl_path = os.path.expanduser(expandvars(output_path)) + ".pkl" + pkl_file = open(pkl_path, 'wb') + pickle.dump(runner.mapper.get_roman_map(), pkl_file, -1) + logging.info(f"Saved tracker, poses_flu_history to file: {pkl_path}.") + pkl_file.close() + + timing_file = os.path.expanduser(expandvars(output_path)) + ".time.txt" + with open(timing_file, 'w') as f: + f.write(f"dt: {data_params.dt}\n\n") + f.write(f"AVERAGE TIMES\n") + f.write(f"fastsam: {np.mean(runner.processing_times.fastsam_times):.3f}\n") + f.write(f"segment_track: {np.mean(runner.processing_times.map_times):.3f}\n") + f.write(f"total: {np.mean(runner.processing_times.total_times):.3f}\n") + f.write(f"TOTAL TIMES\n") + f.write(f"total: {np.sum(runner.processing_times.total_times):.2f}\n") - if args.save_img_data: - img_data_path = os.path.expanduser(expandvars(args.output)) + ".img_data.npz" + if viz_params.save_img_data: + img_data_path = os.path.expanduser(expandvars(output_path)) + ".img_data.npz" print(f"Saving visualization to {img_data_path}") img_data = ImgData(times=runner.mapper.times_history, imgs=runner.viz_imgs, data_type='raw') img_data.to_npz(img_data_path) @@ -136,46 +148,48 @@ def run(args, data_params: DataParams, fastsam_params: FastSAMParams, mapper_par del runner return -def mapping(args): - # TODO: start to fill in here. I think we need the option to either have one set of data params - # or to input one for each robot. Need to think about how to do this cleanly. - data_params_path = expandvars_recursive(f"{args.params}/data.yaml") - mapper_params_path = expandvars_recursive(f"{args.params}/mapper.yaml") - fastsam_params_path = expandvars_recursive(f"{args.params}/fastsam.yaml") +def mapping( + params_path: str, + output_path: str, + run_name: str = None, + max_time: float = None, + viz_params: VisualizationParams = VisualizationParams() +): + data_params_path = expandvars_recursive(f"{params_path}/data.yaml") + mapper_params_path = expandvars_recursive(f"{params_path}/mapper.yaml") + fastsam_params_path = expandvars_recursive(f"{params_path}/fastsam.yaml") - if args.max_time is not None: - output = args.output + if max_time is not None: try: mapping_iter = 0 while True: data_params, fastsam_params, mapper_params = \ - extract_params(data_params_path, fastsam_params_path, mapper_params_path, run_name=args.run) + extract_params(data_params_path, fastsam_params_path, mapper_params_path, run_name=run_name) data_params.time_params = { - 't0': args.max_time * mapping_iter, - 'tf': args.max_time * (mapping_iter + 1), + 't0': max_time * mapping_iter, + 'tf': max_time * (mapping_iter + 1), 'relative': True} - args.output = f"{output}_{mapping_iter}" - - run(args, data_params, fastsam_params, mapper_params) + run(data_params, fastsam_params, mapper_params, + output_path=f"{output_path}_{mapping_iter}", viz_params=viz_params) mapping_iter += 1 except: - demo_output_files = [f"{output}_{mi}.pkl" for mi in range(mapping_iter)] - merge_demo_output(demo_output_files, f"{output}.pkl") + demo_output_files = [f"{output_path}_{mi}.pkl" for mi in range(mapping_iter)] + merge_demo_output(demo_output_files, f"{output_path}.pkl") else: data_params, fastsam_params, mapper_params = \ - extract_params(data_params_path, fastsam_params_path, mapper_params_path, run_name=args.run) - run(args, data_params, fastsam_params, mapper_params) + extract_params(data_params_path, fastsam_params_path, mapper_params_path, run_name=run_name) + run(data_params, fastsam_params, mapper_params, output_path, viz_params) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-p', '--params', type=str, help='Path to params file', required=True) + parser.add_argument('-o', '--output', type=str, help='Path to output file', required=True) parser.add_argument('--max-time', type=float, default=None) - parser.add_argument('-o', '--output', type=str, help='Path to output file', required=False, default=None) parser.add_argument('-m', '--viz-map', action='store_true', help='Visualize map') parser.add_argument('-v', '--viz-observations', action='store_true', help='Visualize observations') parser.add_argument('-3', '--viz-3d', action='store_true', help='Visualize in 3D') @@ -184,4 +198,18 @@ def mapping(args): parser.add_argument('-r', '--run', type=str, help='Robot run', default=None) args = parser.parse_args() - mapping(args) \ No newline at end of file + viz_params = VisualizationParams( + viz_map=args.viz_map, + viz_observations=args.viz_observations, + viz_3d=args.viz_3d, + vid_rate=args.vid_rate, + save_img_data=args.save_img_data + ) + + mapping( + params_path=args.params, + output_path=args.output, + run_name=args.run, + max_time=args.max_time, + viz_params=viz_params + ) \ No newline at end of file diff --git a/roman/params/data_params.py b/roman/params/data_params.py index e150fc0..49ef71f 100644 --- a/roman/params/data_params.py +++ b/roman/params/data_params.py @@ -100,7 +100,7 @@ def _find_transformation(self, param_dict) -> np.array: else: raise ValueError("Invalid string.") elif param_dict['input_type'] == 'tf': - img_file_path = expandvars_recursive(self.params["img_data"]["path"]) + img_file_path = expandvars_recursive(self.params_dict["path"]) T = PoseData.static_tf_from_bag( expandvars_recursive(img_file_path), expandvars_recursive(param_dict['parent']), @@ -135,17 +135,26 @@ def __post_init__(self): def from_yaml(cls, yaml_path: str, run: str = None): with open(yaml_path) as f: data = yaml.safe_load(f) - if run is not None: - data = data[run] + if run is None: + return cls( + None, None, None, + dt=data['dt'] if 'dt' in data else 1/6, + runs=data['runs'] if 'runs' in data else None, + run_env=data['run_env'] if 'run_env' in data else None + ) + elif run in data: + run_data = data[run] + else: + run_data = data return cls( - ImgDataParams.from_dict(data['img_data']), - ImgDataParams.from_dict(data['depth_data']), - PoseDataParams.from_dict(data['pose_data']), - dt=data['dt'] if 'dt' in data else 1/6, + ImgDataParams.from_dict(run_data['img_data']), + ImgDataParams.from_dict(run_data['depth_data']), + PoseDataParams.from_dict(run_data['pose_data']), + dt=run_data['dt'] if 'dt' in run_data else 1/6, runs=data['runs'] if 'runs' in data else None, run_env=data['run_env'] if 'run_env' in data else None, - time_params=data['time_params'] if 'time_params' in data else None, - kitti=data['kitti'] if 'kitti' in data else False + time_params=run_data['time_params'] if 'time_params' in run_data else None, + kitti=run_data['kitti'] if 'kitti' in run_data else False ) @cached_property diff --git a/roman/params/fastsam_params.py b/roman/params/fastsam_params.py index 2431490..b3ca3e7 100644 --- a/roman/params/fastsam_params.py +++ b/roman/params/fastsam_params.py @@ -71,7 +71,7 @@ class FastSAMParams: def from_yaml(cls, yaml_path: str, run: str = None): with open(yaml_path) as f: data = yaml.safe_load(f) - if run is not None: + if run is not None and run in data: data = data[run] return cls(**data) \ No newline at end of file diff --git a/roman/params/mapper_params.py b/roman/params/mapper_params.py index 4f61f2b..a500066 100644 --- a/roman/params/mapper_params.py +++ b/roman/params/mapper_params.py @@ -57,6 +57,6 @@ class MapperParams(): def from_yaml(cls, yaml_path: str, run: str = None): with open(yaml_path) as f: data = yaml.safe_load(f) - if run is not None: + if run is not None and run in data: data = data[run] return cls(**data) \ No newline at end of file From 8f61f0de7a06e2c998e70f056179a2560bfa4666 Mon Sep 17 00:00:00 2001 From: Mason Peterson Date: Wed, 15 Jan 2025 16:43:14 -0500 Subject: [PATCH 2/3] Move offline rpgo params and submap align params to params submodule Previously were located in offline_rpgo and submap_align modules respectively. Small clean up. --- demo/demo.py | 25 +++++++++---------- demo/loop_closures.py | 2 +- demo/params/demo/offline_rpgo.yaml | 3 ++- roman/align/results.py | 2 +- roman/align/submap_align.py | 2 +- roman/map/map.py | 2 +- .../offline_rpgo_params.py} | 16 +++++++++++- .../submap_align_params.py} | 12 +++++++++ roman/viz.py | 1 + 9 files changed, 46 insertions(+), 19 deletions(-) rename roman/{offline_rpgo/params.py => params/offline_rpgo_params.py} (64%) rename roman/{align/params.py => params/submap_align_params.py} (96%) diff --git a/demo/demo.py b/demo/demo.py index 5b5a5ce..1232dfc 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -19,7 +19,7 @@ import os import yaml -from roman.align.params import SubmapAlignInputOutput, SubmapAlignParams +from roman.params.submap_align_params import SubmapAlignInputOutput, SubmapAlignParams from roman.align.submap_align import submap_align from roman.offline_rpgo.extract_odom_g2o import roman_map_pkl_to_g2o from roman.offline_rpgo.g2o_file_fusion import create_config, g2o_file_fusion @@ -28,19 +28,19 @@ from roman.offline_rpgo.g2o_and_time_to_pose_data import g2o_and_time_to_pose_data from roman.offline_rpgo.evaluate import evaluate from roman.offline_rpgo.edit_g2o_edge_information import edit_g2o_edge_information -from roman.offline_rpgo.params import OfflineRPGOParams +from roman.params.offline_rpgo_params import OfflineRPGOParams from roman.params.data_params import DataParams import mapping if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-p', '--params', default=None, type=str, help='Path to params directory.' + - ' Params can include mapping.yaml with mapping parameters. gt_pose.yaml with ground' + - ' truth parameters for evaluation. submap_align.yaml for loop closure parameters.' + - ' offline_rpgo.yaml for pose graph optimization parameters. Only mapping.yaml is required,' + - ' although a list of mapping params can be given with --mapping-params instead.' + - ' When loading these files, --run-env will be set with the run name.', required=True) + parser.add_argument('-p', '--params', default=None, type=str, help='Path to params directory. ' + + 'Params can include the following files: data.yaml, fastsam.yaml, ' + + 'mapper.yaml, submap_align.yaml, and offline_rpgo.yaml. Only data.yaml ' + + 'is required to be provided. Parameter defaults and definitions can be ' + + 'found in the roman.params module. Additional information can be found ' + + 'here: https://github.com/mit-acl/ROMAN/blob/main/demo/README.md', required=True) parser.add_argument('-o', '--output-dir', type=str, help='Path to output directory', required=True, default=None) parser.add_argument('-m', '--viz-map', action='store_true', help='Visualize map') @@ -48,7 +48,6 @@ parser.add_argument('-3', '--viz-3d', action='store_true', help='Visualize 3D') parser.add_argument('--vid-rate', type=float, help='Video playback rate', default=1.0) parser.add_argument('-d', '--save-img-data', action='store_true', help='Save video frames as ImgData class') - parser.add_argument('-s', '--sparse-pgo', action='store_true', help='Use sparse pose graph optimization') parser.add_argument('-n', '--num-req-assoc', type=int, help='Number of required associations', default=4) # parser.add_argument('--set-env-vars', type=str) parser.add_argument('--max-time', type=float, default=None, help='If the input data is too large, this allows a maximum time' + @@ -142,7 +141,7 @@ submap_align(sm_params=submap_align_params, sm_io=sm_io) if not args.skip_rpgo: - min_keyframe_dist = 0.01 if not args.sparse_pgo else 2.0 + min_keyframe_dist = 0.01 if not offline_rpgo_params.sparsified else 2.0 # Create g2o files for odometry for i, run in enumerate(data_params.runs): roman_map_pkl_to_g2o( @@ -196,7 +195,7 @@ g2o_file_fusion(g2o_fusion_config, dense_g2o_file, thresh=args.num_req_assoc) # Add loop closures to odometry g2o files - if args.sparse_pgo: + if offline_rpgo_params.sparsified: final_g2o_file = os.path.join(args.output_dir, "offline_rpgo", "odom_and_lc.g2o") combine_loop_closures( g2o_reference=odom_sparse_all_g2o_file, @@ -251,7 +250,7 @@ # Save csv files with resulting trajectories for i, run in enumerate(data_params.runs): pose_data = g2o_and_time_to_pose_data(result_g2o_file, - odom_sparse_all_time_file if args.sparse_pgo else odom_dense_all_time_file, + odom_sparse_all_time_file if offline_rpgo_params.sparsified else odom_dense_all_time_file, robot_id=i) pose_data.to_csv(os.path.join(args.output_dir, "offline_rpgo", f"{run}.csv")) print(f"Saving {run} pose data to {os.path.join(args.output_dir, 'offline_rpgo', f'{run}.csv')}") @@ -262,7 +261,7 @@ print("============") print(evaluate( result_g2o_file, - odom_sparse_all_time_file if args.sparse_pgo else odom_dense_all_time_file, + odom_sparse_all_time_file if offline_rpgo_params.sparsified else odom_dense_all_time_file, {i: gt_files[i] for i in range(len(gt_files))}, {i: data_params.runs[i] for i in range(len(data_params.runs))}, data_params.run_env, diff --git a/demo/loop_closures.py b/demo/loop_closures.py index 1a713cc..a9ad715 100644 --- a/demo/loop_closures.py +++ b/demo/loop_closures.py @@ -1,6 +1,6 @@ import argparse -from roman.align.params import SubmapAlignInputOutput, SubmapAlignParams +from roman.params.submap_align_params import SubmapAlignInputOutput, SubmapAlignParams from roman.align.submap_align import submap_align if __name__ == '__main__': diff --git a/demo/params/demo/offline_rpgo.yaml b/demo/params/demo/offline_rpgo.yaml index fc7ce97..2022a56 100644 --- a/demo/params/demo/offline_rpgo.yaml +++ b/demo/params/demo/offline_rpgo.yaml @@ -3,4 +3,5 @@ odom_t_std: 0.02 # 2 cm odom_r_std: 0.001745 # .1 deg lc_t_std: 2.0 # 2 m -lc_r_std: 0.1745 # 10 deg \ No newline at end of file +lc_r_std: 0.1745 # 10 deg +sparsified: False \ No newline at end of file diff --git a/roman/align/results.py b/roman/align/results.py index bc6c936..a6f7fc1 100644 --- a/roman/align/results.py +++ b/roman/align/results.py @@ -11,7 +11,7 @@ from roman.utils import transform_rm_roll_pitch from roman.map.map import ROMANMap, SubmapParams -from roman.align.params import SubmapAlignInputOutput, SubmapAlignParams +from roman.params.submap_align_params import SubmapAlignInputOutput, SubmapAlignParams from roman.object.segment import Segment diff --git a/roman/align/submap_align.py b/roman/align/submap_align.py index 775d0cc..bd9fb3a 100644 --- a/roman/align/submap_align.py +++ b/roman/align/submap_align.py @@ -22,7 +22,7 @@ from roman.align.object_registration import InsufficientAssociationsException from roman.align.dist_reg_with_pruning import GravityConstraintError from roman.utils import object_list_bounds, transform_rm_roll_pitch -from roman.align.params import SubmapAlignParams, SubmapAlignInputOutput +from roman.params.submap_align_params import SubmapAlignParams, SubmapAlignInputOutput from roman.align.results import save_submap_align_results, SubmapAlignResults OVERLAP_EPS = 0.1 diff --git a/roman/map/map.py b/roman/map/map.py index a90228f..bae0a4c 100644 --- a/roman/map/map.py +++ b/roman/map/map.py @@ -11,7 +11,7 @@ from robotdatapy.data.pose_data import PoseData -from roman.align.params import SubmapAlignParams +from roman.params.submap_align_params import SubmapAlignParams from roman.object.segment import Segment, SegmentMinimalData from roman.utils import transform_rm_roll_pitch diff --git a/roman/offline_rpgo/params.py b/roman/params/offline_rpgo_params.py similarity index 64% rename from roman/offline_rpgo/params.py rename to roman/params/offline_rpgo_params.py index c0ca052..9a6b415 100644 --- a/roman/offline_rpgo/params.py +++ b/roman/params/offline_rpgo_params.py @@ -1,3 +1,15 @@ +########################################################### +# +# offline_rpgo_params.py +# +# Params for ROMAN offline RPGO SLAM. +# +# Authors: Mason Peterson +# +# Jan. 15, 2025 +# +########################################################### + import numpy as np from dataclasses import dataclass @@ -15,7 +27,9 @@ class OfflineRPGOParams: # loop closure covariance params lc_t_std: float = 0.5 lc_r_std: float = np.deg2rad(0.5) - + + # sparse or dense + sparsified: bool = True @classmethod def from_yaml(cls, yaml_file): diff --git a/roman/align/params.py b/roman/params/submap_align_params.py similarity index 96% rename from roman/align/params.py rename to roman/params/submap_align_params.py index 5aed71f..ded40cc 100644 --- a/roman/align/params.py +++ b/roman/params/submap_align_params.py @@ -1,3 +1,15 @@ +########################################################### +# +# submap_align_params.py +# +# Params for ROMAN object registration. +# +# Authors: Mason Peterson +# +# Jan. 15, 2025 +# +########################################################### + import numpy as np from dataclasses import dataclass, field diff --git a/roman/viz.py b/roman/viz.py index dcdb681..056a5de 100644 --- a/roman/viz.py +++ b/roman/viz.py @@ -77,6 +77,7 @@ def visualize_observations_on_img(t, img, mapper, observations, reprojected_bbox cv.rectangle(img_fastsam, np.array([bbox[0][0], bbox[0][1]]).astype(np.int32), np.array([bbox[1][0], bbox[1][1]]).astype(np.int32), color=rand_color.tolist(), thickness=2) +# TODO: rename, this is confusing. This is visualizing the 3D world (does not write on top of another image) def visualize_3d_on_img(t: float, pose_flu: np.ndarray, mapper: Mapper) -> np.ndarray: """ Visualizes a 3D map onto an image (with camera params used by the mapper). From 21c57df14e0985c771e7b185e0ee4964eeb88027 Mon Sep 17 00:00:00 2001 From: Mason Peterson Date: Wed, 15 Jan 2025 18:08:01 -0500 Subject: [PATCH 3/3] Add option to override runs in demo data.yaml Also, save ate results to output file. --- demo/demo.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 1232dfc..942dc27 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -43,6 +43,8 @@ 'here: https://github.com/mit-acl/ROMAN/blob/main/demo/README.md', required=True) parser.add_argument('-o', '--output-dir', type=str, help='Path to output directory', required=True, default=None) + parser.add_argument('-r', '--runs', type=str, nargs='+', required=False, default=None, + help='Run names. Overrides runs field in data.yaml') parser.add_argument('-m', '--viz-map', action='store_true', help='Visualize map') parser.add_argument('-v', '--viz-observations', action='store_true', help='Visualize observations') parser.add_argument('-3', '--viz-3d', action='store_true', help='Visualize 3D') @@ -69,6 +71,8 @@ offline_rpgo_params = OfflineRPGOParams.from_yaml(offline_rpgo_params_path) \ if os.path.exists(os.path.join(args.params, "offline_rpgo.yaml")) else OfflineRPGOParams() data_params = DataParams.from_yaml(os.path.join(args.params, "data.yaml")) + if args.runs is not None: + data_params.runs = args.runs # ground truth pose files if os.path.exists(os.path.join(params_dir, "gt_pose.yaml")): @@ -257,13 +261,18 @@ # Report ATE results if has_gt: - print("ATE results:") - print("============") - print(evaluate( + ate_rmse = evaluate( result_g2o_file, odom_sparse_all_time_file if offline_rpgo_params.sparsified else odom_dense_all_time_file, {i: gt_files[i] for i in range(len(gt_files))}, {i: data_params.runs[i] for i in range(len(data_params.runs))}, data_params.run_env, output_dir=args.output_dir - )) \ No newline at end of file + ) + print("ATE results:") + print("============") + print(ate_rmse) + with open(os.path.join(args.output_dir, "offline_rpgo", "ate_rmse.txt"), 'w') as f: + print(ate_rmse, file=f) + f.close() + \ No newline at end of file