Skip to content

Commit 0831e38

Browse files
committed
Update Agent contructor in jax
1 parent ca647a9 commit 0831e38

File tree

10 files changed

+148
-299
lines changed

10 files changed

+148
-299
lines changed

skrl/agents/jax/a2c/a2c.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Mapping, Optional, Tuple, Union
1+
from typing import Any, Mapping, Optional, Union
22

33
import copy
44
import functools
@@ -186,32 +186,26 @@ def _value_loss(params):
186186
class A2C(Agent):
187187
def __init__(
188188
self,
189-
models: Mapping[str, Model],
190-
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
191-
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
192-
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
189+
*,
190+
models: Optional[Mapping[str, Model]] = None,
191+
memory: Optional[Memory] = None,
192+
observation_space: Optional[gymnasium.Space] = None,
193+
state_space: Optional[gymnasium.Space] = None,
194+
action_space: Optional[gymnasium.Space] = None,
193195
device: Optional[Union[str, jax.Device]] = None,
194196
cfg: Optional[dict] = None,
195197
) -> None:
196198
"""Advantage Actor Critic (A2C)
197199
198200
https://arxiv.org/abs/1602.01783
199201
200-
:param models: Models used by the agent
201-
:type models: dictionary of skrl.models.jax.Model
202-
:param memory: Memory to storage the transitions.
203-
If it is a tuple, the first element will be used for training and
204-
for the rest only the environment transitions will be added
205-
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
206-
:param observation_space: Observation/state space or shape (default: ``None``)
207-
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
208-
:param action_space: Action space or shape (default: ``None``)
209-
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
210-
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
211-
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
212-
:type device: str or jax.Device, optional
213-
:param cfg: Configuration dictionary
214-
:type cfg: dict
202+
:param models: Agent's models.
203+
:param memory: Memory to storage agent's data and environment transitions.
204+
:param observation_space: Observation space.
205+
:param state_space: State space.
206+
:param action_space: Action space.
207+
:param device: Data allocation and computation device. If not specified, the default device will be used.
208+
:param cfg: Agent's configuration.
215209
216210
:raises KeyError: If the models dictionary is missing a required key
217211
"""
@@ -425,17 +419,6 @@ def record_transition(
425419
log_prob=self._current_log_prob,
426420
values=values,
427421
)
428-
for memory in self.secondary_memories:
429-
memory.add_samples(
430-
states=states,
431-
actions=actions,
432-
rewards=rewards,
433-
next_states=next_states,
434-
terminated=terminated,
435-
truncated=truncated,
436-
log_prob=self._current_log_prob,
437-
values=values,
438-
)
439422

440423
def pre_interaction(self, timestep: int, timesteps: int) -> None:
441424
"""Callback called before the interaction with the environment

skrl/agents/jax/base.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Any, Mapping, Optional, Tuple, Union
1+
from typing import Any, Mapping, Optional, Union
22

33
import collections
44
import copy
55
import datetime
66
import os
77
import pickle
8+
from abc import ABC, abstractmethod
89
import gymnasium
910

1011
import flax
@@ -16,55 +17,46 @@
1617
from skrl.models.jax import Model
1718

1819

19-
class Agent:
20+
class Agent(ABC):
2021
def __init__(
2122
self,
22-
models: Mapping[str, Model],
23-
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
24-
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
25-
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
23+
*,
24+
models: Optional[Mapping[str, Model]] = None,
25+
memory: Optional[Memory] = None,
26+
observation_space: Optional[gymnasium.Space] = None,
27+
state_space: Optional[gymnasium.Space] = None,
28+
action_space: Optional[gymnasium.Space] = None,
2629
device: Optional[Union[str, jax.Device]] = None,
2730
cfg: Optional[dict] = None,
2831
) -> None:
29-
"""Base class that represent a RL agent
30-
31-
:param models: Models used by the agent
32-
:type models: dictionary of skrl.models.jax.Model
33-
:param memory: Memory to storage the transitions.
34-
If it is a tuple, the first element will be used for training and
35-
for the rest only the environment transitions will be added
36-
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
37-
:param observation_space: Observation/state space or shape (default: ``None``)
38-
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
39-
:param action_space: Action space or shape (default: ``None``)
40-
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
41-
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
42-
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
43-
:type device: str or jax.Device, optional
44-
:param cfg: Configuration dictionary
45-
:type cfg: dict
32+
"""Base class that represent a RL agent/algorithm.
33+
34+
:param models: Agent's models.
35+
:param memory: Memory to storage agent's data and environment transitions.
36+
:param observation_space: Observation space.
37+
:param state_space: State space.
38+
:param action_space: Action space.
39+
:param device: Data allocation and computation device. If not specified, the default device will be used.
40+
:param cfg: Agent's configuration.
4641
"""
4742
self._jax = config.jax.backend == "jax"
43+
self.training = True
4844

