|View source on GitHub|
RNN specified by RNNCell
cell and loop function
tf.compat.v1.nn.raw_rnn( cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None )
This function is a more primitive version of
dynamic_rnn that provides
more direct access to the inputs each iteration. It also provides more
control over when to start and finish reading the sequence, and
what to emit for the output.
For example, it can be used to implement the dynamic decoder of a seq2seq model.
Instead of working with
Tensor objects, most operations work with
TensorArray objects directly.
The operation of
raw_rnn, in pseudo-code, is basically the following:
time = tf.constant(0, dtype=tf.int32) (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn( time=time, cell_output=None, cell_state=None, loop_state=None) emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) state = initial_state while not all(finished): (output, cell_state) = cell(next_input, state) (next_finished, next_input, next_state, emit, loop_state) = loop_fn( time=time + 1, cell_output=output, cell_state=cell_state, loop_state=loop_state) # Emit zeros and copy forward state for minibatch entries that are finished. state = tf.where(finished, state, next_state) emit = tf.where(finished, tf.zeros_like(emit_structure), emit) emit_ta = emit_ta.write(time, emit) # If any new minibatch entries are marked as finished, mark these. finished = tf.logical_or(finished, next_finished) time += 1 return (emit_ta, state, loop_state)
with the additional properties that output and state may be (possibly nested)
tuples, as determined by
as a result the final
emit_ta may themselves be tuples.
A simple implementation of
raw_rnn looks like this:
inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth), dtype=tf.float32) sequence_length = tf.compat.v1.placeholder(shape=(batch_size,), dtype=tf.int32) inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) inputs_ta = inputs_ta.unstack(inputs) cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) def loop_fn(time, cell_output, cell_state, loop_state): emit_output = cell_output # == None for time == 0 if cell_output is None: # time == 0 next_cell_state = cell.zero_state(batch_size, tf.float32) else: next_cell_state = cell_state elements_finished = (time >= sequence_length) finished = tf.reduce_all(elements_finished) next_input = tf.cond( finished, lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), lambda: inputs_ta.read(time)) next_loop_state = None return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state) outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) outputs = outputs_ta.stack()
||An instance of RNNCell.|
A callable that takes inputs
||(Default: 32). The number of iterations to run in parallel. Those operations which do not have any temporal dependency and can be run in parallel, will be. This parameter trades off time for space. Values >> 1 use more memory but take less time, while smaller values use less memory but computations take longer.|
||Transparently swap the tensors produced in forward inference but needed for back prop from GPU to CPU. This allows training RNNs which would typically not fit on a single GPU, with very minimal (or no) performance penalty.|
||VariableScope for the created subgraph; defaults to "rnn".|
A tuple |