Tworzenie nowych warstw i modeli za pomocą podklas

Zadbaj o dobrą organizację dzięki kolekcji Zapisuj i kategoryzuj treści zgodnie ze swoimi preferencjami.

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ustawiać

import tensorflow as tf
from tensorflow import keras

Layer klasa: kombinacja państwowych (odważników) i niektóre obliczenia

Jednym z głównych abstrakcji w Keras jest Layer klasy. Warstwa zawiera zarówno stan ("wagi") warstwy, jak i transformację z danych wejściowych na dane wyjściowe ("wywołanie", przekazanie warstwy do przodu).

Oto gęsto połączona warstwa. Posiada Stan: zmienne w i b .

class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"),
            trainable=True,
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

Użyjesz warstwy, wywołując ją na niektórych wejściach tensorowych, podobnie jak funkcja Pythona.

x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)
tf.Tensor(
[[ 0.00962844 -0.01307489 -0.1452128   0.0538918 ]
 [ 0.00962844 -0.01307489 -0.1452128   0.0538918 ]], shape=(2, 4), dtype=float32)

Należy zauważyć, że obciążniki w i b są automatycznie śledzone przez warstwę na ustawionym jako atrybuty warstwowych:

assert linear_layer.weights == [linear_layer.w, linear_layer.b]

Zauważ też mają dostęp do szybszego skrótu do dodawania wagę do warstwy: na add_weight() metoda:

class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        self.w = self.add_weight(
            shape=(input_dim, units), initializer="random_normal", trainable=True
        )
        self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b


x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)
tf.Tensor(
[[ 0.05790994  0.060931   -0.0402256  -0.09450993]
 [ 0.05790994  0.060931   -0.0402256  -0.09450993]], shape=(2, 4), dtype=float32)

Warstwy mogą mieć ciężary, których nie można trenować

Oprócz ciężarów, które można trenować, możesz dodać do warstwy również ciężary, których nie można trenować. Takich ciężarów nie należy brać pod uwagę podczas propagacji wstecznej, gdy trenujesz warstwę.

Oto jak dodać i używać wagi, której nie można trenować:

class ComputeSum(keras.layers.Layer):
    def __init__(self, input_dim):
        super(ComputeSum, self).__init__()
        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)

    def call(self, inputs):
        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
        return self.total


x = tf.ones((2, 2))
my_sum = ComputeSum(2)
y = my_sum(x)
print(y.numpy())
y = my_sum(x)
print(y.numpy())
[2. 2.]
[4. 4.]

To część layer.weights , ale robi się sklasyfikować jako nie nadającego się do szkolenia wagi:

print("weights:", len(my_sum.weights))
print("non-trainable weights:", len(my_sum.non_trainable_weights))

# It's not included in the trainable weights:
print("trainable_weights:", my_sum.trainable_weights)
weights: 1
non-trainable weights: 1
trainable_weights: []

Najlepsza praktyka: odroczenie tworzenia wagi do czasu poznania kształtu danych wejściowych

Nasze Linear powyżej warstwy brał input_dim argumentu, który został użyty do obliczenia kształtu wag w i b w __init__() :

class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        self.w = self.add_weight(
            shape=(input_dim, units), initializer="random_normal", trainable=True
        )
        self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

W wielu przypadkach możesz nie znać z góry rozmiaru danych wejściowych i chciałbyś leniwie tworzyć wagi, gdy ta wartość stanie się znana, jakiś czas po utworzeniu instancji warstwy.

W API Keras, zalecamy utworzenie wagi warstwy w build(self, inputs_shape) metody swojej warstwie. Lubię to:

class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

__call__() metoda swojej warstwie automatycznie uruchomić zbudować pierwszy raz to jest tzw. Masz teraz warstwę, która jest leniwa, a przez to łatwiejsza w użyciu:

# At instantiation, we don't know on what inputs this is going to get called
linear_layer = Linear(32)

# The layer's weights are created dynamically the first time the layer is called
y = linear_layer(x)

