Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


Use this method to track nested keras models in a shim-decorated method.

This method can be used within a tf.keras.Layer's methods decorated by thetrack_tf1_style_variables shim, to additionally track inner keras Model objects created within the same method. The inner model's variables and losses will be accessible via the outer model's variables and losses attributes.

This enables tracking of inner keras models using TF2 behaviors, with minimal changes to existing TF1-style code.


class NestedLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  def build_model(self):
    inp = tf.keras.Input(shape=(5, 5))
    dense_layer = tf.keras.layers.Dense(
        10, name="dense", kernel_regularizer="l2",
    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
    return model

  def call(self, inputs):
    model = tf.compat.v1.keras.utils.get_or_create_layer(
        "dense_model", self.build_model)
    return model(inputs)

The inner model creation should be confined to its own zero-arg function, which should be passed into this method. In TF1, this method will immediately create and return the desired model, without any tracking.

name A name to give the nested layer to track.
create_layer_method a Callable that takes no args and returns the nested layer.

The created layer.