ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Base neural network module class.

Used in the notebooks

Used in the guide Used in the tutorials

A module is a named container for tf.Variables, other tf.Modules and functions which apply to user input. For example a dense layer in a neural network might be implemented as a tf.Module:

class Dense(tf.Module):
  def __init__(self, input_dim, output_size, name=None):
    super(Dense, self).__init__(name=name)
    self.w = tf.Variable(
      tf.random.normal([input_dim, output_size]), name='w')
    self.b = tf.Variable(tf.zeros([output_size]), name='b')
  def __call__(self, x):
    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

You can use the Dense layer as you would expect:

d = Dense(input_dim=3, output_size=2)
d(tf.ones([1, 3]))
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)>

By subclassing tf.Module instead of object any tf.Variable or tf.Module instances assigned to object properties can be collected using the variables, trainable_variables or submodules property:

    (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=...,
    <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>)

Subclasses of tf.Module can also take advantage of the _flatten method which can be used to implement tracking of any other types.

All tf.Module classes have an associated