Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Tworzenie nowych warstw & amp; Modele poprzez podklasy

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ustawiać

 import tensorflow as tf
from tensorflow import keras
 

Klasa Layer : połączenie stanu (wagi) i niektórych obliczeń

Jedną z głównych abstrakcji w Keras jest klasa Layer . Warstwa hermetyzuje zarówno stan („wagi” warstwy), jak i transformację danych wejściowych do wyjściowych („wywołanie”, przejście warstwy do przodu).

Oto gęsto połączona warstwa. Ma stan: zmienne w i b .

 class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"),
            trainable=True,
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

 

Możesz użyć warstwy, wywołując ją na niektórych wejściach tensora, podobnie jak funkcja Pythona.

 x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)
 
tf.Tensor(
[[-0.00892124  0.03003723  0.01141541 -0.13389507]
 [-0.00892124  0.03003723  0.01141541 -0.13389507]], shape=(2, 4), dtype=float32)

Zwróć uwagę, że wagi w i b są automatycznie śledzone przez warstwę po ustawieniu ich jako atrybutów warstwy:

 assert linear_layer.weights == [linear_layer.w, linear_layer.b]
 

Zauważ, że masz również dostęp do szybszego skrótu do dodawania wagi do warstwy: metoda add_weight() :

 class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        self.w = self.add_weight(
            shape=(input_dim, units), initializer="random_normal", trainable=True
        )
        self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b


x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)
 
tf.Tensor(
[[-0.01266684  0.01941528 -0.09573359  0.03471692]
 [-0.01266684  0.01941528 -0.09573359  0.03471692]], shape=(2, 4), dtype=float32)

Warstwy mogą mieć ciężary, których nie można trenować

Oprócz ciężarków, których można trenować, do warstwy można również dodać ciężarki, których nie można trenować. Takie ciężary nie powinny być brane pod uwagę podczas wstecznej propagacji, kiedy trenujesz warstwę.

Oto jak dodać i użyć ciężaru, którego nie można trenować:

 class ComputeSum(keras.layers.Layer):
    def __init__(self, input_dim):
        super(ComputeSum, self).__init__()
        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)

    def call(self, inputs):
        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
        return self.total


x = tf.ones((2, 2))
my_sum = ComputeSum(2)
y = my_sum(x)
print(y.numpy())
y = my_sum(x)
print(y.numpy())
 
[2. 2.]
[4. 4.]

Jest częścią layer.weights , ale zostaje sklasyfikowana jako layer.weights , której nie można layer.weights :

 print("weights:", len(my_sum.weights))
print("non-trainable weights:", len(my_sum.non_trainable_weights))

# It's not included in the trainable weights:
print("trainable_weights:", my_sum.trainable_weights)
 
weights: 1
non-trainable weights: 1
trainable_weights: []

Najlepsza praktyka: odraczanie tworzenia masy do czasu poznania kształtu nakładów

Nasza warstwa Linear powyżej przyjęła argument input_dim który został użyty do obliczenia kształtu wag w i b w __init__() :

 class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        self.w = self.add_weight(
            shape=(input_dim, units), initializer="random_normal", trainable=True
        )
        self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

 

W wielu przypadkach możesz nie znać z góry rozmiaru danych wejściowych i chciałbyś leniwie tworzyć wagi, gdy ta wartość stanie się znana, jakiś czas po utworzeniu wystąpienia warstwy.

W API Keras zalecamy tworzenie wag warstw w metodzie build(self, inputs_shape) Twojej warstwy. Lubię to:

 class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

 

__call__() Twojej warstwy uruchomi się automatycznie przy jej pierwszym wywołaniu. Masz teraz warstwę, która jest leniwa i dzięki temu łatwiejsza w użyciu:

 # At instantiation, we don't know on what inputs this is going to get called
linear_layer = Linear(32)

# The layer's weights are created dynamically the first time the layer is called
y = linear_layer(x)
 

Warstwy są rekurencyjne

Jeśli przypiszesz wystąpienie warstwy jako atrybut innej warstwy, warstwa zewnętrzna zacznie śledzić wagi warstwy wewnętrznej.