4945
self.models = models
46+
self.memory = memory
5047
self.observation_space = observation_space
48+
self.state_space = state_space
5149
self.action_space = action_space
5250
self.cfg = cfg if cfg is not None else {}
5351

5452
self.device = config.jax.parse_device(device)
5553

56-
if type(memory) is list:
57-
self.memory = memory[0]
58-
self.secondary_memories = memory[1:]
59-
else:
60-
self.memory = memory
61-
self.secondary_memories = []
62-
6354
# convert the models to their respective device
6455
for model in self.models.values():
6556
if model is not None:
6657
pass
6758

59+
# data tracking
6860
self.tracking_data = collections.defaultdict(list)
6961
self.write_interval = self.cfg.get("experiment", {}).get("write_interval", "auto")
7062

@@ -73,8 +65,6 @@ def __init__(
7365
self._cumulative_rewards = None
7466
self._cumulative_timesteps = None
7567

76-
self.training = True
77-
7868
# checkpoint
7969
self.checkpoint_modules = {}
8070
self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", "auto")

skrl/agents/jax/cem/cem.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Mapping, Optional, Tuple, Union
1+
from typing import Any, Mapping, Optional, Union
22

33
import copy
44
import gymnasium
@@ -54,32 +54,26 @@
5454
class CEM(Agent):
5555
def __init__(
5656
self,
57-
models: Mapping[str, Model],
58-
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
59-
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
60-
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
57+
*,
58+
models: Optional[Mapping[str, Model]] = None,
59+
memory: Optional[Memory] = None,
60+
observation_space: Optional[gymnasium.Space] = None,
61+
state_space: Optional[gymnasium.Space] = None,
62+
action_space: Optional[gymnasium.Space] = None,
6163
device: Optional[Union[str, jax.Device]] = None,
6264
cfg: Optional[dict] = None,
6365
) -> None:
6466
"""Cross-Entropy Method (CEM)
6567
6668
https://ieeexplore.ieee.org/abstract/document/6796865/
6769
68-
:param models: Models used by the agent
69-
:type models: dictionary of skrl.models.jax.Model
70-
:param memory: Memory to storage the transitions.
71-
If it is a tuple, the first element will be used for training and
72-
for the rest only the environment transitions will be added
73-
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
74-
:param observation_space: Observation/state space or shape (default: ``None``)
75-
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
76-
:param action_space: Action space or shape (default: ``None``)
77-
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
78-
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
79-
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
80-
:type device: str or jax.Device, optional
81-
:param cfg: Configuration dictionary
82-
:type cfg: dict
70+
:param models: Agent's models.
71+
:param memory: Memory to storage agent's data and environment transitions.
72+
:param observation_space: Observation space.
73+
:param state_space: State space.
74+
:param action_space: Action space.
75+
:param device: Data allocation and computation device. If not specified, the default device will be used.
76+
:param cfg: Agent's configuration.
8377
8478
:raises KeyError: If the models dictionary is missing a required key
8579
"""
@@ -235,15 +229,6 @@ def record_transition(
235229
terminated=terminated,
236230
truncated=truncated,
237231
)
238-
for memory in self.secondary_memories:
239-
memory.add_samples(
240-
states=states,
241-
actions=actions,
242-
rewards=rewards,
243-
next_states=next_states,
244-
terminated=terminated,
245-
truncated=truncated,
246-
)
247232

248233
# track episodes internally
249234
if self._rollout:

skrl/agents/jax/ddpg/ddpg.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Mapping, Optional, Tuple, Union
1+
from typing import Any, Mapping, Optional, Union
22

33
import copy
44
import functools
@@ -117,32 +117,26 @@ def _policy_loss(policy_params, critic_params):
117117
class DDPG(Agent):
118118
def __init__(
119119
self,
120-
models: Mapping[str, Model],
121-
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
122-
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
123-
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
120+
*,
121+
models: Optional[Mapping[str, Model]] = None,
122+
memory: Optional[Memory] = None,
123+
observation_space: Optional[gymnasium.Space] = None,
124+
state_space: Optional[gymnasium.Space] = None,
125+
action_space: Optional[gymnasium.Space] = None,
124126
device: Optional[Union[str, jax.Device]] = None,
125127
cfg: Optional[dict] = None,
126128
) -> None:
127129
"""Deep Deterministic Policy Gradient (DDPG)
128130
129131
https://arxiv.org/abs/1509.02971
130132
131-
:param models: Models used by the agent
132-
:type models: dictionary of skrl.models.jax.Model
133-
:param memory: Memory to storage the transitions.
134-
If it is a tuple, the first element will be used for training and
135-
for the rest only the environment transitions will be added
136-
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
137-
:param observation_space: Observation/state space or shape (default: ``None``)
138-
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
139-
:param action_space: Action space or shape (default: ``None``)
140-
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
141-
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
142-
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
143-
:type device: str or jax.Device, optional
144-
:param cfg: Configuration dictionary
145-
:type cfg: dict
133+
:param models: Agent's models.
134+
:param memory: Memory to storage agent's data and environment transitions.
135+
:param observation_space: Observation space.
136+
:param state_space: State space.
137+
:param action_space: Action space.
138+
:param device: Data allocation and computation device. If not specified, the default device will be used.
139+
:param cfg: Agent's configuration.
146140
147141
:raises KeyError: If the models dictionary is missing a required key
148142
"""
@@ -388,15 +382,6 @@ def record_transition(
388382
terminated=terminated,
389383
truncated=truncated,
390384
)
391-
for memory in self.secondary_memories:
392-
memory.add_samples(
393-
states=states,
394-
actions=actions,
395-
rewards=rewards,
396-
next_states=next_states,
397-
terminated=terminated,
398-
truncated=truncated,
399-
)
400385

401386
def pre_interaction(self, timestep: int, timesteps: int) -> None:
402387
"""Callback called before the interaction with the environment

skrl/agents/jax/dqn/ddqn.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Mapping, Optional, Tuple, Union
1+
from typing import Any, Mapping, Optional, Union
22

33
import copy
44
import functools
@@ -97,32 +97,26 @@ def _q_network_loss(params):
9797
class DDQN(Agent):
9898
def __init__(
9999
self,
100-
models: Mapping[str, Model],
101-
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
102-
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
103-
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
100+
*,
101+
models: Optional[Mapping[str, Model]] = None,
102+
memory: Optional[Memory] = None,
103+
observation_space: Optional[gymnasium.Space] = None,
104+
state_space: Optional[gymnasium.Space] = None,
105+
action_space: Optional[gymnasium.Space] = None,
104106
device: Optional[Union[str, jax.Device]] = None,
105107
cfg: Optional[dict] = None,
106108
) -> None:
107109
"""Double Deep Q-Network (DDQN)
108110
109111
https://ojs.aaai.org/index.php/AAAI/article/view/10295
110112
111-
:param models: Models used by the agent
112-
:type models: dictionary of skrl.models.jax.Model
113-
:param memory: Memory to storage the transitions.
114-
If it is a tuple, the first element will be used for training and
115-
for the rest only the environment transitions will be added
116-
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
117-
:param observation_space: Observation/state space or shape (default: ``None``)
118-
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
119-
:param action_space: Action space or shape (default: ``None``)
120-
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
121-
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
122-
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
123-
:type device: str or jax.Device, optional
124-
:param cfg: Configuration dictionary
125-
:type cfg: dict
113+
:param models: Agent's models.
114+
:param memory: Memory to storage agent's data and environment transitions.
115+
:param observation_space: Observation space.
116+
:param state_space: State space.
117+
:param action_space: Action space.
118+
:param device: Data allocation and computation device. If not specified, the default device will be used.
119+
:param cfg: Agent's configuration.
126120
127121
:raises KeyError: If the models dictionary is missing a required key
128122
"""
@@ -326,15 +320,6 @@ def record_transition(
326320
terminated=terminated,
327321
truncated=truncated,
328322
)
329-
for memory in self.secondary_memories:
330-
memory.add_samples(
331-
states=states,
332-
actions=actions,
333-
rewards=rewards,
334-
next_states=next_states,
335-
terminated=terminated,
336-
truncated=truncated,
337-
)
338323

339324
def pre_interaction(self, timestep: int, timesteps: int) -> None:
340325
"""Callback called before the interaction with the environment

0 commit comments

Comments
 (0)