Continuing from the previous post (Learning to use OpenAI Gym) I explore how to create a neural network policy that can train our cart to balance the pole for a longer period of time compared to the simple hard coded policy. Using a simple network we should see an improvement my measuring the amount of consectutive steps that the pole can stay vertical.

Neural Network Policy

Use the import statements and plotting functions from the previous post

keras.backend.clear_session()
tf.random.set_seed(42)
np.random.rand(42)
#this is the observation space 
num_inputs = 4 


If changing the parameters of the model, make sure to clear session before retraining the model


model = keras.models.Sequential([
    keras.layers.Dense(20,activation="relu",input_shape=(num_inputs,)),
    keras.layers.Dense(5,activation="relu"),
    keras.layers.Dense(1,activation="sigmoid"),
])

model.summary()
Model: "sequential" 
_________________________________________________________________
Layer (type) Output Shape Param # =================================================================
dense (Dense) (None, 20) 100
dense_1 (Dense) (None, 5) 105
dense_2 (Dense) (None, 1) 6
=================================================================
Total params: 211
Trainable params: 211
Non-trainable params: 0
_________________________________________________________________

The input for our model is going to the environment’s observations: obs = [x,x,x,x] and the output will the action that is going to be taken by the cart (either 0 or 1). Since there are only two possible actions, we only need one output neuron using the sigmoid activation function to represent the action that the cart will take. If there were more than two actions that we could take, then each action would be represented by a neuron and we would use the softmax activation function.

In the code below, we decide our action based on random probability. For example let’s say the randomly generated probability is .40 and the probability of going left generated by the model is 0.50. We have a statement below that compares these two probabilites. Since .40 > .50 is a false statement, we convert our boolean to the integer 0 and our action is to go left.

We make our decisions based on this random probability because we want to find the right balance between exploring new actions and exploiting actions that we already know work well. Imagine you go to a cafe and randomly select a coffee. If you like it, the probability that you order it again next time you go is increased. However, you shouldn’t increase that probability to 100% because there might be other coffees to try that are even better than the first.

def render_policy(model,max_steps=200,seed=42): 
  frames = [] 
  env = gym.make("CartPole-v1",render_mode="rgb_array")
  np.random.seed(seed)
  #reset the environment 
  obs = env.reset()
  obs = np.array([obs[0]])
  reward = 0 
  #keep track of how many consectuive times pole is vertical 
  totals = 0 
  for step in range(max_steps):
    frames.append(env.render())
    left_prob = model.predict(obs)
    #print('left_prob',left_prob)
    #generate random number  
    p = np.random.rand()
    #print('p',p)
    #turn boolean value into integer (ACTION  = TRUE:1 OR FALSE:0)
    action = int(p>left_prob)
    #print('action',action)
    stats = env.step(action)
    obs = np.array([stats[0]])
    
    reward = stats[1]
    done = stats[2]
    info = stats[3]
    totals+= reward 
    if done: 
      break
  env.close()
  return frames,totals

For our cart-pole environment, we can ignore past observations and actions because at each step we can see the environments full state.

For example: If the environment only revealed the cart’s position and not the velocity, you would have to consider past and current observations in order to determine the current velocity. We do not have to worry about this in our case.

Lets see how well a randomly initialized policy network performs:

