Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Development - |version|
`RSO Inspection <examples/rso_inspection.ipynb>`_ example.
* Add a maximum duration option to :class:`~bsk_rl.act.Image`.
* Fix a bug where a satellite's initial data was never added to the rewarder.
* Fix a bug where using multiple of the same rewarder would cause some settings to be
overwritten.


Version 1.1.0
Expand Down
19 changes: 12 additions & 7 deletions src/bsk_rl/data/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
satellite: "Satellite",
*data_store_types: type[DataStore],
initial_data: Optional[ComposedData] = None,
data_store_kwargs: Optional[dict] = None,
data_store_kwargs: Optional[list] = None,
):
"""DataStore for composed data types.

Expand All @@ -84,14 +84,21 @@ def __init__(
data_store_types: DataStore types to compose.
initial_data: Initial data to start the store with. Usually comes from
:class:`~bsk_rl.data.GlobalReward.initial_data`.
data_store_kwargs: Dictionary mapping data_store types to their kwargs.
data_store_kwargs: List of data_store kwargs matching data_store_types.
"""
self.data: ComposedData
super().__init__(satellite, initial_data)
if data_store_kwargs is None:
data_store_kwargs = {ds: {} for ds in data_store_types}
data_store_kwargs = [{} for _ in data_store_types]

if len(data_store_types) != len(data_store_kwargs):
raise ValueError(
"data_store_types and data_store_kwargs must have the same length."
)

self.data_stores = tuple(
[ds(satellite, **data_store_kwargs[ds]) for ds in data_store_types]
ds(satellite, **kwargs)
for ds, kwargs in zip(data_store_types, data_store_kwargs)
)
self.pass_data()

Expand Down Expand Up @@ -197,9 +204,7 @@ def create_data_store(self, satellite: Satellite) -> None:
satellite,
*[r.data_store_type for r in self.rewarders],
initial_data=self.initial_data(satellite),
data_store_kwargs={
r.data_store_type: r.data_store_kwargs for r in self.rewarders
},
data_store_kwargs=[r.data_store_kwargs for r in self.rewarders],
)
self.cum_reward[satellite.name] = 0.0
for rewarder in self.rewarders:
Expand Down