1+ import warnings
2+
13from collections .abc import Iterator
4+ from copy import deepcopy
25from dataclasses import dataclass , fields
36from typing import Generic , Self , TypeVar
47
@@ -36,8 +39,8 @@ def __post_init__(self):
3639 missing_attr .append (item )
3740 continue
3841 key = getattr (item , self .key_field )
39- if key in index :
40- raise ValueError (f"Duplicate { self .key_field } '{ key } ' detected." )
42+ # if key in index:
43+ # raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states
4144 index [key ] = item
4245 if missing_attr :
4346 raise AttributeError (f"Items missing attribute '{ self .key_field } ': { missing_attr } " )
@@ -72,6 +75,9 @@ def __str__(self) -> str:
7275 def names (self ) -> tuple [str , ...]:
7376 return tuple (self ._index .keys ())
7477
78+ def copy (self ) -> "Info[T]" :
79+ return deepcopy (self )
80+
7581
7682@dataclass (frozen = True )
7783class Parameter (Property ):
@@ -90,13 +96,13 @@ def add(self, parameter: Parameter) -> "ParameterInfo":
9096 # return a new ParameterInfo with parameter appended
9197 return ParameterInfo (parameters = [* list (self .items ), parameter ])
9298
93- def merge (self , other : "ParameterInfo" ) -> "ParameterInfo" :
99+ def merge (self , other : "ParameterInfo" , allow_duplicates : bool = False ) -> "ParameterInfo" :
94100 """Combine parameters from two ParameterInfo objects."""
95101 if not isinstance (other , ParameterInfo ):
96102 raise TypeError (f"Cannot merge { type (other ).__name__ } with ParameterInfo" )
97103
98104 overlapping = set (self .names ) & set (other .names )
99- if overlapping :
105+ if overlapping and not allow_duplicates :
100106 raise ValueError (f"Duplicate parameter names found: { overlapping } " )
101107
102108 return ParameterInfo (parameters = list (self .items ) + list (other .items ))
@@ -119,20 +125,24 @@ def __init__(self, data: list[Data]):
119125 def needs_exogenous_data (self ) -> bool :
120126 return any (d .is_exogenous for d in self .items )
121127
128+ @property
129+ def exogenous_names (self ) -> tuple [str , ...]:
130+ return tuple (d .name for d in self .items if d .is_exogenous )
131+
122132 def __str__ (self ) -> str :
123133 return f"data: { [d .name for d in self .items ]} \n needs exogenous data: { self .needs_exogenous_data } "
124134
125135 def add (self , data : Data ) -> "DataInfo" :
126136 # return a new DataInfo with data appended
127137 return DataInfo (data = [* list (self .items ), data ])
128138
129- def merge (self , other : "DataInfo" ) -> "DataInfo" :
139+ def merge (self , other : "DataInfo" , allow_duplicates : bool = False ) -> "DataInfo" :
130140 """Combine data from two DataInfo objects."""
131141 if not isinstance (other , DataInfo ):
132142 raise TypeError (f"Cannot merge { type (other ).__name__ } with DataInfo" )
133143
134144 overlapping = set (self .names ) & set (other .names )
135- if overlapping :
145+ if overlapping and not allow_duplicates :
136146 raise ValueError (f"Duplicate data names found: { overlapping } " )
137147
138148 return DataInfo (data = list (self .items ) + list (other .items ))
@@ -164,7 +174,7 @@ def default_coords_from_model(
164174 Self
165175 ): # TODO: Need to figure out how to include Component type was causing circular import issues
166176 states = tuple (model .state_names )
167- obs_states = tuple (model .observed_state_names )
177+ obs_states = tuple (model .observed_states )
168178 shocks = tuple (model .shock_names )
169179
170180 dim_to_labels = (
@@ -186,13 +196,13 @@ def add(self, coord: Coord) -> "CoordInfo":
186196 # return a new CoordInfo with data appended
187197 return CoordInfo (coords = [* list (self .items ), coord ])
188198
189- def merge (self , other : "CoordInfo" ) -> "CoordInfo" :
199+ def merge (self , other : "CoordInfo" , allow_duplicates : bool = False ) -> "CoordInfo" :
190200 """Combine data from two CoordInfo objects."""
191201 if not isinstance (other , CoordInfo ):
192202 raise TypeError (f"Cannot merge { type (other ).__name__ } with CoordInfo" )
193203
194204 overlapping = set (self .names ) & set (other .names )
195- if overlapping :
205+ if overlapping and not allow_duplicates :
196206 raise ValueError (f"Duplicate coord names found: { overlapping } " )
197207
198208 return CoordInfo (coords = list (self .items ) + list (other .items ))
@@ -216,21 +226,37 @@ def __str__(self) -> str:
216226 )
217227
218228 @property
219- def observed_states (self ) -> tuple [State , ...]:
229+ def observed_states (self ) -> tuple [State , ...]: # Is this needed??
220230 return tuple (s for s in self .items if s .observed )
221231
232+ @property
233+ def observed_state_names (self ) -> tuple [State , ...]:
234+ return tuple (s .name for s in self .items if s .observed )
235+
236+ @property
237+ def unobserved_state_names (self ) -> tuple [State , ...]:
238+ return tuple (s .name for s in self .items if not s .observed )
239+
222240 def add (self , state : State ) -> "StateInfo" :
223241 # return a new StateInfo with state appended
224242 return StateInfo (states = [* list (self .items ), state ])
225243
226- def merge (self , other : "StateInfo" ) -> "StateInfo" :
244+ def merge (self , other : "StateInfo" , allow_duplicates : bool = False ) -> "StateInfo" :
227245 """Combine states from two StateInfo objects."""
228246 if not isinstance (other , StateInfo ):
229247 raise TypeError (f"Cannot merge { type (other ).__name__ } with StateInfo" )
230248
231249 overlapping = set (self .names ) & set (other .names )
232- if overlapping :
233- raise ValueError (f"Duplicate state names found: { overlapping } " )
250+ if overlapping and not allow_duplicates :
251+ # This is necessary for shared states
252+ warnings .warn (
253+ f"Duplicate state names found: { overlapping } . Merge will ONLY retain unique states" ,
254+ UserWarning ,
255+ )
256+ return StateInfo (
257+ states = list (self .items )
258+ + [item for item in other .items if item .name not in overlapping ]
259+ )
234260
235261 return StateInfo (states = list (self .items ) + list (other .items ))
236262
@@ -249,13 +275,13 @@ def add(self, shock: Shock) -> "ShockInfo":
249275 # return a new ShockInfo with shock appended
250276 return ShockInfo (shocks = [* list (self .items ), shock ])
251277
252- def merge (self , other : "ShockInfo" ) -> "ShockInfo" :
278+ def merge (self , other : "ShockInfo" , allow_duplicates : bool = False ) -> "ShockInfo" :
253279 """Combine shocks from two ShockInfo objects."""
254280 if not isinstance (other , ShockInfo ):
255281 raise TypeError (f"Cannot merge { type (other ).__name__ } with ShockInfo" )
256282
257283 overlapping = set (self .names ) & set (other .names )
258- if overlapping :
284+ if overlapping and not allow_duplicates :
259285 raise ValueError (f"Duplicate shock names found: { overlapping } " )
260286
261287 return ShockInfo (shocks = list (self .items ) + list (other .items ))
0 commit comments