![]() |
![]() |
![]() |
![]() |
This guide gives you the basics to get started with Keras. It's a 10-minute read.
Import tf.keras
tf.keras
is TensorFlow's implementation of the
Keras API specification. This is a high-level
API to build and train models that includes first-class support for
TensorFlow-specific functionality, such as eager execution,
tf.data
pipelines, and Estimators.
tf.keras
makes TensorFlow easier to use without sacrificing flexibility and
performance.
To get started, import tf.keras
as part of your TensorFlow program setup:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
tf.keras
can run any Keras-compatible code, but keep in mind:
- The
tf.keras
version in the latest TensorFlow release might not be the same as the latestkeras
version from PyPI. Checktf.keras.version
. - When saving a model's weights,
tf.keras
defaults to the checkpoint format. Passsave_format='h5'
to use HDF5 (or pass a filename that ends in.h5
).
Build a simple model
Sequential model
In Keras, you assemble layers to build models. A model is (usually) a graph
of layers. The most common type of model is a stack of layers: the
tf.keras.Sequential
model.
To build a simple, fully-connected network (i.e. multi-layer perceptron):
from tensorflow.keras import layers
model = tf.keras.Sequential()
# Adds a densely-connected layer with 64 units to the model:
model.add(layers.Dense(64, activation='relu'))
# Add another:
model.add(layers.Dense(64, activation='relu'))
# Add a softmax layer with 10 output units:
model.add(layers.Dense(10, activation='softmax'))
You can find a complete, short example of how to use Sequential models here.
To learn about building more advanced models than Sequential models, see: - Guide to the Keras Functional API - Guide to writing layers and models from scratch with subclassing
Configure the layers
There are many tf.keras.layers
available. Most of them share some common constructor
arguments:
activation
: Set the activation function for the layer. This parameter is specified by the name of a built-in function or as a callable object. By default, no activation is applied.kernel_initializer
andbias_initializer
: The initialization schemes that create the layer's weights (kernel and bias). This parameter is a name or a callable object. This defaults to the"Glorot uniform"
initializer.kernel_regularizer
andbias_regularizer
: The regularization schemes that apply the layer's weights (kernel and bias), such as L1 or L2 regularization. By default, no regularization is applied.
The following instantiates tf.keras.layers.Dense
layers using constructor
arguments:
# Create a sigmoid layer:
layers.Dense(64, activation='sigmoid')
# Or:
layers.Dense(64, activation=tf.keras.activations.sigmoid)
# A linear layer with L1 regularization of factor 0.01 applied to the kernel matrix:
layers.Dense(64, kernel_regularizer=tf.keras.regularizers.l1(0.01))
# A linear layer with L2 regularization of factor 0.01 applied to the bias vector:
layers.Dense(64, bias_regularizer=tf.keras.regularizers.l2(0.01))
# A linear layer with a kernel initialized to a random orthogonal matrix:
layers.Dense(64, kernel_initializer='orthogonal')
# A linear layer with a bias vector initialized to 2.0s:
layers.Dense(64, bias_initializer=tf.keras.initializers.Constant(2.0))
<tensorflow.python.keras.layers.core.Dense at 0x7fe74f13eac8>
Train and evaluate
Set up training
After the model is constructed, configure its learning process by calling the
compile
method:
model = tf.keras.Sequential([
# Adds a densely-connected layer with 64 units to the model:
layers.Dense(64, activation='relu', input_shape=(32,)),
# Add another:
layers.Dense(64, activation='relu'),
# Add a softmax layer with 10 output units:
layers.Dense(10, activation='softmax')])
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss='categorical_crossentropy',
metrics=['accuracy'])
tf.keras.Model.compile
takes three important arguments:
optimizer
: This object specifies the training procedure. Pass it optimizer instances from thetf.keras.optimizers
module, such astf.keras.optimizers.Adam
ortf.keras.optimizers.SGD
. If you just want to use the default parameters, you can also specify optimizers via strings, such as'adam'
or'sgd'
.loss
: The function to minimize during optimization. Common choices include mean square error (mse
),categorical_crossentropy
, andbinary_crossentropy
. Loss functions are specified by name or by passing a callable object from thetf.keras.losses
module.metrics
: Used to monitor training. These are string names or callables from thetf.keras.metrics
module.- Additionally, to make sure the model trains and evaluates eagerly, you can make sure to pass
run_eagerly=True
as a parameter to compile.
The following shows a few examples of configuring a model for training:
# Configure a model for mean-squared error regression.
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss='mse', # mean squared error
metrics=['mae']) # mean absolute error
# Configure a model for categorical classification.
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()])
Train from NumPy data
For small datasets, use in-memory NumPy
arrays to train and evaluate a model. The model is "fit" to the training data
using the fit
method:
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
model.fit(data, labels, epochs=10, batch_size=32)
Train on 1000 samples Epoch 1/10 1000/1000 [==============================] - 1s 941us/sample - loss: 292.6379 - categorical_accuracy: 0.1150 Epoch 2/10 1000/1000 [==============================] - 0s 79us/sample - loss: 1195.7818 - categorical_accuracy: 0.0940 Epoch 3/10 1000/1000 [==============================] - 0s 79us/sample - loss: 2382.5031 - categorical_accuracy: 0.0850 Epoch 4/10 1000/1000 [==============================] - 0s 77us/sample - loss: 4133.6880 - categorical_accuracy: 0.0900 Epoch 5/10 1000/1000 [==============================] - 0s 78us/sample - loss: 6172.5911 - categorical_accuracy: 0.1010 Epoch 6/10 1000/1000 [==============================] - 0s 77us/sample - loss: 8432.6698 - categorical_accuracy: 0.1120 Epoch 7/10 1000/1000 [==============================] - 0s 79us/sample - loss: 11067.2603 - categorical_accuracy: 0.0920 Epoch 8/10 1000/1000 [==============================] - 0s 80us/sample - loss: 14060.8606 - categorical_accuracy: 0.0980 Epoch 9/10 1000/1000 [==============================] - 0s 76us/sample - loss: 17402.4838 - categorical_accuracy: 0.0930 Epoch 10/10 1000/1000 [==============================] - 0s 77us/sample - loss: 21134.0576 - categorical_accuracy: 0.0910 <tensorflow.python.keras.callbacks.History at 0x7fe74c0f1240>
tf.keras.Model.fit
takes three important arguments:
epochs
: Training is structured into epochs. An epoch is one iteration over the entire input data (this is done in smaller batches).batch_size
: When passed NumPy data, the model slices the data into smaller batches and iterates over these batches during training. This integer specifies the size of each batch. Be aware that the last batch may be smaller if the total number of samples is not divisible by the batch size.validation_data
: When prototyping a model, you want to easily monitor its performance on some validation data. Passing this argument—a tuple of inputs and labels—allows the model to display the loss and metrics in inference mode for the passed data, at the end of each epoch.
Here's an example using validation_data
:
import numpy as np
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
val_data = np.random.random((100, 32))
val_labels = np.random.random((100, 10))
model.fit(data, labels, epochs=10, batch_size=32,
validation_data=(val_data, val_labels))
Train on 1000 samples, validate on 100 samples Epoch 1/10 1000/1000 [==============================] - 0s 178us/sample - loss: 24941.5426 - categorical_accuracy: 0.1110 - val_loss: 22170.3123 - val_categorical_accuracy: 0.0800 Epoch 2/10 1000/1000 [==============================] - 0s 95us/sample - loss: 27946.4825 - categorical_accuracy: 0.0990 - val_loss: 22531.8664 - val_categorical_accuracy: 0.1000 Epoch 3/10 1000/1000 [==============================] - 0s 95us/sample - loss: 33313.0864 - categorical_accuracy: 0.0960 - val_loss: 40268.1456 - val_categorical_accuracy: 0.1300 Epoch 4/10 1000/1000 [==============================] - 0s 93us/sample - loss: 39161.3360 - categorical_accuracy: 0.1050 - val_loss: 33861.4163 - val_categorical_accuracy: 0.1000 Epoch 5/10 1000/1000 [==============================] - 0s 100us/sample - loss: 44375.6921 - categorical_accuracy: 0.1040 - val_loss: 46568.1248 - val_categorical_accuracy: 0.0700 Epoch 6/10 1000/1000 [==============================] - 0s 95us/sample - loss: 48923.2586 - categorical_accuracy: 0.0970 - val_loss: 81242.8403 - val_categorical_accuracy: 0.0700 Epoch 7/10 1000/1000 [==============================] - 0s 94us/sample - loss: 58926.1624 - categorical_accuracy: 0.0910 - val_loss: 67530.7906 - val_categorical_accuracy: 0.0900 Epoch 8/10 1000/1000 [==============================] - 0s 92us/sample - loss: 61365.5712 - categorical_accuracy: 0.1020 - val_loss: 73100.3025 - val_categorical_accuracy: 0.1400 Epoch 9/10 1000/1000 [==============================] - 0s 94us/sample - loss: 71333.1506 - categorical_accuracy: 0.0980 - val_loss: 71719.7863 - val_categorical_accuracy: 0.1400 Epoch 10/10 1000/1000 [==============================] - 0s 94us/sample - loss: 74634.9986 - categorical_accuracy: 0.1100 - val_loss: 67717.8081 - val_categorical_accuracy: 0.1100 <tensorflow.python.keras.callbacks.History at 0x7fe74fad4ba8>
Train from tf.data datasets
Use the Datasets API to scale to large datasets
or multi-device training. Pass a tf.data.Dataset
instance to the fit
method:
# Instantiates a toy dataset instance:
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
model.fit(dataset, epochs=10)
Epoch 1/10 32/32 [==============================] - 0s 4ms/step - loss: 81467.1696 - categorical_accuracy: 0.1040 Epoch 2/10 32/32 [==============================] - 0s 3ms/step - loss: 95568.6485 - categorical_accuracy: 0.0970 Epoch 3/10 32/32 [==============================] - 0s 3ms/step - loss: 97389.2383 - categorical_accuracy: 0.1060 Epoch 4/10 32/32 [==============================] - 0s 3ms/step - loss: 104265.7252 - categorical_accuracy: 0.0950 Epoch 5/10 32/32 [==============================] - 0s 3ms/step - loss: 116323.9572 - categorical_accuracy: 0.1010 Epoch 6/10 32/32 [==============================] - 0s 3ms/step - loss: 124712.1657 - categorical_accuracy: 0.1090 Epoch 7/10 32/32 [==============================] - 0s 3ms/step - loss: 132602.0694 - categorical_accuracy: 0.0930 Epoch 8/10 32/32 [==============================] - 0s 3ms/step - loss: 144176.1806 - categorical_accuracy: 0.0930 Epoch 9/10 32/32 [==============================] - 0s 3ms/step - loss: 151150.5936 - categorical_accuracy: 0.1160 Epoch 10/10 32/32 [==============================] - 0s 3ms/step - loss: 163552.8846 - categorical_accuracy: 0.1040 <tensorflow.python.keras.callbacks.History at 0x7fe72c563940>
Since the Dataset
yields batches of data, this snippet does not require a batch_size
.
Datasets can also be used for validation:
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
val_dataset = val_dataset.batch(32)
model.fit(dataset, epochs=10,
validation_data=val_dataset)
Epoch 1/10 32/32 [==============================] - 0s 4ms/step - loss: 174823.2109 - categorical_accuracy: 0.1040 - val_loss: 0.0000e+00 - val_categorical_accuracy: 0.0000e+00 Epoch 2/10 32/32 [==============================] - 0s 3ms/step - loss: 182694.2235 - categorical_accuracy: 0.0920 - val_loss: 207405.1328 - val_categorical_accuracy: 0.1000 Epoch 3/10 32/32 [==============================] - 0s 3ms/step - loss: 198014.8274 - categorical_accuracy: 0.0910 - val_loss: 254707.6133 - val_categorical_accuracy: 0.0900 Epoch 4/10 32/32 [==============================] - 0s 3ms/step - loss: 210588.8274 - categorical_accuracy: 0.1020 - val_loss: 169632.5312 - val_categorical_accuracy: 0.1300 Epoch 5/10 32/32 [==============================] - 0s 3ms/step - loss: 219237.8966 - categorical_accuracy: 0.1020 - val_loss: 213207.9141 - val_categorical_accuracy: 0.1300 Epoch 6/10 32/32 [==============================] - 0s 3ms/step - loss: 234645.3691 - categorical_accuracy: 0.1040 - val_loss: 273672.8672 - val_categorical_accuracy: 0.0900 Epoch 7/10 32/32 [==============================] - 0s 3ms/step - loss: 246899.7756 - categorical_accuracy: 0.1050 - val_loss: 215929.5312 - val_categorical_accuracy: 0.1000 Epoch 8/10 32/32 [==============================] - 0s 3ms/step - loss: 257410.2769 - categorical_accuracy: 0.1090 - val_loss: 303968.7422 - val_categorical_accuracy: 0.1300 Epoch 9/10 32/32 [==============================] - 0s 3ms/step - loss: 273098.0448 - categorical_accuracy: 0.0910 - val_loss: 279622.2812 - val_categorical_accuracy: 0.0900 Epoch 10/10 32/32 [==============================] - 0s 3ms/step - loss: 291685.4157 - categorical_accuracy: 0.0890 - val_loss: 277359.3203 - val_categorical_accuracy: 0.1000 <tensorflow.python.keras.callbacks.History at 0x7fe72c27f0f0>
Evaluate and predict
The tf.keras.Model.evaluate
and tf.keras.Model.predict
methods can use NumPy
data and a tf.data.Dataset
.
Here's how to evaluate the inference-mode loss and metrics for the data provided:
# With Numpy arrays
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
model.evaluate(data, labels, batch_size=32)
# With a Dataset
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
model.evaluate(dataset)
1000/1 [================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 53us/sample - loss: 249148.8575 - categorical_accuracy: 0.0960 32/32 [==============================] - 0s 2ms/step - loss: 262983.5693 - categorical_accuracy: 0.0960 [262983.5693359375, 0.096]
And here's how to predict the output of the last layer in inference for the data provided, as a NumPy array:
result = model.predict(data, batch_size=32)
print(result.shape)
(1000, 10)
For a complete guide on training and evaluation, including how to write custom training loops from scratch, see the guide to training and evaluation.
Build complex models
The Functional API
The tf.keras.Sequential
model is a simple stack of layers that cannot
represent arbitrary models. Use the
Keras functional API
to build complex model topologies such as:
- Multi-input models,
- Multi-output models,
- Models with shared layers (the same layer called several times),
- Models with non-sequential data flows (e.g. residual connections).
Building a model with the functional API works like this:
- A layer instance is callable and returns a tensor.
- Input tensors and output tensors are used to define a
tf.keras.Model
instance. - This model is trained just like the
Sequential
model.
The following example uses the functional API to build a simple, fully-connected network:
inputs = tf.keras.Input(shape=(32,)) # Returns an input placeholder
# A layer instance is callable on a tensor, and returns a tensor.
x = layers.Dense(64, activation='relu')(inputs)
x = layers.Dense(64, activation='relu')(x)
predictions = layers.Dense(10, activation='softmax')(x)
Instantiate the model given inputs and outputs.
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# The compile step specifies the training configuration.
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Trains for 5 epochs
model.fit(data, labels, batch_size=32, epochs=5)
Train on 1000 samples Epoch 1/5 1000/1000 [==============================] - 0s 472us/sample - loss: 14.3942 - accuracy: 0.0980 Epoch 2/5 1000/1000 [==============================] - 0s 83us/sample - loss: 29.4704 - accuracy: 0.1010 Epoch 3/5 1000/1000 [==============================] - 0s 80us/sample - loss: 54.4626 - accuracy: 0.1120 Epoch 4/5 1000/1000 [==============================] - 0s 78us/sample - loss: 87.0312 - accuracy: 0.1020 Epoch 5/5 1000/1000 [==============================] - 0s 80us/sample - loss: 125.3873 - accuracy: 0.1160 <tensorflow.python.keras.callbacks.History at 0x7fe72c01f128>
Model subclassing
Build a fully-customizable model by subclassing tf.keras.Model
and defining
your own forward pass. Create layers in the __init__
method and set them as
attributes of the class instance. Define the forward pass in the call
method.
Model subclassing is particularly useful when eager execution is enabled, because it allows the forward pass to be written imperatively.
The following example shows a subclassed tf.keras.Model
using a custom forward
pass that does not have to be run imperatively:
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define your layers here.
self.dense_1 = layers.Dense(32, activation='relu')
self.dense_2 = layers.Dense(num_classes, activation='sigmoid')
def call(self, inputs):
# Define your forward pass here,
# using layers you previously defined (in `__init__`).
x = self.dense_1(inputs)
return self.dense_2(x)
Instantiate the new model class:
model = MyModel(num_classes=10)
# The compile step specifies the training configuration.
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Trains for 5 epochs.
model.fit(data, labels, batch_size=32, epochs=5)
Train on 1000 samples Epoch 1/5 1000/1000 [==============================] - 1s 572us/sample - loss: 11.4854 - accuracy: 0.0910 Epoch 2/5 1000/1000 [==============================] - 0s 78us/sample - loss: 11.4548 - accuracy: 0.0910 Epoch 3/5 1000/1000 [==============================] - 0s 83us/sample - loss: 11.4474 - accuracy: 0.1060 Epoch 4/5 1000/1000 [==============================] - 0s 81us/sample - loss: 11.4434 - accuracy: 0.1070 Epoch 5/5 1000/1000 [==============================] - 0s 80us/sample - loss: 11.4400 - accuracy: 0.1070 <tensorflow.python.keras.callbacks.History at 0x7fe6d05a1d30>
Custom layers
Create a custom layer by subclassing tf.keras.layers.Layer
and implementing
the following methods:
__init__
: Optionally define sublayers to be used by this layer.build
: Create the weights of the layer. Add weights with theadd_weight
method.call
: Define the forward pass.- Optionally, a layer can be serialized by implementing the
get_config
method and thefrom_config
class method.
Here's an example of a custom layer that implements a matmul
of an input with
a kernel matrix:
class MyLayer(layers.Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.kernel)
def get_config(self):
base_config = super(MyLayer, self).get_config()
base_config['output_dim'] = self.output_dim
return base_config
@classmethod
def from_config(cls, config):
return cls(**config)
Create a model using your custom layer:
model = tf.keras.Sequential([
MyLayer(10),
layers.Activation('softmax')])
# The compile step specifies the training configuration
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Trains for 5 epochs.
model.fit(data, labels, batch_size=32, epochs=5)
Train on 1000 samples Epoch 1/5 1000/1000 [==============================] - 0s 289us/sample - loss: 11.4484 - accuracy: 0.1000 Epoch 2/5 1000/1000 [==============================] - 0s 62us/sample - loss: 11.4486 - accuracy: 0.1030 Epoch 3/5 1000/1000 [==============================] - 0s 65us/sample - loss: 11.4487 - accuracy: 0.1010 Epoch 4/5 1000/1000 [==============================] - 0s 64us/sample - loss: 11.4484 - accuracy: 0.1030 Epoch 5/5 1000/1000 [==============================] - 0s 64us/sample - loss: 11.4483 - accuracy: 0.1000 <tensorflow.python.keras.callbacks.History at 0x7fe6c0714160>
Learn more about creating new layers and models from scratch with subclassing in the Guide to writing layers and models from scratch.
Callbacks
A callback is an object passed to a model to customize and extend its behavior
during training. You can write your own custom callback, or use the built-in
tf.keras.callbacks
that include:
tf.keras.callbacks.ModelCheckpoint
: Save checkpoints of your model at regular intervals.tf.keras.callbacks.LearningRateScheduler
: Dynamically change the learning rate.tf.keras.callbacks.EarlyStopping
: Interrupt training when validation performance has stopped improving.tf.keras.callbacks.TensorBoard
: Monitor the model's behavior using TensorBoard.
To use a tf.keras.callbacks.Callback
, pass it to the model's fit
method:
callbacks = [
# Interrupt training if `val_loss` stops improving for over 2 epochs
tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# Write TensorBoard logs to `./logs` directory
tf.keras.callbacks.TensorBoard(log_dir='./logs')
]
model.fit(data, labels, batch_size=32, epochs=5, callbacks=callbacks,
validation_data=(val_data, val_labels))
Train on 1000 samples, validate on 100 samples Epoch 1/5 1000/1000 [==============================] - 0s 157us/sample - loss: 11.4479 - accuracy: 0.1000 - val_loss: 11.7400 - val_accuracy: 0.1400 Epoch 2/5 1000/1000 [==============================] - 0s 82us/sample - loss: 11.4479 - accuracy: 0.1010 - val_loss: 11.7400 - val_accuracy: 0.1400 Epoch 3/5 1000/1000 [==============================] - 0s 79us/sample - loss: 11.4478 - accuracy: 0.1040 - val_loss: 11.7404 - val_accuracy: 0.1400 Epoch 4/5 1000/1000 [==============================] - 0s 80us/sample - loss: 11.4478 - accuracy: 0.1030 - val_loss: 11.7407 - val_accuracy: 0.1400 <tensorflow.python.keras.callbacks.History at 0x7fe6c02d5240>
Save and restore
Save just the weights values
Save and load the weights of a model using tf.keras.Model.save_weights
:
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),
layers.Dense(10, activation='softmax')])
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Save weights to a TensorFlow Checkpoint file
model.save_weights('./weights/my_model')
# Restore the model's state,
# this requires a model with the same architecture.
model.load_weights('./weights/my_model')
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fe76d32d128>
By default, this saves the model's weights in the TensorFlow checkpoint file format. Weights can also be saved to the Keras HDF5 format (the default for the multi-backend implementation of Keras):
# Save weights to a HDF5 file
model.save_weights('my_model.h5', save_format='h5')
# Restore the model's state
model.load_weights('my_model.h5')
Save just the model configuration
A model's configuration can be saved—this serializes the model architecture without any weights. A saved configuration can recreate and initialize the same model, even without the code that defined the original model. Keras supports JSON and YAML serialization formats:
# Serialize a model to JSON format
json_string = model.to_json()
json_string
'{"class_name": "Sequential", "config": {"name": "sequential_3", "layers": [{"class_name": "Dense", "config": {"name": "dense_17", "trainable": true, "batch_input_shape": [null, 32], "dtype": "float32", "units": 64, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_18", "trainable": true, "dtype": "float32", "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.2.4-tf", "backend": "tensorflow"}'
import json
import pprint
pprint.pprint(json.loads(json_string))
{'backend': 'tensorflow', 'class_name': 'Sequential', 'config': {'layers': [{'class_name': 'Dense', 'config': {'activation': 'relu', 'activity_regularizer': None, 'batch_input_shape': [None, 32], 'bias_constraint': None, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'bias_regularizer': None, 'dtype': 'float32', 'kernel_constraint': None, 'kernel_initializer': {'class_name': 'GlorotUniform', 'config': {'seed': None}}, 'kernel_regularizer': None, 'name': 'dense_17', 'trainable': True, 'units': 64, 'use_bias': True}}, {'class_name': 'Dense', 'config': {'activation': 'softmax', 'activity_regularizer': None, 'bias_constraint': None, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'bias_regularizer': None, 'dtype': 'float32', 'kernel_constraint': None, 'kernel_initializer': {'class_name': 'GlorotUniform', 'config': {'seed': None}}, 'kernel_regularizer': None, 'name': 'dense_18', 'trainable': True, 'units': 10, 'use_bias': True}}], 'name': 'sequential_3'}, 'keras_version': '2.2.4-tf'}
Recreate the model (newly initialized) from the JSON:
fresh_model = tf.keras.models.model_from_json(json_string)
Serializing a model to YAML format requires that you install pyyaml
before you import TensorFlow:
yaml_string = model.to_yaml()
print(yaml_string)
backend: tensorflow class_name: Sequential config: layers: - class_name: Dense config: activation: relu activity_regularizer: null batch_input_shape: !!python/tuple [null, 32] bias_constraint: null bias_initializer: class_name: Zeros config: {} bias_regularizer: null dtype: float32 kernel_constraint: null kernel_initializer: class_name: GlorotUniform config: {seed: null} kernel_regularizer: null name: dense_17 trainable: true units: 64 use_bias: true - class_name: Dense config: activation: softmax activity_regularizer: null bias_constraint: null bias_initializer: class_name: Zeros config: {} bias_regularizer: null dtype: float32 kernel_constraint: null kernel_initializer: class_name: GlorotUniform config: {seed: null} kernel_regularizer: null name: dense_18 trainable: true units: 10 use_bias: true name: sequential_3 keras_version: 2.2.4-tf
Recreate the model from the YAML:
fresh_model = tf.keras.models.model_from_yaml(yaml_string)
Save the entire model in one file
The entire model can be saved to a file that contains the weight values, the model's configuration, and even the optimizer's configuration. This allows you to checkpoint a model and resume training later—from the exact same state—without access to the original code.
# Create a simple model
model = tf.keras.Sequential([
layers.Dense(10, activation='softmax', input_shape=(32,)),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, batch_size=32, epochs=5)
# Save entire model to a HDF5 file
model.save('my_model.h5')
# Recreate the exact same model, including weights and optimizer.
model = tf.keras.models.load_model('my_model.h5')
Train on 1000 samples Epoch 1/5 1000/1000 [==============================] - 0s 404us/sample - loss: 11.4749 - accuracy: 0.1130 Epoch 2/5 1000/1000 [==============================] - 0s 73us/sample - loss: 11.4722 - accuracy: 0.1040 Epoch 3/5 1000/1000 [==============================] - 0s 75us/sample - loss: 11.4959 - accuracy: 0.1060 Epoch 4/5 1000/1000 [==============================] - 0s 73us/sample - loss: 11.5455 - accuracy: 0.0990 Epoch 5/5 1000/1000 [==============================] - 0s 74us/sample - loss: 11.5848 - accuracy: 0.0970
Learn more about saving and serialization for Keras models in the guide to save and serialize models.
Eager execution
Eager execution is an imperative programming
environment that evaluates operations immediately. This is not required for
Keras, but is supported by tf.keras
and useful for inspecting your program and
debugging.
All of the tf.keras
model-building APIs are compatible with eager execution.
And while the Sequential
and functional APIs can be used, eager execution
especially benefits model subclassing and building custom layers—the APIs
that require you to write the forward pass as code (instead of the APIs that
create models by assembling existing layers).
See the eager execution guide for
examples of using Keras models with custom training loops and tf.GradientTape
.
You can also find a complete, short example here.
Distribution
Multiple GPUs
tf.keras
models can run on multiple GPUs using
tf.distribute.Strategy
. This API provides distributed
training on multiple GPUs with almost no changes to existing code.
Currently, tf.distribute.MirroredStrategy
is the only supported
distribution strategy. MirroredStrategy
does in-graph replication with
synchronous training using all-reduce on a single machine. To use
distribute.Strategy
s , nest the optimizer instantiation and model construction and compilation in a Strategy
's .scope()
, then
train the model.
The following example distributes a tf.keras.Model
across multiple GPUs on a
single machine.
First, define a model inside the distributed strategy scope:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential()
model.add(layers.Dense(16, activation='relu', input_shape=(10,)))
model.add(layers.Dense(1, activation='sigmoid'))
optimizer = tf.keras.optimizers.SGD(0.2)
model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_21 (Dense) (None, 16) 176 _________________________________________________________________ dense_22 (Dense) (None, 1) 17 ================================================================= Total params: 193 Trainable params: 193 Non-trainable params: 0 _________________________________________________________________
Next, train the model on data as usual:
x = np.random.random((1024, 10))
y = np.random.randint(2, size=(1024, 1))
x = tf.cast(x, tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(buffer_size=1024).batch(32)
model.fit(dataset, epochs=1)
32/32 [==============================] - 2s 75ms/step - loss: 0.7055 <tensorflow.python.keras.callbacks.History at 0x7fe6607a3898>
For more information, see the full guide on Distributed Training in TensorFlow.