Implementacja build() oddzielnie, jak przedstawiono powyżej, dobrze oddziela tworzenie masy tylko raz z użyciem wagi w każdym połączeniu. Jednak w przypadku niektórych zaawansowanych warstw niestandardowych rozdzielenie tworzenia stanu i obliczeń może stać się niepraktyczne. Realizatorzy warstwowe mogą odroczyć tworzenia masy do pierwszej __call__() , ale trzeba uważać, że późniejsze rozmowy używać tych samych ciężarów. Ponadto, ponieważ __call__() może być wykonany po raz pierwszy wewnątrz tf.function , tworzenie dowolnej zmiennej, która ma miejsce w __call__() powinny być zapakowane w tf.init_scope .

Warstwy można komponować rekurencyjnie

Jeśli przypiszesz instancję warstwy jako atrybut innej warstwy, warstwa zewnętrzna zacznie śledzić wagi utworzone przez warstwę wewnętrzną.

Zalecamy utworzenie takich podwarstwy w __init__() metoda i pozostawić go do pierwszego __call__() do spustu budują swoje ciężary.

class MLPBlock(keras.layers.Layer):
    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(1)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        return self.linear_3(x)


mlp = MLPBlock()
y = mlp(tf.ones(shape=(3, 64)))  # The first call to the `mlp` will create the weights
print("weights:", len(mlp.weights))
print("trainable weights:", len(mlp.trainable_weights))
weights: 6
trainable weights: 6

add_loss() Sposób

Pisząc call() metodę warstwą, można utworzyć tensory strat, które chcesz użyć później, pisząc swoją pętlę treningową. To jest wykonalne poprzez wywołanie self.add_loss(value) :

# A layer that creates an activity regularization loss
class ActivityRegularizationLayer(keras.layers.Layer):
    def __init__(self, rate=1e-2):
        super(ActivityRegularizationLayer, self).__init__()
        self.rate = rate

    def call(self, inputs):
        self.add_loss(self.rate * tf.reduce_sum(inputs))
        return inputs

Te straty (w tym te, które powstały przez każdą warstwę wewnętrzną) mogą być pobrane przez layer.losses . Ta właściwość jest wyzerowany na początku każdego __call__() do warstwy najwyższym poziomie, dzięki czemu layer.losses zawsze zawiera wartości straty powstałe podczas ostatniej przełęczy przodu.

class OuterLayer(keras.layers.Layer):
    def __init__(self):
        super(OuterLayer, self).__init__()
        self.activity_reg = ActivityRegularizationLayer(1e-2)

    def call(self, inputs):
        return self.activity_reg(inputs)


layer = OuterLayer()
assert len(layer.losses) == 0  # No losses yet since the layer has never been called

_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1  # We created one loss value

# `layer.losses` gets reset at the start of each __call__
_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1  # This is the loss created during the call above

Ponadto loss właściwość zawiera również straty utworzone dla legalizacji wag dowolnej warstwy wewnętrznej:

class OuterLayerWithKernelRegularizer(keras.layers.Layer):
    def __init__(self):
        super(OuterLayerWithKernelRegularizer, self).__init__()
        self.dense = keras.layers.Dense(
            32, kernel_regularizer=tf.keras.regularizers.l2(1e-3)
        )

    def call(self, inputs):
        return self.dense(inputs)


layer = OuterLayerWithKernelRegularizer()
_ = layer(tf.zeros((1, 1)))

# This is `1e-3 * sum(layer.dense.kernel ** 2)`,
# created by the `kernel_regularizer` above.
print(layer.losses)
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0024520475>]

Te straty mają być brane pod uwagę podczas pisania pętli treningowych, takich jak:

# Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Iterate over the batches of a dataset.
for x_batch_train, y_batch_train in train_dataset:
  with tf.GradientTape() as tape:
    logits = layer(x_batch_train)  # Logits for this minibatch
    # Loss value for this minibatch
    loss_value = loss_fn(y_batch_train, logits)
    # Add extra losses created during this forward pass:
    loss_value += sum(model.losses)

  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))

Dla szczegółowego przewodnika na temat pisania pętle treningowe, zobacz podręcznik do pisania pętlę szkolenia od podstaw .

Straty te również działają bezproblemowo z fit() (oni trafiają automatycznie sumowane i dodawane do głównej strat, jeśli występują):

import numpy as np

inputs = keras.Input(shape=(3,))
outputs = ActivityRegularizationLayer()(inputs)
model = keras.Model(inputs, outputs)

