Creating a NN policy for Reinforcement Learning
Continuing from the previous post (Learning to use OpenAI Gym) I explore how to create a neural network policy that can train our cart to balance the pole for a longer period of time compared to the simple hard coded policy. Using a simple network we should see an improvement my measuring the amount of consectutive steps that the pole can stay vertical.
Neural Network Policy
Use the import statements and plotting functions from the previous post
keras.backend.clear_session()
tf.random.set_seed(42)
np.random.rand(42)
#this is the observation space
num_inputs = 4
model = keras.models.Sequential([
keras.layers.Dense(20,activation="relu",input_shape=(num_inputs,)),
keras.layers.Dense(5,activation="relu"),
keras.layers.Dense(1,activation="sigmoid"),
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param # =================================================================
dense (Dense) (None, 20) 100
dense_1 (Dense) (None, 5) 105
dense_2 (Dense) (None, 1) 6
=================================================================
Total params: 211
Trainable params: 211
Non-trainable params: 0
_________________________________________________________________
The input for our model is going to the environment’s observations: obs = [x,x,x,x] and the output will the action that is going to be taken by the cart (either 0 or 1). Since there are only two possible actions, we only need one output neuron using the sigmoid activation function to represent the action that the cart will take. If there were more than two actions that we could take, then each action would be represented by a neuron and we would use the softmax activation function.
In the code below, we decide our action based on random probability. For example let’s say the randomly generated probability is .40 and the probability of going left generated by the model is 0.50. We have a statement below that compares these two probabilites. Since .40 > .50 is a false statement, we convert our boolean to the integer 0 and our action is to go left.
We make our decisions based on this random probability because we want to find the right balance between exploring new actions and exploiting actions that we already know work well. Imagine you go to a cafe and randomly select a coffee. If you like it, the probability that you order it again next time you go is increased. However, you shouldn’t increase that probability to 100% because there might be other coffees to try that are even better than the first.
def render_policy(model,max_steps=200,seed=42):
frames = []
env = gym.make("CartPole-v1",render_mode="rgb_array")
np.random.seed(seed)
#reset the environment
obs = env.reset()
obs = np.array([obs[0]])
reward = 0
#keep track of how many consectuive times pole is vertical
totals = 0
for step in range(max_steps):
frames.append(env.render())
left_prob = model.predict(obs)
#print('left_prob',left_prob)
#generate random number
p = np.random.rand()
#print('p',p)
#turn boolean value into integer (ACTION = TRUE:1 OR FALSE:0)
action = int(p>left_prob)
#print('action',action)
stats = env.step(action)
obs = np.array([stats[0]])
reward = stats[1]
done = stats[2]
info = stats[3]
totals+= reward
if done:
break
env.close()
return frames,totals
For our cart-pole environment, we can ignore past observations and actions because at each step we can see the environments full state.
For example: If the environment only revealed the cart’s position and not the velocity, you would have to consider past and current observations in order to determine the current velocity. We do not have to worry about this in our case.
Lets see how well a randomly initialized policy network performs:
frames,totals = render_policy(model)
plot_animation(frames)
1/1 [==============================] - 0s 19ms/step
p 0.3745401188473625
left_prob [[0.49724284]]
action 0
1/1 [==============================] - 0s 20ms/step
p 0.9507143064099162
left_prob [[0.50142497]]
action 1
1/1 [==============================] - 0s 19ms/step
p 0.7319939418114051
left_prob [[0.4971607]]
action 1
1/1 [==============================] - 0s 21ms/step
p 0.5986584841970366
left_prob [[0.4774564]]
action 1
1/1 [==============================] - 0s 29ms/step
p 0.15601864044243652
left_prob [[0.4559578]]
action 0
1/1 [==============================] - 0s 24ms/step
p 0.15599452033620265
left_prob [[0.47929704]]
action 0
1/1 [==============================] - 0s 23ms/step
p 0.05808361216819946
left_prob [[0.4979668]]
action 0
1/1 [==============================] - 0s 22ms/step
p 0.8661761457749352
left_prob [[0.5000971]]
action 1
1/1 [==============================] - 0s 21ms/step
p 0.6011150117432088
left_prob [[0.49776265]]
action 1
1/1 [==============================] - 0s 24ms/step
p 0.7080725777960455
left_prob [[0.47974744]]
action 1
1/1 [==============================] - 0s 25ms/step
p 0.020584494295802447
left_prob [[0.45795083]]
action 0
1/1 [==============================] - 0s 24ms/step
p 0.9699098521619943
left_prob [[0.48107377]]
action 1
1/1 [==============================] - 0s 18ms/step
p 0.8324426408004217
left_prob [[0.45891747]]
action 1
1/1 [==============================] - 0s 23ms/step
p 0.21233911067827616
left_prob [[0.43718466]]
action 0
1/1 [==============================] - 0s 21ms/step
p 0.18182496720710062
left_prob [[0.4599171]]
action 0
1/1 [==============================] - 0s 22ms/step
p 0.18340450985343382
left_prob [[0.48222327]]
action 0
1/1 [==============================] - 0s 20ms/step
p 0.3042422429595377
left_prob [[0.49851432]]
action 0
1/1 [==============================] - 0s 25ms/step
p 0.5247564316322378
left_prob [[0.49408367]]
action 1
1/1 [==============================] - 0s 24ms/step
p 0.43194501864211576
left_prob [[0.49975875]]
action 0
1/1 [==============================] - 0s 20ms/step
p 0.2912291401980419
left_prob [[0.49443373]]
action 0
1/1 [==============================] - 0s 21ms/step
p 0.6118528947223795
left_prob [[0.49302602]]
action 1
1/1 [==============================] - 0s 28ms/step
p 0.13949386065204183
left_prob [[0.49531654]]
action 0
1/1 [==============================] - 0s 21ms/step
p 0.29214464853521815
left_prob [[0.49462938]]
action 0
1/1 [==============================] - 0s 23ms/step
p 0.3663618432936917
left_prob [[0.49462]]
action 0
1/1 [==============================] - 0s 23ms/step
p 0.45606998421703593
left_prob [[0.495313]]
action 0
1/1 [==============================] - 0s 21ms/step
p 0.7851759613930136
left_prob [[0.49673685]]
action 1
1/1 [==============================] - 0s 19ms/step
p 0.19967378215835974
left_prob [[0.50123394]]
action 0
1/1 [==============================] - 0s 17ms/step
p 0.5142344384136116
left_prob [[0.5028582]]
action 1
1/1 [==============================] - 0s 21ms/step
p 0.5924145688620425
left_prob [[0.5076048]]
action 1
1/1 [==============================] - 0s 16ms/step
p 0.046450412719997725
left_prob [[0.51172656]]
action 0
1/1 [==============================] - 0s 19ms/step
p 0.6075448519014384
left_prob [[0.5132017]]
action 1
1/1 [==============================] - 0s 25ms/step
p 0.17052412368729153
left_prob [[0.5151897]]
action 0
1/1 [==============================] - 0s 24ms/step
p 0.06505159298527952
left_prob [[0.51832235]]
action 0
The network here is able to keep the pole vertical for 51 consectutive steps which is an improvement over the previous hard-coded policy! Woo!
print(totals)
51.0