ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.keras.utils.plot_model

Converts a Keras model to dot format and save to a file.

Used in the notebooks

Used in the guide Used in the tutorials

Example:

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

model A Keras model instance
to_file File name of the plot image.
show_shapes whether to display shape information.
show_dtype whether to display layer dtypes.
show_layer_names whether to display layer names.
rankdir rankdir argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.
expand_nested Whether to expand nested models into clusters.
dpi Dots per inch.
layer_range input of list containing two str items, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers for which the plot will be generated. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches to layer_range[0] and the end predicate will be the last element it matches to layer_range[1]. By default None which considers all layers of model. Note that you must pass range such that the resultant subgraph must be complete.

A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks.