ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

GRUBlockCell

public final class GRUBlockCell

Computes the GRU cell forward propagation for 1 time step.

Args x: Input to the GRU cell. h_prev: State input from the previous GRU cell. w_ru: Weight matrix for the reset and update gate. w_c: Weight matrix for the cell connection gate. b_ru: Bias vector for the reset and update gate. b_c: Bias vector for the cell connection gate.

Returns r: Output of the reset gate. u: Output of the update gate. c: Output of the cell connection gate. h: Current state of the GRU cell.

Note on notation of the variables:

Concatenation of a and b is represented by a_b Element-wise dot product of a and b is represented by ab Element-wise dot product is represented by \circ Matrix multiplication is represented by *

Biases are initialized with : `b_ru` - constant_initializer(1.0) `b_c` - constant_initializer(0.0)

This kernel op implements the following mathematical equations:

x_h_prev = [x, h_prev]
 
 [r_bar u_bar] = x_h_prev * w_ru + b_ru
 
 r = sigmoid(r_bar)
 u = sigmoid(u_bar)
 
 h_prevr = h_prev \circ r
 
 x_h_prevr = [x h_prevr]
 
 c_bar = x_h_prevr * w_c + b_c
 c = tanh(c_bar)
 
 h = (1-u) \circ c + u \circ h_prev
 

Public Methods

Output<T>
c()
static <T extends Number> GRUBlockCell<T>
create(Scope scope, Operand<T> x, Operand<T> hPrev, Operand<T> wRu, Operand<T> wC, Operand<T> bRu, Operand<T> bC)
Factory method to create a class wrapping a new GRUBlockCell operation.
Output<T>
h()
Output<T>
r()
Output<T>
u()

Inherited Methods

Public Methods

public Output<T> c ()

public static GRUBlockCell<T> create (Scope scope, Operand<T> x, Operand<T> hPrev, Operand<T> wRu, Operand<T> wC, Operand<T> bRu, Operand<T> bC)

Factory method to create a class wrapping a new GRUBlockCell operation.

Parameters
scope current scope
Returns
  • a new instance of GRUBlockCell

public Output<T> h ()

public Output<T> r ()

public Output<T> u ()