tf.contrib.checkpoint.List

View source on GitHub

An append-only sequence type which is trackable.

Maintains checkpoint dependencies on its contents (which must also be trackable), and forwards any Layer metadata such as updates and losses.

Note that List is purely a container. It lets a tf.keras.Model or other trackable object know about its contents, but does not call any Layer instances which are added to it. To indicate a sequence of Layer instances which should be called sequentially, use tf.keras.Sequential.

Example usage:

class HasList(tf.keras.Model):

  def __init__(self):
    super(HasList, self).__init__()
    self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
    self.layer_list.append(layers.Dense(4))

  def call(self, x):
    aggregation = 0.
    for l in self.layer_list:
      x = l(x)
      aggregation += tf.reduce_sum(x)
    return aggregation

This kind of wrapping is necessary because Trackable objects do not (yet) deeply inspect regular Python data structures, so for example assigning a regular list (self.layer_list = [layers.Dense(3)]) does not create a checkpoint dependency and does not add the Layer instance's weights to its parent Model.

layers

losses Aggregate losses from any Layer instances.
non_trainable_variables

non_trainable_weights

trainable

trainable_variables

trainable_weights

updates Aggregate updates from any Layer instances.
variables

weights

Methods

append

View source

Add a new trackable value.

copy

View source

count

S.count(value) -> integer -- return number of occurrences of value

extend

View source

Add a sequence of trackable values.

index

S.index(value, [start, [stop]]) -> integer -- return first index of value. Raises ValueError if the value is not present.

Supporting start and stop arguments is optional, but recommended.

__add__

View source

__contains__

__eq__

View source

Return self==value.

__getitem__

View source

__iter__

__len__

View source

__mul__

View source

__radd__

View source

__rmul__

View source