View source on GitHub |
An append-only sequence type which is trackable.
tf.contrib.checkpoint.List(
*args, **kwargs
)
Maintains checkpoint dependencies on its contents (which must also be
trackable), and forwards any Layer
metadata such as updates and losses.
Note that List
is purely a container. It lets a tf.keras.Model
or
other trackable object know about its contents, but does not call any
Layer
instances which are added to it. To indicate a sequence of Layer
instances which should be called sequentially, use tf.keras.Sequential
.
Example usage:
class HasList(tf.keras.Model):
def __init__(self):
super(HasList, self).__init__()
self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
self.layer_list.append(layers.Dense(4))
def call(self, x):
aggregation = 0.
for l in self.layer_list:
x = l(x)
aggregation += tf.reduce_sum(x)
return aggregation
This kind of wrapping is necessary because Trackable
objects do not
(yet) deeply inspect regular Python data structures, so for example assigning
a regular list (self.layer_list = [layers.Dense(3)]
) does not create a
checkpoint dependency and does not add the Layer
instance's weights to its
parent Model
.
Attributes | |
---|---|
layers
|
|
losses
|
Aggregate losses from any Layer instances.
|
non_trainable_variables
|
|
non_trainable_weights
|
|
trainable
|
|
trainable_variables
|
|
trainable_weights
|
|
updates
|
Aggregate updates from any Layer instances.
|
variables
|
|
weights
|
Methods
append
append(
value
)
Add a new trackable value.
copy
copy()
count
count(
value
)
S.count(value) -> integer -- return number of occurrences of value
extend
extend(
values
)
Add a sequence of trackable values.
index
index(
value, start=0, stop=None
)
S.index(value, [start, [stop]]) -> integer -- return first index of value. Raises ValueError if the value is not present.
Supporting start and stop arguments is optional, but recommended.
__add__
__add__(
other
)
__contains__
__contains__(
value
)
__eq__
__eq__(
other
)
Return self==value.
__getitem__
__getitem__(
key
)
__iter__
__iter__()
__len__
__len__()
__mul__
__mul__(
n
)
__radd__
__radd__(
other
)
__rmul__
__rmul__(
n
)