blob: 8dd2a786155152066da18cd3d36e9540aec034d9 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple LSTM layer with benchmarks.
This sets up a simple LSTM (Long Short Term Memory) layer, unrolled to a fixed
length sequence. The only deviation from standard LSTM cells is that
activations are clipped, inspired by the GNMT machine translation model.
The GNMT paper has more details: https://arxiv.org/abs/1609.08144
"""
from six.moves import range
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_v1
def Clip(x):
"""Clips x to the range [-1., 1.]."""
return math_ops.maximum(math_ops.minimum(x, 1.), -1.)
def LSTMCellWeightsShape(num_inputs, num_nodes):
"""Returns the shape of the weights for a single LSTM cell."""
# Dimension 0 accounts for combining x with the previous m state.
# Dimension 1 accounts for the in value and the (in, forget, out) gates.
return [num_inputs + num_nodes, 4 * num_nodes]
def LSTMCell(weights, m_prev, c_prev, x, pad):
"""Unrolls a single LSTM cell with clipped activations forward by one step.
Args:
weights: Weight matrix with shape LSTMCellWeightsShape.
m_prev: Previous m states with shape [batch_size, num_nodes].
c_prev: Previous c states with shape [batch_size, num_nodes].
x: Input with shape [batch_size, num_inputs].
pad: Padding with shape [batch_size, 1]. Each padding value is either
0 or 1, where 1 indicates padding; i.e. the input is shorter than the
sequence length, and the (m, c) states should simply be passed through
from the previous states.
Returns:
The next (m, c) states, each with shape [batch_size, num_nodes].
"""
# Apply weights to the input and previous hidden state.
# The matmul here is the "big" operation.
xm = array_ops.concat([x, m_prev], 1)
xmw = math_ops.matmul(xm, weights)
# Element-wise ops for the standard LSTM cell, with clipped activations.
# XLA can fuse these operations into a single loop.
in_value, in_gate, forget_gate, out_gate = array_ops.split(
value=xmw, num_or_size_splits=4, axis=1)
in_value = math_ops.tanh(in_value)
in_gate = math_ops.sigmoid(in_gate)
forget_gate = math_ops.sigmoid(forget_gate)
out_gate = math_ops.sigmoid(out_gate)
c_next = Clip(Clip(forget_gate * c_prev) + Clip(in_gate * in_value))
m_next = Clip(out_gate * c_next)
# Account for padding.
c_next = c_prev * pad + c_next * (1.0 - pad)
m_next = m_prev * pad + m_next * (1.0 - pad)
return m_next, c_next
def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
"""Unrolls a layer of LSTM cells forward by the sequence length.
The sequence length is determined by the length of x_seq and pad_seq, which
must be the same.
Args:
cell_name: Base name of each cell.
weights: Weight matrix with shape LSTMCellWeightsShape.
m: Initial m states with shape [batch_size, num_nodes].
c: Initial c states with shape [batch_size, num_nodes].
x_seq: List of inputs, each with shape [batch_size, num_inputs].
The length of the list is the sequence length.
pad_seq: List of paddings, each with shape [batch_size, 1].
The length of the list is the sequence length.
Each padding value is either 0 or 1, where 1 indicates padding;
i.e. the input is shorter than the sequence length.
Returns:
List of per-sequence-step outputs, each with shape [batch_size, num_nodes].
Raises:
ValueError: If len(x_seq) != len(pad_seq).
"""
if len(x_seq) != len(pad_seq):
raise ValueError('length of x_seq(%d) != pad_seq(%d)' %
(len(x_seq), len(pad_seq)))
out_seq = []
for seq in range(len(x_seq)):
with ops.name_scope('%s_%d' % (cell_name, seq)):
m, c = LSTMCell(weights, m, c, x_seq[seq], pad_seq[seq])
out_seq.append(array_ops.identity(m, name='out'))
return out_seq
def RandomVar(shape, name=None):
"""Returns a variable of the given shape initialized to random values."""
return variable_v1.VariableV1(
random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)
def RandomInputs(batch_size, seq_length, num_inputs):
"""Returns randomly initialized (x_seq, pad_seq) sequences."""
x_seq = []
pad_seq = []
with ops.name_scope('inputs'):
for seq in range(seq_length):
x_seq.append(RandomVar([batch_size, num_inputs], name='x_seq_%d' % seq))
# Real padding values are always a sequence of 0 followed by a
# sequence of 1, but random values are fine for benchmarking.
pad_seq.append(RandomVar([batch_size, 1], name='pad_seq_%d' % seq))
return x_seq, pad_seq
def BuildLSTMLayer(batch_size, seq_length, num_inputs, num_nodes):
"""Builds a single LSTM layer with random weights and inputs.
Args:
batch_size: Inputs are fed in batches of this size.
seq_length: The sequence length to unroll the LSTM layer.
num_inputs: Dimension of inputs that are fed into each LSTM cell.
num_nodes: The number of nodes in each LSTM cell.
Returns:
(out_seq, weights) pair. The out_seq is a list of per-sequence-step
outputs, each with shape [batch_size, num_nodes]. The weights are a list of
weight variables that may be trained.
"""
weights = RandomVar(
LSTMCellWeightsShape(num_inputs, num_nodes), name='weights')
m = array_ops.zeros([batch_size, num_nodes], name='init_m')
c = array_ops.zeros([batch_size, num_nodes], name='init_c')
x_seq, pad_seq = RandomInputs(batch_size, seq_length, num_inputs)
out_seq = LSTMLayer('lstm', weights, m, c, x_seq, pad_seq)
return out_seq, [weights]