Zalecamy tworzenie takich podwarstw w __init__() (ponieważ podwarstwy zazwyczaj będą miały metodę budowania, zostaną zbudowane, gdy warstwa zewnętrzna zostanie zbudowana).

 # Let's assume we are reusing the Linear class
# with a `build` method that we defined above.


class MLPBlock(keras.layers.Layer):
    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(1)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        return self.linear_3(x)


mlp = MLPBlock()
y = mlp(tf.ones(shape=(3, 64)))  # The first call to the `mlp` will create the weights
print("weights:", len(mlp.weights))
print("trainable weights:", len(mlp.trainable_weights))
 
weights: 6
trainable weights: 6

Metoda add_loss()

Pisząc metodę call() warstwy, możesz utworzyć tensory strat, których będziesz chciał użyć później, podczas pisania pętli treningowej. Można to zrobić, wywołując self.add_loss(value) :

 # A layer that creates an activity regularization loss
class ActivityRegularizationLayer(keras.layers.Layer):
    def __init__(self, rate=1e-2):
        super(ActivityRegularizationLayer, self).__init__()
        self.rate = rate

    def call(self, inputs):
        self.add_loss(self.rate * tf.reduce_sum(inputs))
        return inputs

 

Straty te (w tym te utworzone przez dowolną warstwę wewnętrzną) można odzyskać za pomocą funkcji layer.losses . Ta właściwość jest resetowana na początku każdego __call__() do warstwy najwyższego poziomu, tak że layer.losses zawsze zawiera wartości strat utworzone podczas ostatniego przejścia w przód.

 class OuterLayer(keras.layers.Layer):
    def __init__(self):
        super(OuterLayer, self).__init__()
        self.activity_reg = ActivityRegularizationLayer(1e-2)

    def call(self, inputs):
        return self.activity_reg(inputs)


layer = OuterLayer()
assert len(layer.losses) == 0  # No losses yet since the layer has never been called

_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1  # We created one loss value

# `layer.losses` gets reset at the start of each __call__
_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1  # This is the loss created during the call above
 

Ponadto właściwość loss zawiera również straty regularyzacyjne utworzone dla wag dowolnej warstwy wewnętrznej:

 class OuterLayerWithKernelRegularizer(keras.layers.Layer):
    def __init__(self):
        super(OuterLayerWithKernelRegularizer, self).__init__()
        self.dense = keras.layers.Dense(
            32, kernel_regularizer=tf.keras.regularizers.l2(1e-3)
        )

    def call(self, inputs):
        return self.dense(inputs)


layer = OuterLayerWithKernelRegularizer()
_ = layer(tf.zeros((1, 1)))

# This is `1e-3 * sum(layer.dense.kernel ** 2)`,
# created by the `kernel_regularizer` above.
print(layer.losses)
 
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0019264814>]

Te straty należy wziąć pod uwagę podczas pisania pętli treningowych, takich jak:

 # Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Iterate over the batches of a dataset.
for x_batch_train, y_batch_train in train_dataset:
  with tf.GradientTape() as tape:
    logits = layer(x_batch_train)  # Logits for this minibatch
    # Loss value for this minibatch
    loss_value = loss_fn(y_batch_train, logits)
    # Add extra losses created during this forward pass:
    loss_value += sum(model.losses)

  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
 

Aby uzyskać szczegółowy przewodnik dotyczący pisania pętli treningowych, zobacz przewodnik dotyczący pisania pętli szkoleniowej od podstaw .

Straty te działają również płynnie z fit() (są automatycznie sumowane i dodawane do głównej straty, jeśli występuje):

 import numpy as np

inputs = keras.Input(shape=(3,))
outputs = ActivityRegularizationLayer()(inputs)
model = keras.Model(inputs, outputs)

# If there is a loss passed in `compile`, thee regularization
# losses get added to it
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# It's also possible not to pass any loss in `compile`,
# since the model already has a loss to minimize, via the `add_loss`
# call during the forward pass!
model.compile(optimizer="adam")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
 
1/1 [==============================] - 0s 1ms/step - loss: 0.2169
1/1 [==============================] - 0s 875us/step - loss: 0.0396

<tensorflow.python.keras.callbacks.History at 0x7fa7639fd6a0>

Metoda add_metric()

