diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 9d0dd8f6..8e110b61 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -47,6 +47,8 @@ Development - |version| `RSO Inspection `_ example. * Add a maximum duration option to :class:`~bsk_rl.act.Image`. * Fix a bug where a satellite's initial data was never added to the rewarder. +* Fix a bug where using multiple of the same rewarder would cause some settings to be + overwritten. Version 1.1.0 diff --git a/src/bsk_rl/data/composition.py b/src/bsk_rl/data/composition.py index 3422bf1f..8e6a2e1a 100644 --- a/src/bsk_rl/data/composition.py +++ b/src/bsk_rl/data/composition.py @@ -75,7 +75,7 @@ def __init__( satellite: "Satellite", *data_store_types: type[DataStore], initial_data: Optional[ComposedData] = None, - data_store_kwargs: Optional[dict] = None, + data_store_kwargs: Optional[list] = None, ): """DataStore for composed data types. @@ -84,14 +84,21 @@ def __init__( data_store_types: DataStore types to compose. initial_data: Initial data to start the store with. Usually comes from :class:`~bsk_rl.data.GlobalReward.initial_data`. - data_store_kwargs: Dictionary mapping data_store types to their kwargs. + data_store_kwargs: List of data_store kwargs matching data_store_types. """ self.data: ComposedData super().__init__(satellite, initial_data) if data_store_kwargs is None: - data_store_kwargs = {ds: {} for ds in data_store_types} + data_store_kwargs = [{} for _ in data_store_types] + + if len(data_store_types) != len(data_store_kwargs): + raise ValueError( + "data_store_types and data_store_kwargs must have the same length." + ) + self.data_stores = tuple( - [ds(satellite, **data_store_kwargs[ds]) for ds in data_store_types] + ds(satellite, **kwargs) + for ds, kwargs in zip(data_store_types, data_store_kwargs) ) self.pass_data() @@ -197,9 +204,7 @@ def create_data_store(self, satellite: Satellite) -> None: satellite, *[r.data_store_type for r in self.rewarders], initial_data=self.initial_data(satellite), - data_store_kwargs={ - r.data_store_type: r.data_store_kwargs for r in self.rewarders - }, + data_store_kwargs=[r.data_store_kwargs for r in self.rewarders], ) self.cum_reward[satellite.name] = 0.0 for rewarder in self.rewarders: