Skip to content

Commit 77f512b

Browse files
authored
Migrating Actor Critic Method example to Keras 3 (TF-Only) (#1759)
* migrated the example to tf only backend * .md and .ipynb file added
1 parent ce463ed commit 77f512b

File tree

3 files changed

+35
-24
lines changed

3 files changed

+35
-24
lines changed

examples/rl/actor_critic_cartpole.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
Title: Actor Critic Method
33
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
44
Date created: 2020/05/13
5-
Last modified: 2020/05/13
5+
Last modified: 2024/02/22
66
Description: Implement Actor Critic Method in CartPole environment.
77
Accelerator: NONE
8+
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
89
"""
910

1011
"""
@@ -40,11 +41,15 @@
4041
## Setup
4142
"""
4243

44+
import os
45+
46+
os.environ["KERAS_BACKEND"] = "tensorflow"
4347
import gym
4448
import numpy as np
49+
import keras
50+
from keras import ops
51+
from keras import layers
4552
import tensorflow as tf
46-
from tensorflow import keras
47-
from tensorflow.keras import layers
4853

4954
# Configuration parameters for the whole setup
5055
seed = 42
@@ -98,8 +103,8 @@
98103
# env.render(); Adding this line would show the attempts
99104
# of the agent in a pop up window.
100105

101-
state = tf.convert_to_tensor(state)
102-
state = tf.expand_dims(state, 0)
106+
state = ops.convert_to_tensor(state)
107+
state = ops.expand_dims(state, 0)
103108

104109
# Predict action probabilities and estimated future rewards
105110
# from environment state
@@ -108,7 +113,7 @@
108113

109114
# Sample action from action probability distribution
110115
action = np.random.choice(num_actions, p=np.squeeze(action_probs))
111-
action_probs_history.append(tf.math.log(action_probs[0, action]))
116+
action_probs_history.append(ops.log(action_probs[0, action]))
112117

113118
# Apply the sampled action in our environment
114119
state, reward, done, _ = env.step(action)
@@ -152,7 +157,7 @@
152157
# The critic must be updated so that it predicts a better estimate of
153158
# the future rewards.
154159
critic_losses.append(
155-
huber_loss(tf.expand_dims(value, 0), tf.expand_dims(ret, 0))
160+
huber_loss(ops.expand_dims(value, 0), ops.expand_dims(ret, 0))
156161
)
157162

158163
# Backpropagation

examples/rl/ipynb/actor_critic_cartpole.ipynb

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)<br>\n",
1212
"**Date created:** 2020/05/13<br>\n",
13-
"**Last modified:** 2020/05/13<br>\n",
13+
"**Last modified:** 2024/02/22<br>\n",
1414
"**Description:** Implement Actor Critic Method in CartPole environment."
1515
]
1616
},
@@ -60,17 +60,20 @@
6060
},
6161
{
6262
"cell_type": "code",
63-
"execution_count": 0,
63+
"execution_count": null,
6464
"metadata": {
6565
"colab_type": "code"
6666
},
6767
"outputs": [],
6868
"source": [
69+
"import os\n",
70+
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
6971
"import gym\n",
7072
"import numpy as np\n",
73+
"import keras\n",
74+
"from keras import ops\n",
75+
"from keras import layers\n",
7176
"import tensorflow as tf\n",
72-
"from tensorflow import keras\n",
73-
"from tensorflow.keras import layers\n",
7477
"\n",
7578
"# Configuration parameters for the whole setup\n",
7679
"seed = 42\n",
@@ -101,7 +104,7 @@
101104
},
102105
{
103106
"cell_type": "code",
104-
"execution_count": 0,
107+
"execution_count": null,
105108
"metadata": {
106109
"colab_type": "code"
107110
},
@@ -130,7 +133,7 @@
130133
},
131134
{
132135
"cell_type": "code",
133-
"execution_count": 0,
136+
"execution_count": null,
134137
"metadata": {
135138
"colab_type": "code"
136139
},
@@ -152,8 +155,8 @@
152155
" # env.render(); Adding this line would show the attempts\n",
153156
" # of the agent in a pop up window.\n",
154157
"\n",
155-
" state = tf.convert_to_tensor(state)\n",
156-
" state = tf.expand_dims(state, 0)\n",
158+
" state = ops.convert_to_tensor(state)\n",
159+
" state = ops.expand_dims(state, 0)\n",
157160
"\n",
158161
" # Predict action probabilities and estimated future rewards\n",
159162
" # from environment state\n",
@@ -162,7 +165,7 @@
162165
"\n",
163166
" # Sample action from action probability distribution\n",
164167
" action = np.random.choice(num_actions, p=np.squeeze(action_probs))\n",
165-
" action_probs_history.append(tf.math.log(action_probs[0, action]))\n",
168+
" action_probs_history.append(ops.log(action_probs[0, action]))\n",
166169
"\n",
167170
" # Apply the sampled action in our environment\n",
168171
" state, reward, done, _ = env.step(action)\n",
@@ -206,7 +209,7 @@
206209
" # The critic must be updated so that it predicts a better estimate of\n",
207210
" # the future rewards.\n",
208211
" critic_losses.append(\n",
209-
" huber_loss(tf.expand_dims(value, 0), tf.expand_dims(ret, 0))\n",
212+
" huber_loss(ops.expand_dims(value, 0), ops.expand_dims(ret, 0))\n",
210213
" )\n",
211214
"\n",
212215
" # Backpropagation\n",

examples/rl/md/actor_critic_cartpole.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)<br>
44
**Date created:** 2020/05/13<br>
5-
**Last modified:** 2020/05/13<br>
5+
**Last modified:** 2024/02/22<br>
66
**Description:** Implement Actor Critic Method in CartPole environment.
77

88

@@ -46,11 +46,14 @@ remains upright. The agent, therefore, must learn to keep the pole from falling
4646

4747

4848
```python
49+
import os
50+
os.environ["KERAS_BACKEND"] = "tensorflow"
4951
import gym
5052
import numpy as np
53+
import keras
54+
from keras import ops
55+
from keras import layers
5156
import tensorflow as tf
52-
from tensorflow import keras
53-
from tensorflow.keras import layers
5457

5558
# Configuration parameters for the whole setup
5659
seed = 42
@@ -112,8 +115,8 @@ while True: # Run until solved
112115
# env.render(); Adding this line would show the attempts
113116
# of the agent in a pop up window.
114117

115-
state = tf.convert_to_tensor(state)
116-
state = tf.expand_dims(state, 0)
118+
state = ops.convert_to_tensor(state)
119+
state = ops.expand_dims(state, 0)
117120

118121
# Predict action probabilities and estimated future rewards
119122
# from environment state
@@ -122,7 +125,7 @@ while True: # Run until solved
122125

123126
# Sample action from action probability distribution
124127
action = np.random.choice(num_actions, p=np.squeeze(action_probs))
125-
action_probs_history.append(tf.math.log(action_probs[0, action]))
128+
action_probs_history.append(ops.log(action_probs[0, action]))
126129

127130
# Apply the sampled action in our environment
128131
state, reward, done, _ = env.step(action)
@@ -166,7 +169,7 @@ while True: # Run until solved
166169
# The critic must be updated so that it predicts a better estimate of
167170
# the future rewards.
168171
critic_losses.append(
169-
huber_loss(tf.expand_dims(value, 0), tf.expand_dims(ret, 0))
172+
huber_loss(ops.expand_dims(value, 0), ops.expand_dims(ret, 0))
170173
)
171174

172175
# Backpropagation

0 commit comments

Comments
 (0)