Setting and resetting LSTM hidden states in Tensorflow 2
Getting control using a stateful and stateless LSTM.
Tensorflow 2 is currently in alpha, which means the old ways to do things have changed. I’m working on a project where I want fine grained control of the hidden state of an LSTM layer.
After a bit of hacking around I settled on the solution below (note - the TF 2.0 docs say that you should be able to pass an initial_state
when calling the layer - I couldn’t get this to work).
Using a stateful LSTM
This solution requires using a stateful LSTM - stateful here means that the final states of batch i
will be used as the initial states of batch i+1
. Often this isn’t the behaviour that we want (when training each batch is independent of other batches) but it is required to be able to call tf.keras.layers.RNN().reset_states(state)
.
Having a stateful LSTM means that you will need to reset the hidden state in between batches yourself if you do want independent batches. The default initial hidden state in Tensorflow is all zeros.
First let’s setup a simple, single layer LSTM with a fully connected output layer. I use tf.keras.Model
rather than tf.keras.Sequential
so that I can have multiple outputs (i.e. so I can access the hidden state after a forward pass):
import numpy as np
import tensorflow as tf
np.random.seed(42)
tf.random.set_seed(42)
input_dim = 3
output_dim = 3
num_timesteps = 2
batch_size = 10
nodes = 10
input_layer = tf.keras.Input(shape=(num_timesteps, input_dim), batch_size=batch_size)
cell = tf.keras.layers.LSTMCell(
nodes,
kernel_initializer='glorot_uniform',
recurrent_initializer='glorot_uniform',
bias_initializer='zeros',
)
lstm = tf.keras.layers.RNN(
cell,
return_state=True,
return_sequences=True,
stateful=True,
)
lstm_out, hidden_state, cell_state = lstm(input_layer)
output = tf.keras.layers.Dense(output_dim)(lstm_out)
mdl = tf.keras.Model(
inputs=input_layer,
outputs=[hidden_state, cell_state, output]
)
We can now test what’s going on by passing a batch through the network (look Ma, no tf.Session
!):
x = np.random.rand(batch_size, num_timesteps, input_dim).astype(np.float32)
h_state, c_state, out = mdl(x)
print(np.mean(out))
-0.011644869
If we pass this same batch again, we get different result as the hidden state has been changed:
h_state, c_state, out = mdl(x)
print(np.mean(out))
-0.015350263
If we reset the hidden state, we can recover our initial output:
lstm.reset_states(states=[np.zeros((batch_size, nodes)), np.zeros((batch_size, nodes))])
h_state, c_state, out = mdl(x)
print(np.mean(out))
-0.011644869
This method also allows us to use other values than all zeros for the hidden state:
lstm.reset_states(states=[np.ones((batch_size, nodes)), np.ones((batch_size, nodes))])
h_state, c_state, out = mdl(x)
print(np.mean(out))
-0.21755001
Using a non-stateful LSTM
One major downside of using a stateful LSTM is that you are forced to use the same batch sizes when doing forward and backward passes. I wanted the ability to pass single sample through the LSTM as well as being able to train in batches.
This method actually overrides one of the functions used internally in Tensorflow (tf.keras.layers.LSTMCell().get_initial_state
). I felt a bit dirty doing this but whenever I tried to pass the states through in the call
I got a TypeError: call() got an unexpected keyword argument 'states'
.
import numpy as np
import tensorflow as tf
np.random.seed(42)
tf.random.set_seed(42)
class Model():
def __init__(self):
cell = tf.keras.layers.LSTMCell(
nodes,
kernel_initializer='glorot_uniform',
recurrent_initializer='glorot_uniform',
bias_initializer='zeros',
)
self.lstm = tf.keras.layers.RNN(
cell,
return_state=True,
return_sequences=True,
stateful=False,
)
lstm_out, hidden_state, cell_state = self.lstm(input_layer)
output = tf.keras.layers.Dense(output_dim)(lstm_out)
self.net = tf.keras.Model(inputs=input_layer, outputs=[output, hidden_state, cell_state])
def get_zero_initial_state(self, inputs):
return [tf.zeros((batch_size, nodes)), tf.zeros((batch_size, nodes))]
def get_initial_state(self, inputs):
return self.initial_state
def __call__(self, inputs, states=None):
if states is None:
self.lstm.get_initial_state = self.get_zero_initial_state
else:
self.initial_state = states
self.lstm.get_initial_state = self.get_initial_state
return self.net(inputs, states)
So does this work? Let’s generate another batch, this time a single sample:
mdl = Model()
x = np.random.rand(1, num_timesteps, input_dim).astype(np.float32)
out, hidden_state, cell_state = mdl(x)
print(np.mean(out))
0.00057914766
Unlike a stateful LSTM, if we try this again we get the same result:
out, hidden_state, cell_state = mdl(x)
np.mean(out)
0.00057914766
And most importantly, we gain the ability to control the initial state for the sequence:
out, hidden_state, cell_state = mdl(x, states=[tf.ones((1, nodes)), tf.ones((1, nodes))] )
np.mean(out)
0.25189233
Thanks for reading!