Podobnie jak w przypadku add_loss() , warstwy mają również metodę add_metric() do śledzenia średniej ruchomej ilości podczas uczenia.

Rozważmy następującą warstwę: warstwę „logistycznego punktu końcowego”. Przyjmuje jako dane wejściowe prognozy i cele, oblicza stratę, którą śledzi za pomocą add_loss() , i oblicza skalar dokładności, który śledzi za pomocą add_metric() .

 class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()

    def call(self, targets, logits, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        loss = self.loss_fn(targets, logits, sample_weights)
        self.add_loss(loss)

        # Log accuracy as a metric and add it
        # to the layer using `self.add_metric()`.
        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_metric(acc, name="accuracy")

        # Return the inference-time prediction tensor (for `.predict()`).
        return tf.nn.softmax(logits)

 

Metryki śledzone w ten sposób są dostępne poprzez layer.metrics :

 layer = LogisticEndpoint()

targets = tf.ones((2, 2))
logits = tf.ones((2, 2))
y = layer(targets, logits)

print("layer.metrics:", layer.metrics)
print("current accuracy value:", float(layer.metrics[0].result()))
 
layer.metrics: [<tensorflow.python.keras.metrics.BinaryAccuracy object at 0x7fa7f03601d0>]
current accuracy value: 1.0

Podobnie jak w przypadku add_loss() , te metryki są śledzone przez fit() :

 inputs = keras.Input(shape=(3,), name="inputs")
targets = keras.Input(shape=(10,), name="targets")
logits = keras.layers.Dense(10)(inputs)
predictions = LogisticEndpoint(name="predictions")(logits, targets)

model = keras.Model(inputs=[inputs, targets], outputs=predictions)
model.compile(optimizer="adam")

data = {
    "inputs": np.random.random((3, 3)),
    "targets": np.random.random((3, 10)),
}
model.fit(data)
 
1/1 [==============================] - 0s 1ms/step - loss: 0.9958 - binary_accuracy: 0.0000e+00

<tensorflow.python.keras.callbacks.History at 0x7fa7639ffeb8>

Opcjonalnie możesz włączyć serializację na swoich warstwach

Jeśli chcesz, aby warstwy niestandardowe były możliwe do serializacji w ramach modelu funkcjonalnego , możesz opcjonalnie zaimplementować get_config() :

 class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        return {"units": self.units}


# Now you can recreate the layer from its config:
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
 
{'units': 64}

Zwróć uwagę, że __init__() klasy bazowej Layer pobiera pewne argumenty słów kluczowych, w szczególności name i dtype . Dobrą praktyką jest przekazanie tych argumentów do klasy nadrzędnej w __init__() i uwzględnienie ich w konfiguracji warstwy:

 class Linear(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({"units": self.units})
        return config


layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
 
{'name': 'linear_8', 'trainable': True, 'dtype': 'float32', 'units': 64}

Jeśli potrzebujesz większej elastyczności podczas deserializacji warstwy z jej konfiguracji, możesz również zastąpić metodę klasy from_config() . To jest podstawowa implementacja from_config() :

 def from_config(cls, config):
  return cls(**config)
 

Aby dowiedzieć się więcej o serializacji i zapisywaniu, zobacz pełny przewodnik dotyczący zapisywania i serializacji modeli .

Uprzywilejowany argument training w metodzie call()

Niektóre warstwy, w szczególności warstwa BatchNormalization i warstwa Dropout , zachowują się inaczej podczas uczenia i wnioskowania. W przypadku takich warstw standardową praktyką jest ujawnianie argumentu training (boolowskiego) w metodzie call() .

Ujawniając ten argument w call() , umożliwiasz wbudowanym pętlom uczenia i oceny (np. fit() ) poprawne używanie warstwy podczas uczenia i wnioskowania.

 class CustomDropout(keras.layers.Layer):
    def __init__(self, rate, **kwargs):
        super(CustomDropout, self).__init__(**kwargs)
        self.rate = rate

    def call(self, inputs, training=None):
        if training:
            return tf.nn.dropout(inputs, rate=self.rate)
        return inputs

 

Argument mask uprzywilejowanej w metodzie call()

Innym uprzywilejowanym argumentem obsługiwanym przez call() jest argument mask .

Znajdziesz go we wszystkich warstwach Keras RNN. Maska jest logicznym tensorem (jedna wartość logiczna na krok czasu na wejściu) używanym do pomijania pewnych wejściowych kroków czasowych podczas przetwarzania danych z serii czasu.

Keras automatycznie przekaże prawidłowy argument mask do __call__() dla warstw, które ją obsługują, gdy maska ​​jest generowana przez poprzednią warstwę. Warstwy generujące maskę to warstwa Embedding skonfigurowana z wartością mask_zero=True i warstwa Masking .

Aby dowiedzieć się więcej na temat maskowania i pisania warstw z włączoną funkcją maskowania, zapoznaj się z przewodnikiem „Zrozumienie wypełniania i maskowania” .

Klasa Model

Ogólnie rzecz biorąc, będziesz używać klasy Layer do definiowania wewnętrznych bloków obliczeniowych, a klasy Model do definiowania modelu zewnętrznego - obiektu, który będziesz szkolić.

Na przykład w modelu ResNet50 miałbyś kilka bloków ResNet podklasowych Layer i jeden Model obejmujący całą sieć ResNet50.

Klasa Model ma ten sam interfejs API co Layer , z następującymi różnicami:

  • Udostępnia wbudowane pętle treningowe, oceny i przewidywania ( model.fit() , model.evaluate() , model.predict() ).
  • model.layers listę swoich wewnętrznych warstw za pośrednictwem właściwości model.layers .
  • Udostępnia interfejsy API zapisywania i serializacji ( save() , save_weights() ...)

W rzeczywistości klasa Layer odpowiada temu, co w literaturze nazywamy „warstwą” (jak „warstwa splotu” lub „warstwa rekurencyjna”) lub „blokiem” (jak „blok ResNet” lub „blok początkowy”) ).

Tymczasem klasa Model odpowiada temu, co w literaturze określa się mianem „modelu” (jak w „modelu głębokiego uczenia się”) lub jako „sieci” (jak w „głębokiej sieci neuronowej”).

Więc jeśli zastanawiasz się, „czy powinienem używać klasy Layer czy Model ?”, Zadaj sobie pytanie: czy będę musiał wywoływać fit() ? Czy będę musiał wywołać save() ? Jeśli tak, przejdź do Model . Jeśli nie (albo dlatego, że twoja klasa jest tylko blokiem w większym systemie, albo dlatego, że sam piszesz szkolenie i zapisujesz kod), użyj Layer .

Na przykład moglibyśmy wziąć powyższy przykład mini-resnet i użyć go do zbudowania Model , który moglibyśmy wytrenować za pomocą fit() i który moglibyśmy zapisać za pomocą save_weights() :

 class ResNet(tf.keras.Model):

    def __init__(self):
        super(ResNet, self).__init__()
        self.block_1 = ResNetBlock()
        self.block_2 = ResNetBlock()
        self.global_pool = layers.GlobalAveragePooling2D()
        self.classifier = Dense(num_classes)

    def call(self, inputs):
        x = self.block_1(inputs)
        x = self.block_2(x)
        x = self.global_pool(x)
        return self.classifier(x)


resnet = ResNet()
dataset = ...
resnet.fit(dataset, epochs=10)
resnet.save(filepath)
 

Wszystko razem: kompleksowy przykład

Oto, czego się do tej pory nauczyłeś:

  • Layer hermetyzuje stan (utworzony w __init__() lub build() ) i niektóre obliczenia (zdefiniowane w call() ).
  • Warstwy można zagnieżdżać rekurencyjnie, aby tworzyć nowe, większe bloki obliczeniowe.
  • Warstwy mogą tworzyć i śledzić straty (zwykle straty regularyzacyjne), a także metryki za pomocą add_loss() i add_metric()
  • Zewnętrzny pojemnik, rzecz, którą chcesz trenować, to Model . Model jest podobny do Layer , ale z dodatkowymi narzędziami do szkolenia i serializacji.

Połączmy wszystkie te rzeczy w całościowy przykład: zamierzamy zaimplementować Variational AutoEncoder (VAE). Będziemy trenować na cyfrach MNIST.

Nasz VAE będzie podklasą Model , zbudowaną jako zagnieżdżona kompozycja warstw, która jest podklasą Layer . Będzie charakteryzował się utratą regularyzacji (dywergencja KL).

 from tensorflow.keras import layers


class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class Encoder(layers.Layer):
    """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""

    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z


class Decoder(layers.Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim, activation="sigmoid")

    def call(self, inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)


class VariationalAutoEncoder(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(
        self,
        original_dim,
        intermediate_dim=64,
        latent_dim=32,
        name="autoencoder",
        **kwargs
    ):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed

 

Napiszmy prostą pętlę szkoleniową na MNIST:

 original_dim = 784
vae = VariationalAutoEncoder(original_dim, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

loss_metric = tf.keras.metrics.Mean()

(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

epochs = 2

# Iterate over epochs.
for epoch in range(epochs):
    print("Start of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_batch_train)
            # Compute reconstruction loss
            loss = mse_loss_fn(x_batch_train, reconstructed)
            loss += sum(vae.losses)  # Add KLD regularization loss

        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))

        loss_metric(loss)

        if step % 100 == 0:
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))
 
Start of epoch 0
step 0: mean loss = 0.3052
step 100: mean loss = 0.1252
step 200: mean loss = 0.0990
step 300: mean loss = 0.0890
step 400: mean loss = 0.0841
step 500: mean loss = 0.0808
step 600: mean loss = 0.0787
step 700: mean loss = 0.0771
step 800: mean loss = 0.0759
step 900: mean loss = 0.0749
Start of epoch 1
step 0: mean loss = 0.0746
step 100: mean loss = 0.0740
step 200: mean loss = 0.0735
step 300: mean loss = 0.0730
step 400: mean loss = 0.0727
step 500: mean loss = 0.0723
step 600: mean loss = 0.0720
step 700: mean loss = 0.0717
step 800: mean loss = 0.0714
step 900: mean loss = 0.0712

Zwróć uwagę, że ponieważ VAE jest Model podklasy, ma wbudowane pętle szkoleniowe. Więc możesz to również wyszkolić w ten sposób:

 vae = VariationalAutoEncoder(784, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=2, batch_size=64)
 
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0749
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0676

<tensorflow.python.keras.callbacks.History at 0x7fa7583261d0>

Poza programowaniem obiektowym: funkcjonalny interfejs API

Czy ten przykład był dla Ciebie zbyt intensywny? Możesz również tworzyć modele przy użyciu funkcjonalnego interfejsu API . Co ważne, wybór takiego czy innego stylu nie przeszkadza Ci w wykorzystaniu komponentów napisanych w innym stylu: zawsze możesz mieszać i dopasowywać.

Na przykład poniższy przykład funkcjonalnego interfejsu API ponownie wykorzystuje tę samą warstwę Sampling , którą zdefiniowaliśmy w powyższym przykładzie:

 original_dim = 784
intermediate_dim = 64
latent_dim = 32

# Define encoder model.
original_inputs = tf.keras.Input(shape=(original_dim,), name="encoder_input")
x = layers.Dense(intermediate_dim, activation="relu")(original_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()((z_mean, z_log_var))
encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name="encoder")

# Define decoder model.
latent_inputs = tf.keras.Input(shape=(latent_dim,), name="z_sampling")
x = layers.Dense(intermediate_dim, activation="relu")(latent_inputs)
outputs = layers.Dense(original_dim, activation="sigmoid")(x)
decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name="decoder")

# Define VAE model.
outputs = decoder(z)
vae = tf.keras.Model(inputs=original_inputs, outputs=outputs, name="vae")

# Add KL divergence regularization loss.
kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
vae.add_loss(kl_loss)

# Train.
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=3, batch_size=64)
 
Epoch 1/3
938/938 [==============================] - 2s 2ms/step - loss: 0.0751
Epoch 2/3
938/938 [==============================] - 2s 2ms/step - loss: 0.0676
Epoch 3/3
938/938 [==============================] - 2s 2ms/step - loss: 0.0676

<tensorflow.python.keras.callbacks.History at 0x7fa7580f1668>

Aby uzyskać więcej informacji, przeczytaj przewodnik po funkcjonalnym interfejsie API .