Skip to content

Commit 4bcd419

Browse files
committed
Merge branch 'feature/osi3trace-mcap-support' into temp/osi3trace-mcap-support-fix
2 parents 9afa3ce + 8d990b7 commit 4bcd419

File tree

3 files changed

+421
-5
lines changed

3 files changed

+421
-5
lines changed

osi3trace/osi_trace.py

Lines changed: 218 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
"""
44

55
import lzma
6+
from pathlib import Path
67
import struct
78

9+
from abc import ABC, abstractmethod
10+
11+
from mcap_protobuf.decoder import DecoderFactory
12+
from mcap.reader import make_reader
13+
814
from osi3.osi_sensorview_pb2 import SensorView
915
from osi3.osi_sensorviewconfiguration_pb2 import SensorViewConfiguration
1016
from osi3.osi_groundtruth_pb2 import GroundTruth
@@ -32,8 +38,8 @@
3238

3339

3440
class OSITrace:
35-
"""This class can import and decode OSI trace files."""
36-
41+
"""This class can import and decode OSI single- and multi-channel trace files."""
42+
3743
@staticmethod
3844
def map_message_type(type_name):
3945
"""Map the type name to the protobuf message type."""
@@ -43,9 +49,156 @@ def map_message_type(type_name):
4349
def message_types():
4450
"""Message types that OSITrace supports."""
4551
return list(MESSAGES_TYPE.keys())
52+
53+
_legacy_ositrace_attributes = {
54+
"type",
55+
"file",
56+
"current_index",
57+
"message_offsets",
58+
"read_complete",
59+
"message_cache",
60+
}
61+
62+
def __getattr__(self, name):
63+
"""
64+
This method forwards the getattr call for unsuccessful legacy attribute
65+
name lookups to the reader in case it is an OSITraceSingle instance.
66+
"""
67+
if name in self._legacy_ositrace_attributes and isinstance(self.reader, OSITraceSingle):
68+
return getattr(self.reader, name)
69+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
70+
71+
def __setattr__(self, name, value):
72+
"""
73+
This method overwrites the default setter and forwards setattr calls for
74+
legacy attribute names to the reader in case the reader is an
75+
OSITraceSingle instance. Otherwise it uses the default setter.
76+
"""
77+
reader = super().__getattribute__("reader") if "reader" in self.__dict__ else None
78+
if name in self._legacy_ositrace_attributes and isinstance(reader, OSITraceSingle):
79+
setattr(reader, name, value)
80+
else:
81+
super().__setattr__(name, value)
82+
83+
def __dir__(self):
84+
attrs = super().__dir__()
85+
if isinstance(self.reader, OSITraceSingle):
86+
attrs += list(self._legacy_ositrace_attributes)
87+
return attrs
88+
89+
def __init__(self, path=None, type_name="SensorView", cache_messages=False, topic=None):
90+
"""
91+
Initializes the trace reader depending on the trace file format.
92+
93+
Args:
94+
path (str): The path to the trace file.
95+
type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
96+
cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
97+
topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
98+
"""
99+
self.reader = None
100+
101+
if path is not None:
102+
self.reader = self._init_reader(Path(path), type_name, cache_messages, topic)
103+
104+
def _init_reader(self, path, type_name, cache_messages, topic):
105+
if not path.exists():
106+
raise FileNotFoundError("File not found")
107+
108+
if path.suffix.lower() == ".mcap":
109+
reader = OSITraceMulti(path, topic)
110+
if reader.get_message_type() != type_name:
111+
raise ValueError(f"Channel message type '{reader.get_message_type()}' does not match expected type '{type_name}'")
112+
return reader
113+
elif path.suffix.lower() in [".osi", ".lzma", ".xz"]:
114+
return OSITraceSingle(str(path), type_name, cache_messages)
115+
else:
116+
raise ValueError(f"Unsupported file format: '{path.suffix}'")
117+
118+
def from_file(self, path, type_name="SensorView", cache_messages=False, topic=None):
119+
"""
120+
Initializes the trace reader depending on the trace file format.
121+
122+
Args:
123+
path (str): The path to the trace file.
124+
type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
125+
cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
126+
topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
127+
"""
128+
self.reader = self._init_reader(Path(path), type_name, cache_messages, topic)
129+
130+
def restart(self, index=None):
131+
"""
132+
Restart the trace reader.
133+
134+
Note:
135+
Multi-channel traces don't support restarting from a specific index.
136+
"""
137+
return self.reader.restart(index)
138+
139+
def __iter__(self):
140+
return self.reader.__iter__()
141+
142+
def close(self):
143+
return self.reader.close()
46144

145+
def retrieve_offsets(self, limit=None):
146+
if isinstance(self.reader, OSITraceSingle):
147+
return self.reader.retrieve_offsets(limit)
148+
raise NotImplementedError("Offsets are only supported for single-channel traces.")
149+
150+
def retrieve_message(self, index=None, skip=False):
151+
if isinstance(self.reader, OSITraceSingle):
152+
return self.reader.retrieve_message(index, skip)
153+
raise NotImplementedError("Index-based message retrieval is only supported for single-channel traces.")
154+
155+
def get_message_by_index(self, index):
156+
if isinstance(self.reader, OSITraceSingle):
157+
return self.reader.get_message_by_index(index)
158+
raise NotImplementedError("Index-based message retrieval is only supported for single-channel traces.")
159+
160+
def get_messages_in_index_range(self, begin, end):
161+
if isinstance(self.reader, OSITraceSingle):
162+
return self.reader.get_messages_in_index_range(begin, end)
163+
raise NotImplementedError("Index-based message retrieval is only supported for single-channel traces.")
164+
165+
def get_available_topics(self):
166+
if isinstance(self.reader, OSITraceMulti):
167+
return self.reader.get_available_topics()
168+
raise NotImplementedError("Getting available topics is only supported for multi-channel traces.")
169+
170+
def get_file_metadata(self):
171+
if isinstance(self.reader, OSITraceMulti):
172+
return self.reader.get_file_metadata()
173+
raise NotImplementedError("Getting file metadata is only supported for multi-channel traces.")
174+
175+
def get_channel_metadata(self):
176+
if isinstance(self.reader, OSITraceMulti):
177+
return self.reader.get_channel_metadata()
178+
raise NotImplementedError("Getting channel metadata is only supported for multi-channel traces.")
179+
180+
181+
class ReaderBase(ABC):
182+
"""Common interface for trace readers"""
183+
184+
@abstractmethod
185+
def restart(self, index=None):
186+
pass
187+
188+
@abstractmethod
189+
def __iter__(self):
190+
pass
191+
192+
@abstractmethod
193+
def close(self):
194+
pass
195+
196+
197+
class OSITraceSingle(ReaderBase):
198+
"""OSI single-channel trace reader"""
199+
47200
def __init__(self, path=None, type_name="SensorView", cache_messages=False):
48-
self.type = self.map_message_type(type_name)
201+
self.type = OSITrace.map_message_type(type_name)
49202
self.file = None
50203
self.current_index = None
51204
self.message_offsets = None
@@ -57,7 +210,7 @@ def __init__(self, path=None, type_name="SensorView", cache_messages=False):
57210

58211
def from_file(self, path, type_name="SensorView", cache_messages=False):
59212
"""Import a trace from a file"""
60-
self.type = self.map_message_type(type_name)
213+
self.type = OSITrace.map_message_type(type_name)
61214

62215
if path.lower().endswith((".lzma", ".xz")):
63216
self.file = lzma.open(path, "rb")
@@ -186,3 +339,64 @@ def close(self):
186339
self.read_complete = False
187340
self.read_limit = None
188341
self.type = None
342+
343+
344+
class OSITraceMulti(ReaderBase):
345+
"""OSI multi-channel trace reader"""
346+
347+
def __init__(self, path, topic):
348+
self._file = open(path, "rb")
349+
self._mcap_reader = make_reader(self._file, decoder_factories=[DecoderFactory()])
350+
self._iter = None
351+
self._summary = self._mcap_reader.get_summary()
352+
available_topics = self.get_available_topics()
353+
if topic == None:
354+
topic = available_topics[0]
355+
if topic not in available_topics:
356+
raise ValueError(f"The requested topic '{topic}' is not present in the trace file.")
357+
self.topic = topic
358+
359+
def restart(self, index=None):
360+
if index != None:
361+
raise NotImplementedError("Restarting from a given index is not supported for multi-channel traces.")
362+
self._iter = None
363+
364+
def __iter__(self):
365+
"""Stateful iterator over the channel's messages in log time order."""
366+
if self._iter is None:
367+
self._iter = self._mcap_reader.iter_decoded_messages(topics=[self.topic])
368+
for message in self._iter:
369+
yield message.decoded_message
370+
371+
def close(self):
372+
if self._file:
373+
self._file.close()
374+
self._file = None
375+
self._mcap_reader = None
376+
self._summary = None
377+
self._iter = None
378+
379+
def get_available_topics(self):
380+
return [channel.topic for id, channel in self._summary.channels.items()]
381+
382+
def get_file_metadata(self):
383+
metadata = []
384+
for metadata_entry in self._mcap_reader.iter_metadata():
385+
metadata.append(metadata_entry)
386+
return metadata
387+
388+
def get_channel_metadata(self):
389+
for id, channel in self._summary.channels.items():
390+
if channel.topic == self.topic:
391+
return channel.metadata
392+
return None
393+
394+
def get_message_type(self):
395+
for channel_id, channel in self._summary.channels.items():
396+
if channel.topic == self.topic:
397+
schema = self._summary.schemas[channel.schema_id]
398+
if schema.name.startswith("osi3."):
399+
return schema.name[len("osi3.") :]
400+
else:
401+
raise ValueError(f"Schema '{schema.name}' is not an 'osi3.' schema.")
402+
return None

0 commit comments

Comments
 (0)