Creating a NN policy for Reinforcement Learning
Continuing from the previous post (Learning to use OpenAI Gym) I explore how to create a neural network policy that can train our cart to balance the pole for a longer period of time compared to the simple hard coded policy. Using a simple network we should see an improvement my measuring the amount of consectutive steps that the pole can stay vertical.
Neural Network Policy
Use the import statements and plotting functions from the previous post
keras.backend.clear_session()
tf.random.set_seed(42)
np.random.rand(42)
#this is the observation space
num_inputs = 4
model = keras.models.Sequential([
keras.layers.Dense(20,activation="relu",input_shape=(num_inputs,)),
keras.layers.Dense(5,activation="relu"),
keras.layers.Dense(1,activation="sigmoid"),
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param # =================================================================
dense (Dense) (None, 20) 100
dense_1 (Dense) (None, 5) 105
dense_2 (Dense) (None, 1) 6
=================================================================
Total params: 211
Trainable params: 211
Non-trainable params: 0
_________________________________________________________________
The input for our model is going to the environment’s observations: obs = [x,x,x,x] and the output will the action that is going to be taken by the cart (either 0 or 1). Since there are only two possible actions, we only need one output neuron using the sigmoid activation function to represent the action that the cart will take. If there were more than two actions that we could take, then each action would be represented by a neuron and we would use the softmax activation function.
In the code below, we decide our action based on random probability. For example let’s say the randomly generated probability is .40 and the probability of going left generated by the model is 0.50. We have a statement below that compares these two probabilites. Since .40 > .50 is a false statement, we convert our boolean to the integer 0 and our action is to go left.
We make our decisions based on this random probability because we want to find the right balance between exploring new actions and exploiting actions that we already know work well. Imagine you go to a cafe and randomly select a coffee. If you like it, the probability that you order it again next time you go is increased. However, you shouldn’t increase that probability to 100% because there might be other coffees to try that are even better than the first.
def render_policy(model,max_steps=200,seed=42):
frames = []
env = gym.make("CartPole-v1",render_mode="rgb_array")
np.random.seed(seed)
#reset the environment
obs = env.reset()
obs = np.array([obs[0]])
reward = 0
#keep track of how many consectuive times pole is vertical
totals = 0
for step in range(max_steps):
frames.append(env.render())
left_prob = model.predict(obs)
#print('left_prob',left_prob)
#generate random number
p = np.random.rand()
#print('p',p)
#turn boolean value into integer (ACTION = TRUE:1 OR FALSE:0)
action = int(p>left_prob)
#print('action',action)
stats = env.step(action)
obs = np.array([stats[0]])
reward = stats[1]
done = stats[2]
info = stats[3]
totals+= reward
if done:
break
env.close()
return frames,totals
For our cart-pole environment, we can ignore past observations and actions because at each step we can see the environments full state.
For example: If the environment only revealed the cart’s position and not the velocity, you would have to consider past and current observations in order to determine the current velocity. We do not have to worry about this in our case.
Lets see how well a randomly initialized policy network performs:
frames,totals = render_policy(model)
plot_animation(frames)
1/1 [==============================] - 0s 19ms/step p 0.3745401188473625 left_prob [[0.49724284]] action 0 1/1 [==============================] - 0s 20ms/step p 0.9507143064099162 left_prob [[0.50142497]] action 1 1/1 [==============================] - 0s 19ms/step p 0.7319939418114051 left_prob [[0.4971607]] action 1 1/1 [==============================] - 0s 21ms/step p 0.5986584841970366 left_prob [[0.4774564]] action 1 1/1 [==============================] - 0s 29ms/step p 0.15601864044243652 left_prob [[0.4559578]] action 0 1/1 [==============================] - 0s 24ms/step p 0.15599452033620265 left_prob [[0.47929704]] action 0 1/1 [==============================] - 0s 23ms/step p 0.05808361216819946 left_prob [[0.4979668]] action 0 1/1 [==============================] - 0s 22ms/step p 0.8661761457749352 left_prob [[0.5000971]] action 1 1/1 [==============================] - 0s 21ms/step p 0.6011150117432088 left_prob [[0.49776265]] action 1 1/1 [==============================] - 0s 24ms/step p 0.7080725777960455 left_prob [[0.47974744]] action 1 1/1 [==============================] - 0s 25ms/step p 0.020584494295802447 left_prob [[0.45795083]] action 0 1/1 [==============================] - 0s 24ms/step p 0.9699098521619943 left_prob [[0.48107377]] action 1 1/1 [==============================] - 0s 18ms/step p 0.8324426408004217 left_prob [[0.45891747]] action 1 1/1 [==============================] - 0s 23ms/step p 0.21233911067827616 left_prob [[0.43718466]] action 0 1/1 [==============================] - 0s 21ms/step p 0.18182496720710062 left_prob [[0.4599171]] action 0 1/1 [==============================] - 0s 22ms/step p 0.18340450985343382 left_prob [[0.48222327]] action 0 1/1 [==============================] - 0s 20ms/step p 0.3042422429595377 left_prob [[0.49851432]] action 0 1/1 [==============================] - 0s 25ms/step p 0.5247564316322378 left_prob [[0.49408367]] action 1 1/1 [==============================] - 0s 24ms/step p 0.43194501864211576 left_prob [[0.49975875]] action 0 1/1 [==============================] - 0s 20ms/step p 0.2912291401980419 left_prob [[0.49443373]] action 0 1/1 [==============================] - 0s 21ms/step p 0.6118528947223795 left_prob [[0.49302602]] action 1 1/1 [==============================] - 0s 28ms/step p 0.13949386065204183 left_prob [[0.49531654]] action 0 1/1 [==============================] - 0s 21ms/step p 0.29214464853521815 left_prob [[0.49462938]] action 0 1/1 [==============================] - 0s 23ms/step p 0.3663618432936917 left_prob [[0.49462]] action 0 1/1 [==============================] - 0s 23ms/step p 0.45606998421703593 left_prob [[0.495313]] action 0 1/1 [==============================] - 0s 21ms/step p 0.7851759613930136 left_prob [[0.49673685]] action 1 1/1 [==============================] - 0s 19ms/step p 0.19967378215835974 left_prob [[0.50123394]] action 0 1/1 [==============================] - 0s 17ms/step p 0.5142344384136116 left_prob [[0.5028582]] action 1 1/1 [==============================] - 0s 21ms/step p 0.5924145688620425 left_prob [[0.5076048]] action 1 1/1 [==============================] - 0s 16ms/step p 0.046450412719997725 left_prob [[0.51172656]] action 0 1/1 [==============================] - 0s 19ms/step p 0.6075448519014384 left_prob [[0.5132017]] action 1 1/1 [==============================] - 0s 25ms/step p 0.17052412368729153 left_prob [[0.5151897]] action 0 1/1 [==============================] - 0s 24ms/step p 0.06505159298527952 left_prob [[0.51832235]] action 0
The network here is able to keep the pole vertical for 51 consectutive steps which is an improvement over the previous hard-coded policy! Woo!
print(totals)
51.0