Skip to content

Commit a183c71

Browse files
committed
1. Updated dataclasses to include copy method and replaced raise on duplicate with warning
2. removed unnecessary imports from __init__ after deleting regression_dataclass 3. updated components and structural classes to only utilize dataclasses and pull other objects from <foo>_info dataclasses 4. updated tests to conform to dataclass api
1 parent 92e333f commit a183c71

File tree

7 files changed

+191
-731
lines changed

7 files changed

+191
-731
lines changed

pymc_extras/statespace/core/properties.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import warnings
2+
13
from collections.abc import Iterator
4+
from copy import deepcopy
25
from dataclasses import dataclass, fields
36
from typing import Generic, Self, TypeVar
47

@@ -36,8 +39,8 @@ def __post_init__(self):
3639
missing_attr.append(item)
3740
continue
3841
key = getattr(item, self.key_field)
39-
if key in index:
40-
raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
42+
# if key in index:
43+
# raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states
4144
index[key] = item
4245
if missing_attr:
4346
raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}")
@@ -72,6 +75,9 @@ def __str__(self) -> str:
7275
def names(self) -> tuple[str, ...]:
7376
return tuple(self._index.keys())
7477

78+
def copy(self) -> "Info[T]":
79+
return deepcopy(self)
80+
7581

7682
@dataclass(frozen=True)
7783
class Parameter(Property):
@@ -90,13 +96,13 @@ def add(self, parameter: Parameter) -> "ParameterInfo":
9096
# return a new ParameterInfo with parameter appended
9197
return ParameterInfo(parameters=[*list(self.items), parameter])
9298

93-
def merge(self, other: "ParameterInfo") -> "ParameterInfo":
99+
def merge(self, other: "ParameterInfo", allow_duplicates: bool = False) -> "ParameterInfo":
94100
"""Combine parameters from two ParameterInfo objects."""
95101
if not isinstance(other, ParameterInfo):
96102
raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo")
97103

98104
overlapping = set(self.names) & set(other.names)
99-
if overlapping:
105+
if overlapping and not allow_duplicates:
100106
raise ValueError(f"Duplicate parameter names found: {overlapping}")
101107

102108
return ParameterInfo(parameters=list(self.items) + list(other.items))
@@ -119,20 +125,24 @@ def __init__(self, data: list[Data]):
119125
def needs_exogenous_data(self) -> bool:
120126
return any(d.is_exogenous for d in self.items)
121127

128+
@property
129+
def exogenous_names(self) -> tuple[str, ...]:
130+
return tuple(d.name for d in self.items if d.is_exogenous)
131+
122132
def __str__(self) -> str:
123133
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
124134

125135
def add(self, data: Data) -> "DataInfo":
126136
# return a new DataInfo with data appended
127137
return DataInfo(data=[*list(self.items), data])
128138

129-
def merge(self, other: "DataInfo") -> "DataInfo":
139+
def merge(self, other: "DataInfo", allow_duplicates: bool = False) -> "DataInfo":
130140
"""Combine data from two DataInfo objects."""
131141
if not isinstance(other, DataInfo):
132142
raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo")
133143

134144
overlapping = set(self.names) & set(other.names)
135-
if overlapping:
145+
if overlapping and not allow_duplicates:
136146
raise ValueError(f"Duplicate data names found: {overlapping}")
137147

138148
return DataInfo(data=list(self.items) + list(other.items))
@@ -164,7 +174,7 @@ def default_coords_from_model(
164174
Self
165175
): # TODO: Need to figure out how to include Component type was causing circular import issues
166176
states = tuple(model.state_names)
167-
obs_states = tuple(model.observed_state_names)
177+
obs_states = tuple(model.observed_states)
168178
shocks = tuple(model.shock_names)
169179