frames,totals = render_policy(model)
plot_animation(frames)
    1/1 [==============================] - 0s 19ms/step
    p 0.3745401188473625
    left_prob [[0.49724284]]
    action 0
    1/1 [==============================] - 0s 20ms/step
    p 0.9507143064099162
    left_prob [[0.50142497]]
    action 1
    1/1 [==============================] - 0s 19ms/step
    p 0.7319939418114051
    left_prob [[0.4971607]]
    action 1
    1/1 [==============================] - 0s 21ms/step
    p 0.5986584841970366
    left_prob [[0.4774564]]
    action 1
    1/1 [==============================] - 0s 29ms/step
    p 0.15601864044243652
    left_prob [[0.4559578]]
    action 0
    1/1 [==============================] - 0s 24ms/step
    p 0.15599452033620265
    left_prob [[0.47929704]]
    action 0
    1/1 [==============================] - 0s 23ms/step
    p 0.05808361216819946
    left_prob [[0.4979668]]
    action 0
    1/1 [==============================] - 0s 22ms/step
    p 0.8661761457749352
    left_prob [[0.5000971]]
    action 1
    1/1 [==============================] - 0s 21ms/step
    p 0.6011150117432088
    left_prob [[0.49776265]]
    action 1
    1/1 [==============================] - 0s 24ms/step
    p 0.7080725777960455
    left_prob [[0.47974744]]
    action 1
    1/1 [==============================] - 0s 25ms/step
    p 0.020584494295802447
    left_prob [[0.45795083]]
    action 0
    1/1 [==============================] - 0s 24ms/step
    p 0.9699098521619943
    left_prob [[0.48107377]]
    action 1
    1/1 [==============================] - 0s 18ms/step
    p 0.8324426408004217
    left_prob [[0.45891747]]
    action 1
    1/1 [==============================] - 0s 23ms/step
    p 0.21233911067827616
    left_prob [[0.43718466]]
    action 0
    1/1 [==============================] - 0s 21ms/step
    p 0.18182496720710062
    left_prob [[0.4599171]]
    action 0
    1/1 [==============================] - 0s 22ms/step
    p 0.18340450985343382
    left_prob [[0.48222327]]
    action 0
    1/1 [==============================] - 0s 20ms/step
    p 0.3042422429595377
    left_prob [[0.49851432]]
    action 0
    1/1 [==============================] - 0s 25ms/step
    p 0.5247564316322378
    left_prob [[0.49408367]]
    action 1
    1/1 [==============================] - 0s 24ms/step
    p 0.43194501864211576
    left_prob [[0.49975875]]
    action 0
    1/1 [==============================] - 0s 20ms/step
    p 0.2912291401980419
    left_prob [[0.49443373]]
    action 0
    1/1 [==============================] - 0s 21ms/step
    p 0.6118528947223795
    left_prob [[0.49302602]]
    action 1
    1/1 [==============================] - 0s 28ms/step
    p 0.13949386065204183
    left_prob [[0.49531654]]
    action 0
    1/1 [==============================] - 0s 21ms/step
    p 0.29214464853521815
    left_prob [[0.49462938]]
    action 0
    1/1 [==============================] - 0s 23ms/step
    p 0.3663618432936917
    left_prob [[0.49462]]
    action 0
    1/1 [==============================] - 0s 23ms/step
    p 0.45606998421703593
    left_prob [[0.495313]]
    action 0
    1/1 [==============================] - 0s 21ms/step
    p 0.7851759613930136
    left_prob [[0.49673685]]
    action 1
    1/1 [==============================] - 0s 19ms/step
    p 0.19967378215835974
    left_prob [[0.50123394]]
    action 0
    1/1 [==============================] - 0s 17ms/step
    p 0.5142344384136116
    left_prob [[0.5028582]]
    action 1
    1/1 [==============================] - 0s 21ms/step
    p 0.5924145688620425
    left_prob [[0.5076048]]
    action 1
    1/1 [==============================] - 0s 16ms/step
    p 0.046450412719997725
    left_prob [[0.51172656]]
    action 0
    1/1 [==============================] - 0s 19ms/step
    p 0.6075448519014384
    left_prob [[0.5132017]]
    action 1
    1/1 [==============================] - 0s 25ms/step
    p 0.17052412368729153
    left_prob [[0.5151897]]
    action 0
    1/1 [==============================] - 0s 24ms/step
    p 0.06505159298527952
    left_prob [[0.51832235]]
    action 0
</input>


The network here is able to keep the pole vertical for 51 consectutive steps which is an improvement over the previous hard-coded policy! Woo!

print(totals)
51.0