# If there is a loss passed in `compile`, the regularization
# losses get added to it
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# It's also possible not to pass any loss in `compile`,
# since the model already has a loss to minimize, via the `add_loss`
# call during the forward pass!
model.compile(optimizer="adam")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
1/1 [==============================] - 0s 209ms/step - loss: 0.1948
1/1 [==============================] - 0s 49ms/step - loss: 0.0298
<keras.callbacks.History at 0x7fce9052d290>

add_metric() Sposób

Podobnie jak w add_loss() , warstwy mają również add_metric() metodę śledzenia średnia ruchoma ilości podczas treningu.

Rozważ następującą warstwę: warstwa „logistycznego punktu końcowego”. Zajmuje jako wejścia przewidywania i cele, to oblicza straty który śledzi poprzez add_loss() , i oblicza się skalarne dokładności, który śledzi poprzez add_metric() .

class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()

    def call(self, targets, logits, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        loss = self.loss_fn(targets, logits, sample_weights)
        self.add_loss(loss)

        # Log accuracy as a metric and add it
        # to the layer using `self.add_metric()`.
        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_metric(acc, name="accuracy")

        # Return the inference-time prediction tensor (for `.predict()`).
        return tf.nn.softmax(logits)

Metryki śledzone w ten sposób są dostępne za pośrednictwem layer.metrics :

layer = LogisticEndpoint()

targets = tf.ones((2, 2))
logits = tf.ones((2, 2))
y = layer(targets, logits)

print("layer.metrics:", layer.metrics)
print("current accuracy value:", float(layer.metrics[0].result()))
layer.metrics: [<keras.metrics.BinaryAccuracy object at 0x7fce90578490>]
current accuracy value: 1.0

Podobnie jak dla add_loss() , dane te są śledzone przez fit() :

inputs = keras.Input(shape=(3,), name="inputs")
targets = keras.Input(shape=(10,), name="targets")
logits = keras.layers.Dense(10)(inputs)
predictions = LogisticEndpoint(name="predictions")(logits, targets)

model = keras.Model(inputs=[inputs, targets], outputs=predictions)
model.compile(optimizer="adam")

data = {
    "inputs": np.random.random((3, 3)),
    "targets": np.random.random((3, 10)),
}
model.fit(data)
1/1 [==============================] - 0s 274ms/step - loss: 0.9291 - binary_accuracy: 0.0000e+00
<keras.callbacks.History at 0x7fce90448c50>

Opcjonalnie możesz włączyć serializację na swoich warstwach

Jeśli potrzebujesz niestandardowych warstw być serializable jako część modelu funkcjonalnego , opcjonalnie można wdrożyć get_config() metodę:

class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        return {"units": self.units}


# Now you can recreate the layer from its config:
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
{'units': 64}

Należy zauważyć, że __init__() Sposób baza Layer klasy trwa kilka argumentów słów kluczowych, w szczególności name i dtype . Jest to dobra praktyka, aby przekazać te argumenty do klasy dominującej w __init__() i włączenie ich w config warstwy:

class Linear(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({"units": self.units})
        return config


layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
{'name': 'linear_8', 'trainable': True, 'dtype': 'float32', 'units': 64}

Jeśli potrzebujesz większej elastyczności przy deserializacji warstwę od jego konfiguracji, można również zastąpić from_config() metody klasy. To jest podstawa realizacja from_config() :

def from_config(cls, config):
  return cls(**config)

Aby dowiedzieć się więcej na temat serializacji i zapisywania zobaczyć kompletny przewodnik do oszczędzania i szeregowania modeli .

Uprzywilejowany training argumentem w call() metody

Kilka warstw, w szczególności BatchNormalization warstwa i Dropout warstwy mają różne zachowania podczas treningu i wnioskowania. Do takich warstw, to standardową praktyką w celu odsłonięcia training (typ logiczny) argument w call() sposobu.

Przez wystawienie tego argumentu w call() , włączyć wbudowaną pętle szkoleniowych i ewaluacji (np fit() ), aby prawidłowo zastosować warstwę w szkoleniu i wnioskowania.

class CustomDropout(keras.layers.Layer):
    def __init__(self, rate, **kwargs):
        super(CustomDropout, self).__init__(**kwargs)
        self.rate = rate

    def call(self, inputs, training=None):
        if training:
            return tf.nn.dropout(inputs, rate=self.rate)
        return inputs

Uprzywilejowane mask argumentem w call() sposobem

Drugim argumentem uprzywilejowany obsługiwane przez call() jest mask argumentem.

Znajdziesz go we wszystkich warstwach Keras RNN. Maska to tensor logiczny (jedna wartość logiczna na krok czasowy w danych wejściowych) używany do pomijania pewnych wejściowych przedziałów czasowych podczas przetwarzania danych szeregów czasowych.

Keras automatycznie przechodzą właściwą mask argumentu __call__() do warstw, które obsługują, gdy maska jest generowane przez uprzednie warstwy. Maska wytwarzania warstwy są Embedding warstwy skonfigurowany mask_zero=True i Masking warstwy.

Aby dowiedzieć się więcej na temat maskowania i jak napisać maskowanie z obsługą warstw, proszę zapoznać się z instrukcji „zrozumienie wyściółkę i maskującą” .

Model klasy

Ogólnie rzecz biorąc, można użyć Layer klasy zdefiniowanie wewnętrznych bloków obliczeniowych i użyje Model klasy do określenia modelu zewnętrzną - obiekt będzie pociąg.

Na przykład, w modelu ResNet50, to masz kilka bloków ResNet instacji Layer i pojedynczy Model obejmujący całą sieć ResNet50.

Model klasa ma to samo API jako Layer , z następującymi różnicami:

  • Naraża wbudowanej pętli szkolenia, oceny i prognozowania ( model.fit() , model.evaluate() , model.predict() ).
  • To naraża listę swoich wewnętrznych warstw, poprzez model.layers nieruchomości.
  • To naraża oszczędności i serializacji API ( save() , save_weights() ...)

Skutecznie, że Layer klasy odpowiada temu, co nazywamy w literaturze jako „warstwa” (jak w „warstwie splotu” lub „nawracającego warstwy”) lub jako „bloku” (jak w „ResNet bloku” lub „bloku Incepcja” ).

Tymczasem Model klasy odpowiada temu, co jest określane w literaturze jako „model” (jak w „głębokim modelu uczenia się”) lub jako „sieć” (jak w „głębokiej sieci neuronowej”).

Więc jeśli zastanawiasz się, „powinno się używać Layer klasy lub Model klasy?”, Należy zadać sobie pytanie: czy muszę zadzwonić fit() na nim? Będę musiał zadzwonić save() na nim? Jeśli tak, należy przejść z Model . Jeśli nie (albo dlatego, że klasa jest po prostu blok w większym systemie, albo dlatego, że pisanie kodu i szkolenia samodzielnie zapisu), należy użyć Layer .

Na przykład, możemy wziąć nasz przykład mini-resnet powyżej, oraz wykorzystanie go do budowy Model , że mogliśmy trenować z fit() , i że możemy uratować z save_weights() :

class ResNet(tf.keras.Model):

    def __init__(self, num_classes=1000):
        super(ResNet, self).__init__()
        self.block_1 = ResNetBlock()
        self.block_2 = ResNetBlock()
        self.global_pool = layers.GlobalAveragePooling2D()
        self.classifier = Dense(num_classes)

    def call(self, inputs):
        x = self.block_1(inputs)
        x = self.block_2(x)
        x = self.global_pool(x)
        return self.classifier(x)


resnet = ResNet()
dataset = ...
resnet.fit(dataset, epochs=10)
resnet.save(filepath)

Podsumowując: kompletny przykład

Oto, czego się do tej pory nauczyłeś:

  • Layer hermetyzacji stan (utworzony w __init__() lub build() ), a niektóre obliczenia (określonej w call() ).
  • Warstwy mogą być rekursywnie zagnieżdżane, aby tworzyć nowe, większe bloki obliczeniowe.
  • Warstwy mogą tworzyć i straty track (zwykle straty regularyzacji), jak również dane, poprzez add_loss() i add_metric()
  • Pojemnik zewnętrzny, rzecz chcesz trenować, to Model . Model jest jak Layer , ale z dodatkiem narzędzi szkoleniowych i serializacji.

Połączmy wszystkie te rzeczy razem w kompletny przykład: zamierzamy zaimplementować Variational AutoEncoder (VAE). Nauczymy go cyframi MNIST.

Nasz VAE będzie podklasą Model , zbudowany jako zagnieżdżonych kompozycji warstw tej podklasy Layer . Będzie się on charakteryzował stratą regularyzacyjną (dywergencja KL).

from tensorflow.keras import layers


class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class Encoder(layers.Layer):
    """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""

    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z


class Decoder(layers.Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim, activation="sigmoid")

    def call(self, inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)


class VariationalAutoEncoder(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(
        self,
        original_dim,
        intermediate_dim=64,
        latent_dim=32,
        name="autoencoder",
        **kwargs
    ):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed

Napiszmy prostą pętlę treningową na MNIST:

original_dim = 784
vae = VariationalAutoEncoder(original_dim, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

loss_metric = tf.keras.metrics.Mean()

(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

epochs = 2

# Iterate over epochs.
for epoch in range(epochs):
    print("Start of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_batch_train)
            # Compute reconstruction loss
            loss = mse_loss_fn(x_batch_train, reconstructed)
            loss += sum(vae.losses)  # Add KLD regularization loss

        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))

        loss_metric(loss)

        if step % 100 == 0:
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))
Start of epoch 0
step 0: mean loss = 0.3184
step 100: mean loss = 0.1252
step 200: mean loss = 0.0989
step 300: mean loss = 0.0890
step 400: mean loss = 0.0841
step 500: mean loss = 0.0807
step 600: mean loss = 0.0787
step 700: mean loss = 0.0771
step 800: mean loss = 0.0759
step 900: mean loss = 0.0749
Start of epoch 1
step 0: mean loss = 0.0746
step 100: mean loss = 0.0740
step 200: mean loss = 0.0735
step 300: mean loss = 0.0730
step 400: mean loss = 0.0727
step 500: mean loss = 0.0723
step 600: mean loss = 0.0720
step 700: mean loss = 0.0717
step 800: mean loss = 0.0715
step 900: mean loss = 0.0712

Zauważ, że skoro VAE jest instacji Model , to funkcje wbudowane w pętle szkoleniowych. Więc mogłeś też wytrenować to w ten sposób:

vae = VariationalAutoEncoder(784, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=2, batch_size=64)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0745
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0676
<keras.callbacks.History at 0x7fce90282750>

Poza programowaniem obiektowym: funkcjonalne API

Czy ten przykład nie był dla Ciebie zbyt zorientowany obiektowo? Można także zbudować model używając API funkcjonalna . Co ważne, wybór jednego lub drugiego stylu nie przeszkadza w wykorzystywaniu komponentów napisanych w innym stylu: zawsze możesz mieszać i dopasowywać.

Na przykład funkcjonalne przykład API poniżej ponownie wykorzystuje to samo Sampling warstwa to zdefiniowano w powyższym przykładzie:

original_dim = 784
intermediate_dim = 64
latent_dim = 32

# Define encoder model.
original_inputs = tf.keras.Input(shape=(original_dim,), name="encoder_input")
x = layers.Dense(intermediate_dim, activation="relu")(original_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()((z_mean, z_log_var))
encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name="encoder")

# Define decoder model.
latent_inputs = tf.keras.Input(shape=(latent_dim,), name="z_sampling")
x = layers.Dense(intermediate_dim, activation="relu")(latent_inputs)
outputs = layers.Dense(original_dim, activation="sigmoid")(x)
decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name="decoder")

# Define VAE model.
outputs = decoder(z)
vae = tf.keras.Model(inputs=original_inputs, outputs=outputs, name="vae")

# Add KL divergence regularization loss.
kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
vae.add_loss(kl_loss)

# Train.
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=3, batch_size=64)
Epoch 1/3
938/938 [==============================] - 3s 3ms/step - loss: 0.0748
Epoch 2/3
938/938 [==============================] - 3s 3ms/step - loss: 0.0676
Epoch 3/3
938/938 [==============================] - 3s 3ms/step - loss: 0.0676
<keras.callbacks.History at 0x7fce90233cd0>

Aby uzyskać więcej informacji, należy zapoznać się z Functional instrukcji API .