170180
dim_to_labels = (
@@ -186,13 +196,13 @@ def add(self, coord: Coord) -> "CoordInfo":
186196
# return a new CoordInfo with data appended
187197
return CoordInfo(coords=[*list(self.items), coord])
188198

189-
def merge(self, other: "CoordInfo") -> "CoordInfo":
199+
def merge(self, other: "CoordInfo", allow_duplicates: bool = False) -> "CoordInfo":
190200
"""Combine data from two CoordInfo objects."""
191201
if not isinstance(other, CoordInfo):
192202
raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo")
193203

194204
overlapping = set(self.names) & set(other.names)
195-
if overlapping:
205+
if overlapping and not allow_duplicates:
196206
raise ValueError(f"Duplicate coord names found: {overlapping}")
197207

198208
return CoordInfo(coords=list(self.items) + list(other.items))
@@ -216,21 +226,37 @@ def __str__(self) -> str:
216226
)
217227

218228
@property
219-
def observed_states(self) -> tuple[State, ...]:
229+
def observed_states(self) -> tuple[State, ...]: # Is this needed??
220230
return tuple(s for s in self.items if s.observed)
221231

232+
@property
233+
def observed_state_names(self) -> tuple[State, ...]:
234+
return tuple(s.name for s in self.items if s.observed)
235+
236+
@property
237+
def unobserved_state_names(self) -> tuple[State, ...]:
238+
return tuple(s.name for s in self.items if not s.observed)
239+
222240
def add(self, state: State) -> "StateInfo":
223241
# return a new StateInfo with state appended
224242
return StateInfo(states=[*list(self.items), state])
225243

226-
def merge(self, other: "StateInfo") -> "StateInfo":
244+
def merge(self, other: "StateInfo", allow_duplicates: bool = False) -> "StateInfo":
227245
"""Combine states from two StateInfo objects."""
228246
if not isinstance(other, StateInfo):
229247
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")
230248

231249
overlapping = set(self.names) & set(other.names)
232-
if overlapping:
233-
raise ValueError(f"Duplicate state names found: {overlapping}")
250+
if overlapping and not allow_duplicates:
251+
# This is necessary for shared states
252+
warnings.warn(
253+
f"Duplicate state names found: {overlapping}. Merge will ONLY retain unique states",
254+
UserWarning,
255+
)
256+
return StateInfo(
257+
states=list(self.items)
258+
+ [item for item in other.items if item.name not in overlapping]
259+
)
234260

235261
return StateInfo(states=list(self.items) + list(other.items))
236262

@@ -249,13 +275,13 @@ def add(self, shock: Shock) -> "ShockInfo":
249275
# return a new ShockInfo with shock appended
250276
return ShockInfo(shocks=[*list(self.items), shock])
251277

252-
def merge(self, other: "ShockInfo") -> "ShockInfo":
278+
def merge(self, other: "ShockInfo", allow_duplicates: bool = False) -> "ShockInfo":
253279
"""Combine shocks from two ShockInfo objects."""
254280
if not isinstance(other, ShockInfo):
255281
raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo")
256282

257283
overlapping = set(self.names) & set(other.names)
258-
if overlapping:
284+
if overlapping and not allow_duplicates:
259285
raise ValueError(f"Duplicate shock names found: {overlapping}")
260286

261287
return ShockInfo(shocks=list(self.items) + list(other.items))

pymc_extras/statespace/models/structural/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
from pymc_extras.statespace.models.structural.components.level_trend import LevelTrendComponent
66
from pymc_extras.statespace.models.structural.components.measurement_error import MeasurementError
77
from pymc_extras.statespace.models.structural.components.regression import RegressionComponent
8-
from pymc_extras.statespace.models.structural.components.regression_dataclass import (
9-
RegressionComponent as RegressionComponentDataClass,
10-
)
118
from pymc_extras.statespace.models.structural.components.seasonality import (
129
FrequencySeasonality,
1310
TimeSeasonality,
@@ -20,6 +17,5 @@
2017
"LevelTrendComponent",
2118
"MeasurementError",
2219
"RegressionComponent",
23-
"RegressionComponentDataClass",
2420
"TimeSeasonality",
2521
]

pymc_extras/statespace/models/structural/components/regression.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -211,33 +211,23 @@ def _set_parameters(self) -> None:
211211
k_endog_effective = 1 if self.share_states else k_endog
212212
k_states = self.k_states // k_endog_effective
213213

