Skip to content

Commit b29c9f0

Browse files
authored
Fix RLDatasets.jl documentation (#467)
* Refine documentation in RLDatasets.jl * update docs in missed out files * fix link provided * Update atari_dataset.jl * Update d4rl_dataset.jl * Update d4rl_dataset.jl * fix typo * Fix type error * fix type error * update readme
1 parent 9208cb0 commit b29c9f0

File tree

5 files changed

+115
-90
lines changed

5 files changed

+115
-90
lines changed

src/ReinforcementLearningDatasets/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ pkg> add https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl:
1414
```julia
1515
using ReinforcementLearningDatasets
1616
ds = dataset("hopper-medium-replay-v0"; repo="d4rl")
17-
samples = Iterators.take!(ds)
17+
samples = Iterators.take(ds)
1818
```
1919
`ds` is of the type `D4RLDataset` which consists of the entire dataset along with some other information about the dataset. `samples` are in the form of `SARTS` with batch_size 256.
2020
#### RL Unplugged
2121
```julia
2222
using ReinforcementLearningDatasets
2323
ds = rl_unplugged_atari_dataset("pong", 1, [1, 2])
24-
samples = Iterators.take!(ds, 2)
24+
samples = take!(ds, 2)
2525
```
2626
`ds` is a `Channel{RLTransition}` that returns batches of type `RLTransition` when `take!` is used.
2727

src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,26 @@ using NPZ
22
using CodecZlib
33

44
"""
5-
Represents an iterable dataset of type AtariDataSet with the following fields:
6-
7-
`dataset`: Dict{Symbol, Any}, representation of the dataset as a Dictionary with style as `style`
8-
`epochs`: Vector{Int}, list of epochs to load
9-
`repo`: String, the repository from which the dataset is taken
10-
`length`: Integer, the length of the dataset
11-
`batch_size`: Integer, the size of the batches returned by `iterate`.
12-
`style`: Tuple, the type of the NamedTuple, for now SARTS and SART is supported.
13-
`rng`<: AbstractRNG.
14-
`meta`: Dict, the metadata provided along with the dataset
15-
`is_shuffle`: Bool, determines if the batches returned by `iterate` are shuffled.
5+
Represents an `Iterable` dataset with the following fields:
6+
7+
# Fields
8+
- `dataset::Dict{Symbol, Any}`: representation of the dataset as a Dictionary with style as `style`.
9+
- `epochs::Vector{Int}`: list of epochs to load.
10+
- `repo::String`: the repository from which the dataset is taken.
11+
- `length::Int`: the length of the dataset.
12+
- `batch_size::Int`: the size of the batches returned by `iterate`.
13+
- `style::Tuple{Symbol}`: the style of the `Iterator` that is returned, check out: [`SARTS`](@ref), [`SART`](@ref) and [`SA`](@ref)
14+
for types supported out of the box.
15+
- `rng<:AbstractRNG`.
16+
- `meta::Dict`: the metadata provided along with the dataset.
17+
- `is_shuffle::Bool`: determines if the batches returned by `iterate` are shuffled.
1618
"""
1719
struct AtariDataSet{T<:AbstractRNG} <:RLDataSet
1820
dataset::Dict{Symbol, Any}
1921
epochs::Vector{Int}
2022
repo::String
21-
length::Integer
22-
batch_size::Integer
23+
length::Int
24+
batch_size::Int
2325
style::Tuple
2426
rng::T
2527
meta::Dict
@@ -31,22 +33,30 @@ const atari_frame_size = 84
3133
const epochs_per_game = 50
3234

3335
"""
34-
dataset(dataset::String, epochs::Vector{Int}; repo::String, style::Tuple, rng<:AbstractRNG, is_shuffle::Bool, max_iters::Int64, batch_size::Int64)
35-
36-
Creates a dataset of enclosed in a AtariDataSet type and other related metadata for the `dataset` that is passed.
37-
The `AtariDataSet` type is an iterable that fetches batches when used in a for loop for convenience during offline training.
38-
39-
`dataset`: String, name of the datset.
40-
`index`: Int, analogous to v
41-
`epochs`: Vector{Int}, list of epochs to load
42-
`repo`: Name of the repository of the dataset
43-
`style`: the style of the iterator and the Dict inside AtariDataSet that is returned.
44-
`rng`: StableRNG
45-
`max_iters`: maximum number of iterations for the iterator.
46-
`is_shuffle`: whether the dataset is shuffled or not. `true` by default.
47-
`batch_size`: batch_size that is yielded by the iterator. Defaults to 256.
48-
49-
The returned type is an infinite iterator which can be called using `iterate` and will return batches as specified in the dataset.
36+
dataset(dataset, index, epochs; <keyword arguments>)
37+
38+
Create a dataset enclosed in a [`AtariDataSet`](@ref) [`Iterable`](@ref) type. Contain other related metadata
39+
for the `dataset` that is passed. The returned type is an infinite or a finite `Iterator`
40+
respectively depending upon whether is_shuffle is `true` or `false`. For more information regarding
41+
the dataset, refer to [google-research/batch_rl](https://github.com/google-research/batch_rl).
42+
43+
# Arguments
44+
- `dataset::String`: name of the datset.
45+
- `index::Int`: analogous to `v` and different values correspond to different `seed`s that
46+
are used for data collection. can be between `[1:5]`.
47+
- `epochs::Vector{Int}`: list of epochs to load. included epochs should be between `[0:50]`.
48+
- `style::Tuple{Symbol}=SARTS`: the style of the `Iterator` that is returned. can be [`SARTS`](@ref),
49+
[`SART`](@ref) or [`SA`](@ref).
50+
- `repo::String="atari-replay-datasets"`: name of the repository of the dataset.
51+
- `rng<:AbstractRNG=StableRNG(123)`.
52+
- `is_shuffle::Bool=true`: determines if the dataset is shuffled or not.
53+
- `batch_size::Int=256` batch_size that is yielded by the iterator.
54+
55+
!!! note
56+
57+
The dataset takes up significant amount of space in RAM. Therefore it is advised to
58+
load even one epoch with 20GB of RAM. We are looking for ways to use lazy data loading here
59+
and any contributions are welcome.
5060
"""
5161
function dataset(
5262
game::String,
@@ -172,4 +182,4 @@ function atari_verify(dataset::Dict, num_epochs::Int)
172182
@assert size(dataset["action"]) == (num_epochs * samples_per_epoch,)
173183
@assert size(dataset["reward"]) == (num_epochs * samples_per_epoch,)
174184
@assert size(dataset["terminal"]) == (num_epochs * samples_per_epoch,)
175-
end
185+
end
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
export SARTS
22
export SART
3+
export SA
34
export RLDataSet
45

56
abstract type RLDataSet end
67

8+
"""
9+
(:state, :action, :reward, :terminal, :next_state)
10+
type of the returned batches.
11+
"""
712
const SARTS = (:state, :action, :reward, :terminal, :next_state)
8-
const SART = (:state, :action, :reward, :terminal)
13+
14+
"""
15+
(:state, :action, :reward, :terminal)
16+
type of the returned batches.
17+
"""
18+
const SART = (:state, :action, :reward, :terminal)
19+
20+
"""
21+
(:state, :action)
22+
type of the returned batches.
23+
"""
24+
const SA = (:state, :action)

src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ export dataset
99
export D4RLDataSet
1010

1111
"""
12-
Represents an iterable dataset of type `D4RLDataSet` with the following fields:
13-
14-
`dataset`: Dict{Symbol, Any}, representation of the dataset as a Dictionary with style as `style`
15-
`repo`: String, the repository from which the dataset is taken
16-
`dataset_size`: Integer, the size of the dataset
17-
`batch_size`: Integer, the size of the batches returned by `iterate`.
18-
`style`: Tuple, the type of the NamedTuple, for now SARTS and SART is supported.
19-
`rng`<: AbstractRNG.
20-
`meta`: Dict, the metadata provided along with the dataset
21-
`is_shuffle`: Bool, determines if the batches returned by `iterate` are shuffled.
12+
Represents an `Iterable` dataset with the following fields:
13+
14+
# Fields
15+
- `dataset::Dict{Symbol, Any}`: representation of the dataset as a Dictionary with style as `style`.
16+
- `repo::String`: the repository from which the dataset is taken.
17+
- `dataset_size::Int`, the number of samples in the dataset.
18+
- `batch_size::Int`: the size of the batches returned by `iterate`.
19+
- `style::Tuple{Symbol}`: the style of the `Iterator` that is returned, check out: [`SARTS`](@ref), [`SART`](@ref) and [`SA`](@ref)
20+
for types supported out of the box.
21+
- `rng<:AbstractRNG`.
22+
- `meta::Dict`: the metadata provided along with the dataset.
23+
- `is_shuffle::Bool`: determines if the batches returned by `iterate` are shuffled.
2224
"""
2325
struct D4RLDataSet{T<:AbstractRNG} <: RLDataSet
2426
dataset::Dict{Symbol, Any}
@@ -35,24 +37,31 @@ end
3537
# TO-DO: enable the users providing their own paths to datasets if they already have it
3638
# TO-DO: add additional env arg to do complete verify function
3739
"""
38-
dataset(dataset::String; style::Tuple, rng<:AbstractRNG, is_shuffle::Bool, max_iters::Int64, batch_size::Int64)
39-
40-
Creates a dataset of enclosed in a D4RLDataSet type and other related metadata for the `dataset` that is passed.
41-
The `D4RLDataSet` type is an iterable that fetches batches when used in a for loop for convenience during offline training.
42-
43-
`dataset`: Dict{Symbol, Any}, Name of the datset.
44-
`repo`: Name of the repository of the dataset.
45-
`style`: the style of the iterator and the Dict inside D4RLDataSet that is returned.
46-
`rng`: StableRNG
47-
`max_iters`: maximum number of iterations for the iterator.
48-
`is_shuffle`: whether the dataset is shuffled or not. `true` by default.
49-
`batch_size`: batch_size that is yielded by the iterator. Defaults to 256.
50-
51-
The returned type is an infinite iterator which can be called using `iterate` and will return batches as specified in the dataset.
40+
dataset(dataset; <keyword arguments>)
41+
42+
Create a dataset enclosed in a [`D4RLDataSet`](@ref) `Iterable` type. Contain other related metadata
43+
for the `dataset` that is passed. The returned type is an infinite or a finite `Iterator`
44+
respectively depending upon whether `is_shuffle` is `true` or `false`. For more information regarding
45+
the dataset, refer to [D4RL](https://github.com/rail-berkeley/d4rl).
46+
47+
# Arguments
48+
- `dataset::String`: name of the datset.
49+
- `repo::String="d4rl"`: name of the repository of the dataset.
50+
- `style::Tuple{Symbol}=SARTS`: the style of the `Iterator` that is returned. can be [`SARTS`](@ref),
51+
[`SART`](@ref) or [`SA`](@ref).
52+
- `rng<:AbstractRNG=StableRNG(123)`.
53+
- `is_shuffle::Bool=true`: determines if the dataset is shuffled or not.
54+
- `batch_size::Int=256` batch_size that is yielded by the iterator.
55+
56+
!!! note
57+
58+
[`FLOW`](https://flow-project.github.io/) and [`CARLA`](https://github.com/rail-berkeley/d4rl/wiki/CARLA-Setup) supported by [D4RL](https://github.com/rail-berkeley/d4rl) have not
59+
been tested in this package yet.
5260
"""
53-
function dataset(dataset::String;
54-
style=SARTS,
61+
function dataset(
62+
dataset::String;
5563
repo = "d4rl",
64+
style=SARTS,
5665
rng = StableRNG(123),
5766
is_shuffle = true,
5867
batch_size=256
@@ -139,4 +148,4 @@ function d4rl_verify(data::Dict{String, Any})
139148
N_samples = size(data["observations"])[2]
140149
@assert size(data["rewards"]) == (N_samples,) || size(data["rewards"]) == (1, N_samples)
141150
@assert size(data["terminals"]) == (N_samples,) || size(data["terminals"]) == (1, N_samples)
142-
end
151+
end

src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,7 @@ using PNGFiles
1111
"""
1212
RLTransition
1313
14-
Represents an RLTransition. It can also be used to represent a batch by adding another dimension.
15-
16-
The constructor decodes the incoming `TFRecord.Example` to be ready to use.
17-
18-
Fields:
19-
- `state`
20-
- `action`
21-
- `reward`
22-
- `terminal`
23-
- `next_state`
24-
- `next_action`
25-
- `episode_id`
26-
- `episode_return`
14+
Represent an RLTransition and can also represent a batch.
2715
"""
2816
struct RLTransition
2917
state
@@ -66,21 +54,23 @@ function RLTransition(example::TFRecord.Example)
6654
RLTransition(s, a, r, t, s′, a′, episode_id, episode_return)
6755
end
6856
"""
69-
rl_unplugged_atari_dataset(game::String, run::Int, shards::Vector{Int}; (optional_args))
57+
rl_unplugged_atari_dataset(game, run, shards; <keyword arguments>)
7058
71-
Returns a buffered `Channel` of `RLTransition` batches which supports multi threading.
59+
Returns a buffered `Channel` of [`RLTransition`](@ref) batches which supports
60+
multi threaded loading.
7261
73-
### Arguments and optional_args:
62+
# Arguments
7463
75-
- `game::String`, The name of the env.
76-
- `run::Int`, The run number. Can be in the range 1:5.
77-
- `shards::Vector{Int}` The shards that are to be loaded.
78-
optional_args:
79-
- `shuffle_buffer_size=10_000`, This is the size of the shuffle_buffer used in loading RLTransitions.
80-
- `tf_reader_bufsize=1*1024*1024`, The size of the buffer `bufsize` that is used internally in `TFRecord.read`.
81-
- `tf_reader_sz=10_000`, The size of the `Channel`, `channel_size` that is returned by `TFRecord.read`.
82-
- `batch_size=256`, The size of the batches that are returned by the Channel that is finally returned.
83-
- `n_preallocations`, The size of the buffer in the `Channel` that is returned.
64+
- `game::String`: name of the dataset.
65+
- `run::Int`: run number. can be in the range `1:5`.
66+
- `shards::Vector{Int}`: the shards that are to be loaded.
67+
- `shuffle_buffer_size::Int=10_000`: size of the shuffle_buffer used in loading RLTransitions.
68+
- `tf_reader_bufsize::Int=1*1024*1024`: the size of the buffer `bufsize` that is used internally
69+
in `TFRecord.read`.
70+
- `tf_reader_sz::Int=10_000`: the size of the `Channel`, `channel_size` that is returned by
71+
`TFRecord.read`.
72+
- `batch_size::Int=256`: The number of samples within the batches that are returned by the `Channel`.
73+
- `n_preallocations::Int=nthreads()*12`: the size of the buffer in the `Channel` that is returned.
8474
8575
!!! note
8676
@@ -90,11 +80,11 @@ function rl_unplugged_atari_dataset(
9080
game::String,
9181
run::Int,
9282
shards::Vector{Int};
93-
shuffle_buffer_size = 10_000,
94-
tf_reader_bufsize = 1*1024*1024,
95-
tf_reader_sz = 10_000,
96-
batch_size = 256,
97-
n_preallocations = nthreads() * 12
83+
shuffle_buffer_size=10_000,
84+
tf_reader_bufsize=1*1024*1024,
85+
tf_reader_sz=10_000,
86+
batch_size=256,
87+
n_preallocations=nthreads()*12
9888
)
9989
n = nthreads()
10090

0 commit comments

Comments
 (0)