Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Development - |version|
* Add example script for shielded training with action replacement and action masking in `Shielded training with action masking and action replacement <examples/training_with_shield.ipynb>`_.
* Add ``bsk`` as a dependency in ``pyproject.toml``.
* Update the CI/CD workflows to build BSK-RL using the new ``bsk`` dependency.
* Optimize performance of AEOS environments, especially for high request counts.


Version 1.2.0
Expand Down
594 changes: 66 additions & 528 deletions examples/cloud_environment.ipynb

Large diffs are not rendered by default.

155 changes: 78 additions & 77 deletions examples/cloud_environment_with_reimaging.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@
"\n",
" def __init__(\n",
" self,\n",
" imaged: Optional[list[\"Target\"]] = None,\n",
" imaged: Optional[set[\"Target\"]] = None,\n",
" duplicates: int = 0,\n",
" known: Optional[list[\"Target\"]] = None,\n",
" cloud_covered: Optional[list[\"Target\"]] = None,\n",
" cloud_free: Optional[list[\"Target\"]] = None,\n",
" known: Optional[set[\"Target\"]] = None,\n",
" cloud_covered: Optional[set[\"Target\"]] = None,\n",
" cloud_free: Optional[set[\"Target\"]] = None,\n",
" ) -> None:\n",
" \"\"\"Construct unit of data to record unique images.\n",
"\n",
Expand All @@ -174,19 +174,19 @@
" ``cloud_covered`` and ``cloud_free`` based on the specified threshold.\n",
"\n",
" Args:\n",
" imaged: List of targets that are known to be imaged.\n",
" imaged: Set of targets that are known to be imaged.\n",
" duplicates: Count of target imaging duplication.\n",
" known: List of targets that are known to exist (imaged and unimaged).\n",
" cloud_covered: List of imaged targets that are known to be cloud covered.\n",
" cloud_free: List of imaged targets that are known to be cloud free.\n",
" known: Set of targets that are known to exist (imaged and unimaged).\n",
" cloud_covered: Set of imaged targets that are known to be cloud covered.\n",
" cloud_free: Set of imaged targets that are known to be cloud free.\n",
" \"\"\"\n",
" super().__init__(imaged=imaged, duplicates=duplicates, known=known)\n",
" if cloud_covered is None:\n",
" cloud_covered = []\n",
" cloud_covered = set()\n",
" if cloud_free is None:\n",
" cloud_free = []\n",
" self.cloud_covered = list(set(cloud_covered))\n",
" self.cloud_free = list(set(cloud_free))\n",
" cloud_free = set()\n",
" self.cloud_covered = set(cloud_covered)\n",
" self.cloud_free = set(cloud_free)\n",
"\n",
" def __add__(self, other: \"CloudImageBinaryData\") -> \"CloudImageBinaryData\":\n",
" \"\"\"Combine two units of data.\n",
Expand All @@ -198,17 +198,17 @@
" Combined unit of data.\n",
" \"\"\"\n",
"\n",
" imaged = list(set(self.imaged + other.imaged))\n",
" imaged = self.imaged | other.imaged\n",
" duplicates = (\n",
" self.duplicates\n",
" + other.duplicates\n",
" + len(self.imaged)\n",
" + len(other.imaged)\n",
" - len(imaged)\n",
" )\n",
" known = list(set(self.known + other.known))\n",
" cloud_covered = list(set(self.cloud_covered + other.cloud_covered))\n",
" cloud_free = list(set(self.cloud_free + other.cloud_free))\n",
" known = self.known | other.known\n",
" cloud_covered = self.cloud_covered | other.cloud_covered\n",
" cloud_free = self.cloud_free | other.cloud_free\n",
"\n",
" return self.__class__(\n",
" imaged=imaged,\n",
Expand Down Expand Up @@ -236,27 +236,25 @@
" Returns:\n",
" list: Targets imaged at new_state that were unimaged at old_state\n",
" \"\"\"\n",
" update_idx = np.where(new_state - old_state > 0)[0]\n",
" imaged = []\n",
" for idx in update_idx:\n",
" message = self.satellite.dynamics.storageUnit.storageUnitDataOutMsg\n",
" target_id = message.read().storedDataName[int(idx)]\n",
" imaged.append(\n",
" [target for target in self.data.known if target.id == target_id][0]\n",
" )\n",
"\n",
" cloud_covered = []\n",
" cloud_free = []\n",
" for target in imaged:\n",
" cloud_coverage = target.cloud_cover_true\n",
" if cloud_coverage > target.reward_threshold:\n",
" cloud_covered.append(target)\n",
" data_increase = new_state - old_state\n",
" if data_increase <= 0:\n",
" return UniqueImageData()\n",
" else:\n",
" assert self.satellite.latest_target is not None\n",
" self.update_target_colors([self.satellite.latest_target])\n",
" cloud_coverage = self.satellite.latest_target.cloud_cover_true\n",
" cloud_threshold = self.satellite.latest_target.reward_threshold\n",
" if cloud_coverage > cloud_threshold:\n",
" cloud_covered = [self.satellite.latest_target]\n",
" cloud_free = []\n",
" else:\n",
" cloud_free.append(target)\n",
"\n",
" return CloudImageBinaryData(\n",
" imaged=imaged, cloud_covered=cloud_covered, cloud_free=cloud_free\n",
" )\n",
" cloud_covered = []\n",
" cloud_free = [self.satellite.latest_target]\n",
" return CloudImageBinaryData(\n",
" imaged={self.satellite.latest_target},\n",
" cloud_covered=cloud_covered,\n",
" cloud_free=cloud_free,\n",
" )\n",
"\n",
"\n",
"class CloudImageBinaryRewarder(UniqueImageReward):\n",
Expand All @@ -276,14 +274,20 @@
" reward: Cumulative reward across satellites for one step\n",
" \"\"\"\n",
" reward = {}\n",
" imaged_counts = {}\n",
" for new_data in new_data_dict.values():\n",
" for target in new_data.imaged:\n",
" if target not in imaged_counts:\n",
" imaged_counts[target] = 0\n",
" imaged_counts[target] += 1\n",
"\n",
" for sat_id, new_data in new_data_dict.items():\n",
" reward[sat_id] = 0.0\n",
" for target in new_data.cloud_free:\n",
" reward[sat_id] += self.reward_fn(target.priority)\n",
"\n",
" for new_data in new_data_dict.values():\n",
" self.data += new_data\n",
" if target not in self.data.imaged:\n",
" reward[sat_id] += (\n",
" self.reward_fn(target.priority) / imaged_counts[target]\n",
" )\n",
" return reward\n",
"\n",
"\n",
Expand Down Expand Up @@ -326,9 +330,9 @@
" def __init__(\n",
" self,\n",
" imaged: Optional[list[\"Target\"]] = None,\n",
" imaged_complete: Optional[list[\"Target\"]] = None,\n",
" imaged_complete: Optional[set[\"Target\"]] = None,\n",
" list_belief_update_var: Optional[list[float]] = None,\n",
" known: Optional[list[\"Target\"]] = None,\n",
" known: Optional[set[\"Target\"]] = None,\n",
" ) -> None:\n",
" \"\"\"Construct unit of data to record unique images.\n",
"\n",
Expand All @@ -337,22 +341,22 @@
"\n",
" Args:\n",
" imaged: List of targets that are known to be imaged.\n",
" imaged_complete: List of targets that are known to be completely imaged (P(S=1) >= reward_threshold).\n",
" imaged_complete: Set of targets that are known to be completely imaged (P(S=1) >= reward_threshold).\n",
" list_belief_update_var: List of belief update variations for each target after each picture.\n",
" known: List of targets that are known to exist (imaged and not imaged)\n",
" \"\"\"\n",
" if imaged is None:\n",
" imaged = []\n",
" if imaged_complete is None:\n",
" imaged_complete = []\n",
" imaged_complete = set()\n",
" if list_belief_update_var is None:\n",
" list_belief_update_var = []\n",
" if known is None:\n",
" known = []\n",
" self.known = list(set(known))\n",
" known = set()\n",
" self.known = set(known)\n",
"\n",
" self.imaged = list(imaged)\n",
" self.imaged_complete = list(set(imaged_complete))\n",
" self.imaged = imaged\n",
" self.imaged_complete = imaged_complete\n",
" self.list_belief_update_var = list(list_belief_update_var)\n",
"\n",
" def __add__(\n",
Expand All @@ -367,13 +371,13 @@
" Combined unit of data.\n",
" \"\"\"\n",
"\n",
" imaged = list(self.imaged + other.imaged)\n",
" imaged_complete = list(set(self.imaged_complete + other.imaged_complete))\n",
" list_belief_update_var = list(\n",
" imaged = self.imaged + other.imaged\n",
" imaged_complete = self.imaged_complete | other.imaged_complete\n",
" list_belief_update_var = (\n",
" self.list_belief_update_var + other.list_belief_update_var\n",
" )\n",
"\n",
" known = list(set(self.known + other.known))\n",
" known = self.known | other.known\n",
" return self.__class__(\n",
" imaged=imaged,\n",
" imaged_complete=imaged_complete,\n",
Expand Down Expand Up @@ -401,9 +405,8 @@
" Returns:\n",
" array: storedData from satellite storage unit\n",
" \"\"\"\n",
" return np.array(\n",
" self.satellite.dynamics.storageUnit.storageUnitDataOutMsg.read().storedData\n",
" )\n",
" msg = self.satellite.dynamics.storageUnit.storageUnitDataOutMsg.read()\n",
" return msg.storedData[0]\n",
"\n",
" def compare_log_states(\n",
" self, old_state: np.ndarray, new_state: np.ndarray\n",
Expand All @@ -421,22 +424,18 @@
" Returns:\n",
" list: Targets imaged at new_state that were unimaged at old_state\n",
" \"\"\"\n",
" update_idx = np.where(new_state - old_state > 0)[0]\n",
" imaged = []\n",
" for idx in update_idx:\n",
" message = self.satellite.dynamics.storageUnit.storageUnitDataOutMsg\n",
" target_id = message.read().storedDataName[int(idx)]\n",
" imaged.append(\n",
" [target for target in self.data.known if target.id == target_id][0]\n",
" )\n",
"\n",
" list_imaged_complete = []\n",
" list_belief_update_var = []\n",
" data_increase = new_state - old_state\n",
" if data_increase <= 0:\n",
" return CloudImageProbabilityData()\n",
" else:\n",
" assert self.satellite.latest_target is not None\n",
" # return UniqueImageData(imaged={self.satellite.latest_target})\n",
"\n",
" current_sim_time = self.satellite.simulator.sim_time\n",
" belief_update_func = self.satellite.belief_update_func\n",
" target = self.satellite.latest_target\n",
" current_sim_time = self.satellite.simulator.sim_time\n",
" belief_update_func = self.satellite.belief_update_func\n",
"\n",
" for target in imaged:\n",
" target_prev_obs = (\n",
" target.prev_obs\n",
" ) # Time at which the target was previously observed\n",
Expand All @@ -457,18 +456,20 @@
" target.prev_obs = current_sim_time # Update the previous observation time\n",
"\n",
" if updated_belief[1] > target.reward_threshold:\n",
" list_imaged_complete.append(target)\n",
" list_belief_update_var.append(target.belief_update_var)\n",
" list_imaged_complete = [target]\n",
" else:\n",
" list_imaged_complete = []\n",
" list_belief_update_var = target.belief_update_var\n",
"\n",
" return CloudImageProbabilityData(\n",
" imaged=imaged,\n",
" imaged_complete=list_imaged_complete,\n",
" list_belief_update_var=list_belief_update_var,\n",
" )\n",
" return CloudImageProbabilityData(\n",
" imaged=[target],\n",
" imaged_complete=set(list_imaged_complete),\n",
" list_belief_update_var=[list_belief_update_var],\n",
" )\n",
"\n",
"\n",
"class CloudImageProbabilityRewarder(GlobalReward):\n",
" datastore_type = CloudImageProbabilityDataStore\n",
" data_store_type = CloudImageProbabilityDataStore\n",
"\n",
" def __init__(\n",
" self,\n",
Expand Down Expand Up @@ -505,6 +506,7 @@
" Returns:\n",
" reward: Cumulative reward across satellites for one step\n",
" \"\"\"\n",
"\n",
" reward = {}\n",
"\n",
" for sat_id, new_data in new_data_dict.items():\n",
Expand All @@ -519,7 +521,6 @@
" reward[sat_id] += self.reward_fn(\n",
" target.priority, None, self.alpha, reach_threshold=True\n",
" )\n",
"\n",
" return reward\n",
"\n",
"\n",
Expand Down
Loading