214-
beta_param_name = f"beta_{self.name}"
215-
beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,)
216-
beta_param_dims = (
217-
(f"endog_{self.name}", f"state_{self.name}")
218-
if k_endog_effective > 1
219-
else (f"state_{self.name}",)
220-
)
221-
222-
beta_param_constraints = None
223214
beta_parameter = Parameter(
224-
name=beta_param_name,
225-
shape=beta_param_shape,
226-
dims=beta_param_dims,
227-
constraints=beta_param_constraints,
215+
name=f"beta_{self.name}",
216+
shape=(k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,),
217+
dims=(
218+
(f"endog_{self.name}", f"state_{self.name}")
219+
if k_endog_effective > 1
220+
else (f"state_{self.name}",)
221+
),
222+
constraints=None,
228223
)
229224

230225
if self.innovations:
231-
sigma_param_name = f"sigma_beta_{self.name}"
232-
sigma_param_dims = (f"state_{self.name}",)
233-
sigma_param_shape = (k_states,)
234-
sigma_param_constraints = "Positive"
235-
236226
sigma_parameter = Parameter(
237-
name=sigma_param_name,
238-
shape=sigma_param_shape,
239-
dims=sigma_param_dims,
240-
constraints=sigma_param_constraints,
227+
name=f"sigma_beta_{self.name}",
228+
shape=(k_states,),
229+
dims=(f"state_{self.name}",),
230+
constraints="Positive",
241231
)
242232

243233
self.param_info = ParameterInfo(parameters=[beta_parameter, sigma_parameter])
@@ -251,11 +241,12 @@ def _set_data(self) -> None:
251241
k_endog_effective = 1 if self.share_states else k_endog
252242
k_states = self.k_states // k_endog_effective
253243

254-
data_name = f"data_{self.name}"
255-
data_shape = (None, k_states)
256-
data_dims = (TIME_DIM, f"state_{self.name}")
257-
258-
data_prop = Data(name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True)
244+
data_prop = Data(
245+
name=f"data_{self.name}",
246+
shape=(None, k_states),
247+
dims=(TIME_DIM, f"state_{self.name}"),
248+
is_exogenous=True,
249+
)
259250
self.data_info = DataInfo(data=[data_prop])
260251
self.data_names = self.data_info.names
261252

@@ -274,19 +265,35 @@ def _set_states(self) -> None:
274265
if self.share_states:
275266
state_names = [f"{name}[{self.name}_shared]" for name in self.base_names]
276267
self.state_info = StateInfo(
277-
states=[State(name=name, observed=True, shared=True) for name in state_names]
268+
states=[State(name=name, observed=False, shared=True) for name in state_names]
269+
)
270+
self.state_info = self.state_info.merge(
271+
StateInfo(
272+
states=[
273+
State(name=name, observed=True, shared=False)
274+
for name in self.observed_state_names
275+
]
276+
)
278277
)
279-
self.state_names = self.state_info.names
278+
self.state_names = self.state_info.unobserved_state_names
280279
else:
281280
state_names = [
282281
f"{name}[{obs_name}]"
283282
for obs_name in self.observed_state_names
284283
for name in self.base_names
285284
]
286285
self.state_info = StateInfo(
287-
states=[State(name=name, observed=True, shared=False) for name in state_names]
286+
states=[State(name=name, observed=False, shared=False) for name in state_names]
288287
)
289-
self.state_names = self.state_info.names
288+
self.state_info = self.state_info.merge(
289+
StateInfo(
290+
states=[
291+
State(name=name, observed=True, shared=False)
292+
for name in self.observed_state_names
293+
]
294+
)
295+
)
296+
self.state_names = self.state_info.unobserved_state_names
290297

291298
def _set_coords(self) -> None:
292299
regression_state_coord = Coord(
@@ -296,20 +303,11 @@ def _set_coords(self) -> None:
296303
dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names]
297304
)
298305

299-
self.coords = CoordInfo(coords=[regression_state_coord, endogenous_state_coord])
306+
self.coords_info = CoordInfo(coords=[regression_state_coord, endogenous_state_coord])
300307

301308
def populate_component_properties(self) -> None:
302-
# Set parameter info
303309
self._set_parameters()
304-
305-
# Set data info
306310
self._set_data()
307-
308-
# Set shock info
309311
self._set_shocks()
310-
311-
# Set states info
312312
self._set_states()
313-
314-
# Set coordinates info
315313
self._set_coords()

0 commit comments

Comments
 (0)