-
Notifications
You must be signed in to change notification settings - Fork 363
/
Copy pathtransforms.py
11021 lines (9736 loc) · 452 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Plobs_dictnc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
import hashlib
import importlib.util
import multiprocessing as mp
import time
import warnings
import weakref
from copy import copy
from enum import IntEnum
from functools import wraps
from textwrap import indent
from typing import Any, Callable, Mapping, OrderedDict, Sequence, TypeVar, Union
import numpy as np
import torch
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
NonTensorData,
NonTensorStack,
set_lazy_legacy,
TensorDict,
TensorDictBase,
unravel_key,
unravel_key_list,
)
from tensordict.base import _is_leaf_nontensor
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import (
_unravel_key_to_tuple,
_zip_strict,
expand_as_right,
expand_right,
NestedKey,
)
from torch import nn, Tensor
from torch.utils._pytree import tree_map
from torchrl._utils import (
_append_last,
_ends_with,
_make_ordinal_device,
_replace_last,
auto_unwrap_transformed_env,
logger as torchrl_logger,
)
from torchrl.data.tensor_specs import (
Binary,
Bounded,
BoundedContinuous,
Categorical,
Composite,
ContinuousBox,
MultiCategorical,
MultiOneHot,
OneHot,
TensorSpec,
Unbounded,
UnboundedContinuous,
)
from torchrl.envs.common import (
_do_nothing,
_EnvPostInit,
_maybe_unlock,
EnvBase,
make_tensordict,
)
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import (
_get_reset,
_set_missing_tolerance,
check_finite,
)
from torchrl.envs.utils import (
_sort_keys,
_update_during_reset,
make_composite_from_td,
step_mdp,
)
_has_tv = importlib.util.find_spec("torchvision", None) is not None
IMAGE_KEYS = ["pixels"]
_MAX_NOOPS_TRIALS = 10
FORWARD_NOT_IMPLEMENTED = "class {} cannot be executed without a parent environment."
T = TypeVar("T", bound="Transform")
def _apply_to_composite(function):
@wraps(function)
def new_fun(self, observation_spec):
if isinstance(observation_spec, Composite):
_specs = observation_spec._specs
in_keys = self.in_keys
out_keys = self.out_keys
for in_key, out_key in _zip_strict(in_keys, out_keys):
if in_key in observation_spec.keys(True, True):
_specs[out_key] = function(self, observation_spec[in_key].clone())
return Composite(
_specs, shape=observation_spec.shape, device=observation_spec.device
)
else:
return function(self, observation_spec)
return new_fun
def _apply_to_composite_inv(function):
# Changes the input_spec following a transform function.
# The usage is: if an env expects a certain input (e.g. a double tensor)
# but the input has to be transformed (e.g. it is float), this function will
# modify the spec to get a spec that from the outside matches what is given
# (ie a float).
# Now since EnvBase.step ignores new inputs (ie the root level of the
# tensor is not updated) an out_key that does not match the in_key has
# no effect on the spec.
@wraps(function)
def new_fun(self, input_spec):
if "full_action_spec" in input_spec.keys():
skip = False
action_spec = input_spec["full_action_spec"].clone()
state_spec = input_spec["full_state_spec"]
if state_spec is None:
state_spec = Composite(shape=input_spec.shape, device=input_spec.device)
else:
state_spec = state_spec.clone()
else:
skip = True
# In case we pass full_action_spec or full_state_spec directly
action_spec = state_spec = Composite()
in_keys_inv = self.in_keys_inv
out_keys_inv = self.out_keys_inv
for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv):
in_key = unravel_key(in_key)
out_key = unravel_key(out_key)
# if in_key != out_key:
# # we only change the input spec if the key is the same
# continue
if in_key in action_spec.keys(True, True):
action_spec[out_key] = function(self, action_spec[in_key].clone())
if in_key != out_key:
del action_spec[in_key]
elif in_key in state_spec.keys(True, True):
state_spec[out_key] = function(self, state_spec[in_key].clone())
if in_key != out_key:
del state_spec[in_key]
elif in_key in input_spec.keys(False, True):
input_spec[out_key] = function(self, input_spec[in_key].clone())
if in_key != out_key:
del input_spec[in_key]
if skip:
return input_spec
return Composite(
full_state_spec=state_spec,
full_action_spec=action_spec,
shape=input_spec.shape,
device=input_spec.device,
)
return new_fun
class Transform(nn.Module):
"""Base class for environment transforms, which modify or create new data in a tensordict.
Transforms are used to manipulate the input and output data of an environment. They can be used to preprocess
observations, modify rewards, or transform actions. Transforms can be composed together to create more complex
transformations.
A transform receives a tensordict as input and returns (the same or another) tensordict as output, where a series
of values have been modified or created with a new key.
Attributes:
parent: The parent environment of the transform.
container: The container that holds the transform.
in_keys: The keys of the input tensordict that the transform will read from.
out_keys: The keys of the output tensordict that the transform will write to.
.. seealso:: :ref:`TorchRL transforms <transforms>`.
Subclassing `Transform`:
There are various ways of subclassing a transform. The things to take into considerations are:
- Is the transform identical for each tensor / item being transformed? Use
:meth:`~torchrl.envs.Transform._apply_transform` and :meth:`~torchrl.envs.Transform._inv_apply_transform`.
- The transform needs access to the input data to env.step as well as output? Rewrite
:meth:`~torchrl.envs.Transform._step`.
Otherwise, rewrite :meth:`~torchrl.envs.Transform._call` (or :meth:`~torchrl.envs.Transform._inv_call`).
- Is the transform to be used within a replay buffer? Overwrite :meth:`~torchrl.envs.Transform.forward`,
:meth:`~torchrl.envs.Transform.inv`, :meth:`~torchrl.envs.Transform._apply_transform` or
:meth:`~torchrl.envs.Transform._inv_apply_transform`.
- Within a transform, you can access (and make calls to) the parent environment using
:attr:`~torchrl.envs.Transform.parent` (the base env + all transforms till this one) or
:meth:`~torchrl.envs.Transform.container` (The object that encapsulates the transform).
- Don't forget to edits the specs if needed: top level: :meth:`~torchrl.envs.Transform.transform_output_spec`,
:meth:`~torchrl.envs.Transform.transform_input_spec`.
Leaf level: :meth:`~torchrl.envs.Transform.transform_observation_spec`,
:meth:`~torchrl.envs.Transform.transform_action_spec`, :meth:`~torchrl.envs.Transform.transform_state_spec`,
:meth:`~torchrl.envs.Transform.transform_reward_spec` and
:meth:`~torchrl.envs.Transform.transform_reward_spec`.
For practical examples, see the methods listed above.
Methods:
clone: creates a copy of the tensordict, without parent (a transform object can only have one parent).
set_container: Sets the container for the transform, and in turn the parent if the container is or has one
an environment within.
reset_parent: resets the parent and container caches.
"""
invertible = False
enable_inv_on_reset = False
def __init__(
self,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
):
super().__init__()
self.in_keys = in_keys
self.out_keys = out_keys
self.in_keys_inv = in_keys_inv
self.out_keys_inv = out_keys_inv
self._missing_tolerance = False
# we use __dict__ to avoid having nn.Module placing these objects in the module list
self.__dict__["_container"] = None
self.__dict__["_parent"] = None
@property
def in_keys(self):
in_keys = self.__dict__.get("_in_keys", None)
if in_keys is None:
return []
return in_keys
@in_keys.setter
def in_keys(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._in_keys = value
@property
def out_keys(self):
out_keys = self.__dict__.get("_out_keys", None)
if out_keys is None:
return []
return out_keys
@out_keys.setter
def out_keys(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._out_keys = value
@property
def in_keys_inv(self):
in_keys_inv = self.__dict__.get("_in_keys_inv", None)
if in_keys_inv is None:
return []
return in_keys_inv
@in_keys_inv.setter
def in_keys_inv(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._in_keys_inv = value
@property
def out_keys_inv(self):
out_keys_inv = self.__dict__.get("_out_keys_inv", None)
if out_keys_inv is None:
return []
return out_keys_inv
@out_keys_inv.setter
def out_keys_inv(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._out_keys_inv = value
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
"""Resets a transform if it is stateful."""
return tensordict_reset
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Inverts the input to :meth:`TransformedEnv._reset`, if needed."""
if self.enable_inv_on_reset and tensordict is not None:
with _set_missing_tolerance(self, True):
tensordict = self._inv_call(tensordict)
return tensordict
def init(self, tensordict) -> None:
"""Runs init steps for the transform."""
def _apply_transform(self, obs: torch.Tensor) -> None:
"""Applies the transform to a tensor or a leaf.
This operation can be called multiple times (if multiples keys of the
tensordict match the keys of the transform) for each entry in ``self.in_keys``
after the `TransformedEnv().base_env.step` is undertaken.
Examples:
>>> class AddOneToObs(Transform):
... '''A transform that adds 1 to the observation tensor.'''
... def __init__(self):
... super().__init__(in_keys=["observation"], out_keys=["observation"])
...
... def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
... return obs + 1
"""
raise NotImplementedError(
f"{self.__class__.__name__}._apply_transform is not coded. If the transform is coded in "
"transform._call, make sure that this method is called instead of"
"transform.forward, which is reserved for usage inside nn.Modules"
"or appended to a replay buffer."
)
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
"""The parent method of a transform during the ``env.step`` execution.
This method should be overwritten whenever the :meth:`_step` needs to be
adapted. Unlike :meth:`_call`, it is assumed that :meth:`_step`
will execute some operation with the parent env or that it requires
access to the content of the tensordict at time ``t`` and not only
``t+1`` (the ``"next"`` entry in the input tensordict).
:meth:`_step` will only be called by :meth:`TransformedEnv.step` and
not by :meth:`TransformedEnv.reset`.
Args:
tensordict (TensorDictBase): data at time t
next_tensordict (TensorDictBase): data at time t+1
Returns: the data at t+1
Examples:
>>> class AddActionToObservation(Transform):
... '''A transform that adds the action to the observation tensor.'''
... def _step(
... self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
... ) -> TensorDictBase:
... # This can only be done if we have access to the 'root' tensordict
... next_tensordict["observation"] += tensordict["action"]
... return next_tensordict
"""
next_tensordict = self._call(next_tensordict)
return next_tensordict
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
"""Reads the input tensordict, and for the selected keys, applies the transform.
``_call`` can be re-written whenever a modification of the output of env.step needs to be modified independently
of the data collected in the previous step (including actions and states).
For any operation that relates exclusively to the parent env (e.g. ``FrameSkip``),
modify the :meth:`~torchrl.envs.Transform._step` method instead.
:meth:`_call` should only be overwritten if a modification of the input tensordict is needed.
:meth:`_call` will be called by :meth:`~torchrl.envs.TransformedEnv.step` and
:meth:`~torchrl.envs.TransformedEnv.reset` but not during :meth:`~torchrl.envs.Transform.forward`.
"""
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
value = next_tensordict.get(in_key, default=None)
if value is not None:
observation = self._apply_transform(value)
next_tensordict.set(
out_key,
observation,
)
elif not self.missing_tolerance:
raise KeyError(
f"{self}: '{in_key}' not found in tensordict {next_tensordict}"
)
return next_tensordict
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
- calls directly :meth:`~torchrl.envs.Transform._apply_transform`.
- does not call :meth:`~torchrl.envs.Transform._step` or :meth:`~torchrl.envs.Transform._call`.
This method is not called within `env.step` at any point. However, is is called within
:meth:`~torchrl.data.ReplayBuffer.sample`.
.. note:: ``forward`` also works with regular keyword arguments using :class:`~tensordict.nn.dispatch` to cast the args
names to the keys.
Examples:
>>> class TransformThatMeasuresBytes(Transform):
... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
... def __init__(self):
... super().__init__(in_keys=[], out_keys=["bytes"])
...
... def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
... bytes_in_td = tensordict.bytes()
... tensordict["bytes"] = bytes
... return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0)) # Works offline too.
"""
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
data = tensordict.get(in_key, None)
if data is not None:
data = self._apply_transform(data)
tensordict.set(out_key, data)
elif not self.missing_tolerance:
raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
return tensordict
def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
"""Applies the inverse transform to a tensor or a leaf.
This operation can be called multiple times (if multiples keys of the
tensordict match the keys of the transform) for each entry in ``self.in_keys_inv``
before the `TransformedEnv().base_env.step` is undertaken.
Examples:
>>> class AddOneToAction(Transform):
... '''A transform that adds 1 to the action tensor.'''
... def __init__(self):
... super().__init__(in_keys=[], out_keys=[], in_keys_inv=["action"], out_keys_inv=["action"])
...
... def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor:
... return action + 1
"""
if self.invertible:
raise NotImplementedError
else:
return state
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Reads and possibly modify the input tensordict before it is passed to :meth:`~torchrl.envs.EnvBase.step`.
Examples:
>>> class AddOneToAllTensorDictBeforeStep(Transform):
... '''Adds 1 to the whole content of the input to the env before the step is taken.'''
...
... def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
... return tensordict + 1
"""
if not self.in_keys_inv:
return tensordict
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
data = tensordict.get(out_key, None)
if data is not None:
item = self._inv_apply_transform(data)
tensordict.set(in_key, item)
elif not self.missing_tolerance:
raise KeyError(f"'{out_key}' not found in tensordict {tensordict}")
return tensordict
@dispatch(source="in_keys_inv", dest="out_keys_inv")
def inv(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Reads the input tensordict, and for the selected keys, applies the inverse transform.
By default, this method:
- calls directly :meth:`~torchrl.envs.Transform._inv_apply_transform`.
- does not call :meth:`~torchrl.envs.Transform._inv_call`.
.. note:: ``inv`` also works with regular keyword arguments using :class:`~tensordict.nn.dispatch` to cast the args
names to the keys.
.. note:: ``inv`` is called by :meth:`~torchrl.data.ReplayBuffer.extend`.
"""
def clone(data):
try:
# we privilege speed for tensordicts
return data.clone(recurse=False)
except AttributeError:
return tree_map(lambda x: x, data)
except TypeError:
return tree_map(lambda x: x, data)
out = self._inv_call(clone(tensordict))
return out
def transform_env_device(self, device: torch.device):
"""Transforms the device of the parent env."""
return device
def transform_env_batch_size(self, batch_size: torch.Size):
"""Transforms the batch-size of the parent env."""
return batch_size
def transform_output_spec(self, output_spec: Composite) -> Composite:
"""Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
:meth:`transform_observation_spec`, :meth:`transform_reward_spec` and :meth:`transform_full_done_spec`.
Args:
output_spec (TensorSpec): spec before the transform
Returns:
expected spec after the transform
"""
output_spec = output_spec.clone()
output_spec["full_observation_spec"] = self.transform_observation_spec(
output_spec["full_observation_spec"]
)
if "full_reward_spec" in output_spec.keys():
output_spec["full_reward_spec"] = self.transform_reward_spec(
output_spec["full_reward_spec"]
)
if "full_done_spec" in output_spec.keys():
output_spec["full_done_spec"] = self.transform_done_spec(
output_spec["full_done_spec"]
)
output_spec_keys = [
unravel_key(k[1:]) for k in output_spec.keys(True) if isinstance(k, tuple)
]
out_keys = {unravel_key(k) for k in self.out_keys}
in_keys = {unravel_key(k) for k in self.in_keys}
for key in out_keys - in_keys:
if unravel_key(key) not in output_spec_keys:
warnings.warn(
f"The key '{key}' is unaccounted for by the transform (expected keys {output_spec_keys}). "
f"Every new entry in the tensordict resulting from a call to a transform must be "
f"registered in the specs for torchrl rollouts to be consistently built. "
f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. "
"This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly.",
category=FutureWarning,
)
return output_spec
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
"""Transforms the input spec such that the resulting spec matches transform mapping.
Args:
input_spec (TensorSpec): spec before the transform
Returns:
expected spec after the transform
"""
input_spec = input_spec.clone()
input_spec["full_state_spec"] = self.transform_state_spec(
input_spec["full_state_spec"]
)
input_spec["full_action_spec"] = self.transform_action_spec(
input_spec["full_action_spec"]
)
return input_spec
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
"""Transforms the observation spec such that the resulting spec matches transform mapping.
Args:
observation_spec (TensorSpec): spec before the transform
Returns:
expected spec after the transform
"""
return observation_spec
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
"""Transforms the reward spec such that the resulting spec matches transform mapping.
Args:
reward_spec (TensorSpec): spec before the transform
Returns:
expected spec after the transform
"""
return reward_spec
def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec:
"""Transforms the done spec such that the resulting spec matches transform mapping.
Args:
done_spec (TensorSpec): spec before the transform
Returns:
expected spec after the transform
"""
return done_spec
def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec:
"""Transforms the action spec such that the resulting spec matches transform mapping.
Args:
action_spec (TensorSpec): spec before the transform
Returns:
expected spec after the transform
"""
return action_spec
def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec:
"""Transforms the state spec such that the resulting spec matches transform mapping.
Args:
state_spec (TensorSpec): spec before the transform
Returns:
expected spec after the transform
"""
return state_spec
def dump(self, **kwargs) -> None:
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}(keys={self.in_keys})"
def set_container(self, container: Transform | EnvBase) -> None:
if self.parent is not None:
raise AttributeError(
f"parent of transform {type(self)} already set. "
"Call `transform.clone()` to get a similar transform with no parent set."
)
self.__dict__["_container"] = (
weakref.ref(container) if container is not None else None
)
self.__dict__["_parent"] = None
def reset_parent(self) -> None:
self.__dict__["_container"] = None
self.__dict__["_parent"] = None
def clone(self) -> T:
self_copy = copy(self)
state = copy(self.__dict__)
state["_container"] = None
state["_parent"] = None
self_copy.__dict__.update(state)
return self_copy
@property
def container(self):
"""Returns the env containing the transform.
Examples:
>>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter()))
>>> env.transform[0].container is env
True
"""
if "_container" not in self.__dict__:
raise AttributeError("transform parent uninitialized")
container_weakref = self.__dict__["_container"]
if container_weakref is not None:
container = container_weakref()
else:
container = container_weakref
if container is None:
return container
while not isinstance(container, EnvBase):
# if it's not an env, it should be a Compose transform
if not isinstance(container, Compose):
raise ValueError(
"A transform parent must be either another Compose transform or an environment object."
)
compose = container
container_weakref = compose.__dict__.get("_container")
if container_weakref is not None:
# container is a weakref
container = container_weakref()
else:
container = container_weakref
return container
def __getstate__(self):
result = self.__dict__.copy()
container = result["_container"]
if container is not None:
container = container()
result["_container"] = container
return result
def __setstate__(self, state):
state["_container"] = (
weakref.ref(state["_container"])
if state["_container"] is not None
else None
)
self.__dict__.update(state)
@property
def parent(self) -> EnvBase | None:
"""Returns the parent env of the transform.
The parent env is the env that contains all the transforms up until the current one.
Examples:
>>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter()))
>>> env.transform[1].parent
TransformedEnv(
env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu),
transform=Compose(
RewardSum(keys=['reward'])))
"""
# TODO: ideally parent should be a weakref, like container, to avoid keeping track of a parent that
# is de facto out of scope.
parent = self.__dict__.get("_parent")
if parent is None:
if "_container" not in self.__dict__:
raise AttributeError("transform parent uninitialized")
container_weakref = self.__dict__["_container"]
if container_weakref is None:
return container_weakref
container = container_weakref()
if container is None:
torchrl_logger.info(
"transform container out of scope. Returning None for parent."
)
return container
parent = None
if not isinstance(container, EnvBase):
# if it's not an env, it should be a Compose transform
if not isinstance(container, Compose):
raise ValueError(
"A transform parent must be either another Compose transform or an environment object."
)
parent, _ = container._rebuild_up_to(self)
elif isinstance(container, TransformedEnv):
parent = TransformedEnv(container.base_env, auto_unwrap=False)
else:
raise ValueError(f"container is of type {type(container)}")
self.__dict__["_parent"] = parent
return parent
def empty_cache(self):
self.__dict__["_parent"] = None
def set_missing_tolerance(self, mode=False):
self._missing_tolerance = mode
@property
def missing_tolerance(self):
return self._missing_tolerance
def to(self, *args, **kwargs):
# remove the parent, because it could have the wrong device associated
self.empty_cache()
return super().to(*args, **kwargs)
class _TEnvPostInit(_EnvPostInit):
def __call__(self, *args, **kwargs):
instance: EnvBase = super(_EnvPostInit, self).__call__(*args, **kwargs)
# we skip the materialization of the specs, because this can't be done with lazy
# transforms such as ObservationNorm.
return instance
class TransformedEnv(EnvBase, metaclass=_TEnvPostInit):
"""A transformed_in environment.
Args:
env (EnvBase): original environment to be transformed_in.
transform (Transform or callable, optional): transform to apply to the tensordict resulting
from :obj:`env.step(td)`. If none is provided, an empty Compose
placeholder in an eval mode is used.
.. note:: If ``transform`` is a callable, it must receive as input a single tensordict
and output a tensordict as well. The callable will be called at ``step``
and ``reset`` time: if it acts on the reward (which is absent at
reset time), a check needs to be implemented to ensure that
the transform will run smoothly:
>>> def add_1(data):
... if "reward" in data.keys():
... return data.set("reward", data.get("reward") + 1)
... return data
>>> env = TransformedEnv(base_env, add_1)
cache_specs (bool, optional): if ``True``, the specs will be cached once
and for all after the first call (i.e. the specs will be
transformed_in only once). If the transform changes during
training, the original spec transform may not be valid anymore,
in which case this value should be set to `False`. Default is
`True`.
Keyword Args:
auto_unwrap (bool, optional): if ``True``, wrapping a transformed env in transformed env
unwraps the transforms of the inner TransformedEnv in the outer one (the new instance).
Defaults to ``True``.
.. note:: This behavior will switch to ``False`` in v0.9.
.. seealso:: :class:`~torchrl.set_auto_unwrap_transformed_env`
Examples:
>>> env = GymEnv("Pendulum-v0")
>>> transform = RewardScaling(0.0, 1.0)
>>> transformed_env = TransformedEnv(env, transform)
>>> # check auto-unwrap
>>> transformed_env = TransformedEnv(transformed_env, StepCounter())
>>> # The inner env has been unwrapped
>>> assert isinstance(transformed_env.base_env, GymEnv)
"""
def __init__(
self,
env: EnvBase,
transform: Transform | None = None,
cache_specs: bool = True,
*,
auto_unwrap: bool | None = None,
**kwargs,
):
self._transform = None
device = kwargs.pop("device", None)
if device is not None:
env = env.to(device)
else:
device = env.device
super().__init__(device=None, allow_done_after_reset=None, **kwargs)
# Type matching must be exact here, because subtyping could introduce differences in behavior that must
# be contained within the subclass.
if type(env) is TransformedEnv and type(self) is TransformedEnv:
if auto_unwrap is None:
auto_unwrap = auto_unwrap_transformed_env(allow_none=True)
if auto_unwrap is None:
warnings.warn(
"The default behavior of TransformedEnv will change in version 0.9. "
"Nested TransformedEnvs will no longer be automatically unwrapped by default. "
"To prepare for this change, use set_auto_unwrap_transformed_env(val: bool) "
"as a decorator or context manager, or set the environment variable "
"AUTO_UNWRAP_TRANSFORMED_ENV to 'False'.",
FutureWarning,
stacklevel=2,
)
auto_unwrap = True
else:
auto_unwrap = False
if auto_unwrap:
self._set_env(env.base_env, device)
if type(transform) is not Compose:
# we don't use isinstance as some transforms may be subclassed from
# Compose but with other features that we don't want to lose.
if not isinstance(transform, Transform):
if callable(transform):
transform = _CallableTransform(transform)
else:
raise ValueError(
"Invalid transform type, expected a Transform instance or a callable "
f"but got an object of type {type(transform)}."
)
if transform is not None:
transform = [transform]
else:
transform = []
else:
for t in transform:
t.reset_parent()
env_transform = env.transform.clone()
if type(env_transform) is not Compose:
env_transform = [env_transform]
else:
for t in env_transform:
t.reset_parent()
transform = Compose(*env_transform, *transform).to(device)
else:
self._set_env(env, device)
if transform is None:
transform = Compose()
self.transform = transform
self._last_obs = None
self.cache_specs = cache_specs
self.__dict__["_input_spec"] = None
self.__dict__["_output_spec"] = None
@property
def batch_size(self) -> torch.Size:
try:
if self.transform is not None:
return self.transform.transform_env_batch_size(self.base_env.batch_size)
return self.base_env.batch_size
except AttributeError:
# during init, the base_env is not yet defined
return torch.Size([])
@batch_size.setter
def batch_size(self, value: torch.Size) -> None:
raise RuntimeError(
"Cannot modify the batch-size of a transformed env. Change the batch size of the base_env instead."
)
def add_truncated_keys(self) -> TransformedEnv:
self.base_env.add_truncated_keys()
self.empty_cache()
return self
def _set_env(self, env: EnvBase, device) -> None:
if device != env.device:
env = env.to(device)
self.base_env = env
# updates need not be inplace, as transforms may modify values out-place
self.base_env._inplace_update = False
@property
def transform(self) -> Transform:
return getattr(self, "_transform", None)
@transform.setter
def transform(self, transform: Transform):
if not isinstance(transform, Transform):
if callable(transform):
transform = _CallableTransform(transform)
else:
raise ValueError(
f"""Expected a transform of type torchrl.envs.transforms.Transform or a callable,
but got an object of type {type(transform)}."""
)
prev_transform = getattr(self, "_transform", None)
if prev_transform is not None:
prev_transform.empty_cache()
prev_transform.reset_parent()
transform = transform.to(self.device)
transform.set_container(self)
transform.eval()
self._transform = transform
@property
def device(self) -> bool:
device = self.base_env.device
if self.transform is None:
# during init, the device is checked
return device
return self.transform.transform_env_device(device)
@device.setter
def device(self, value):
raise RuntimeError("device is a read-only property")
@property
def batch_locked(self) -> bool:
return self.base_env.batch_locked
@batch_locked.setter
def batch_locked(self, value):
raise RuntimeError("batch_locked is a read-only property")
@property
def run_type_checks(self) -> bool:
return self.base_env.run_type_checks
@run_type_checks.setter
def run_type_checks(self, value):
raise RuntimeError(
"run_type_checks is a read-only property for TransformedEnvs"
)