# 函数式 API

## 设置

``````import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
``````

## 简介

Keras 函数式 API 是一种比 `tf.keras.Sequential` API 更加灵活的模型创建方式。函数式 API 可以处理具有非线性拓扑的模型、具有共享层的模型，以及具有多个输入或输出的模型。

``````(input: 784-dimensional vectors)        ↧ [Dense (64 units, relu activation)]        ↧ [Dense (64 units, relu activation)]        ↧ [Dense (10 units, softmax activation)]        ↧ (output: logits of a probability distribution over 10 classes)
``````

``````inputs = keras.Input(shape=(784,))
``````

``````# Just for demonstration purposes.
img_inputs = keras.Input(shape=(32, 32, 3))
``````

``````inputs.shape
``````
```TensorShape([None, 784])
```

dtype 如下：

``````inputs.dtype
``````
```tf.float32
```

``````dense = layers.Dense(64, activation="relu")
x = dense(inputs)
``````

“层调用”操作就像从“输入”向您创建的该层绘制一个箭头。您将输入“传递”到 `dense` 层，然后得到 `x`

``````x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10)(x)
``````

``````model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")
``````

``````model.summary()
``````
```Model: "mnist_model"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
input_1 (InputLayer)        [(None, 784)]             0

dense (Dense)               (None, 64)                50240

dense_1 (Dense)             (None, 64)                4160

dense_2 (Dense)             (None, 10)                650

=================================================================
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________
```

``````keras.utils.plot_model(model, "my_first_model.png")
``````
```You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
```
```

``````keras.utils.plot_model(model, "my_first_model_with_shape_info.png", show_shapes=True)
``````
```You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
```

“层计算图”是深度学习模型的直观心理图像，而函数式 API 是创建密切反映此图像的模型的方法。

## 训练、评估和推断

``````(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop(),
metrics=["accuracy"],
)

history = model.fit(x_train, y_train, batch_size=64, epochs=2, validation_split=0.2)

test_scores = model.evaluate(x_test, y_test, verbose=2)
print("Test loss:", test_scores[0])
print("Test accuracy:", test_scores[1])
``````
```Epoch 1/2
750/750 [==============================] - 3s 3ms/step - loss: 0.3513 - accuracy: 0.9004 - val_loss: 0.1908 - val_accuracy: 0.9461
Epoch 2/2
750/750 [==============================] - 2s 2ms/step - loss: 0.1573 - accuracy: 0.9538 - val_loss: 0.1337 - val_accuracy: 0.9623
313/313 - 0s - loss: 0.1257 - accuracy: 0.9621 - 432ms/epoch - 1ms/step
Test loss: 0.125714972615242
Test accuracy: 0.9621000289916992
```

## 保存和序列化

• 模型架构
• 模型权重值（在训练过程中得知）
• 模型训练配置（如果有的话，如传递给 `compile`
• 优化器及其状态（如果有的话，用来从上次中断的地方重新开始训练）
``````model.save("path_to_my_model")
del model
# Recreate the exact same model purely from the file:
``````
```INFO:tensorflow:Assets written to: path_to_my_model/assets
```

## 使用相同的层计算图定义多个模型

``````encoder_input = keras.Input(shape=(28, 28, 1), name="img")
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.Conv2D(16, 3, activation="relu")(x)
encoder_output = layers.GlobalMaxPooling2D()(x)

encoder = keras.Model(encoder_input, encoder_output, name="encoder")
encoder.summary()

x = layers.Reshape((4, 4, 1))(encoder_output)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu")(x)
x = layers.UpSampling2D(3)(x)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x)

autoencoder = keras.Model(encoder_input, decoder_output, name="autoencoder")
autoencoder.summary()
``````
```Model: "encoder"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
img (InputLayer)            [(None, 28, 28, 1)]       0

conv2d (Conv2D)             (None, 26, 26, 16)        160

conv2d_1 (Conv2D)           (None, 24, 24, 32)        4640

max_pooling2d (MaxPooling2D  (None, 8, 8, 32)         0
)

conv2d_2 (Conv2D)           (None, 6, 6, 32)          9248

conv2d_3 (Conv2D)           (None, 4, 4, 16)          4624

global_max_pooling2d (Globa  (None, 16)               0
lMaxPooling2D)

=================================================================
Total params: 18,672
Trainable params: 18,672
Non-trainable params: 0
_________________________________________________________________
Model: "autoencoder"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
img (InputLayer)            [(None, 28, 28, 1)]       0

conv2d (Conv2D)             (None, 26, 26, 16)        160

conv2d_1 (Conv2D)           (None, 24, 24, 32)        4640

max_pooling2d (MaxPooling2D  (None, 8, 8, 32)         0
)

conv2d_2 (Conv2D)           (None, 6, 6, 32)          9248

conv2d_3 (Conv2D)           (None, 4, 4, 16)          4624

global_max_pooling2d (Globa  (None, 16)               0
lMaxPooling2D)

reshape (Reshape)           (None, 4, 4, 1)           0

conv2d_transpose (Conv2DTra  (None, 6, 6, 16)         160
nspose)

conv2d_transpose_1 (Conv2DT  (None, 8, 8, 32)         4640
ranspose)

up_sampling2d (UpSampling2D  (None, 24, 24, 32)       0
)

conv2d_transpose_2 (Conv2DT  (None, 26, 26, 16)       4624
ranspose)

conv2d_transpose_3 (Conv2DT  (None, 28, 28, 1)        145
ranspose)

=================================================================
Total params: 28,241
Trainable params: 28,241
Non-trainable params: 0
_________________________________________________________________
```

`Conv2D` 层的反面是 `Conv2DTranspose` 层，`MaxPooling2D` 层的反面是 `UpSampling2D` 层。

## 所有模型均可像层一样调用

``````encoder_input = keras.Input(shape=(28, 28, 1), name="original_img")
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.Conv2D(16, 3, activation="relu")(x)
encoder_output = layers.GlobalMaxPooling2D()(x)

encoder = keras.Model(encoder_input, encoder_output, name="encoder")
encoder.summary()

decoder_input = keras.Input(shape=(16,), name="encoded_img")
x = layers.Reshape((4, 4, 1))(decoder_input)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu")(x)
x = layers.UpSampling2D(3)(x)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x)

decoder = keras.Model(decoder_input, decoder_output, name="decoder")
decoder.summary()

autoencoder_input = keras.Input(shape=(28, 28, 1), name="img")
encoded_img = encoder(autoencoder_input)
decoded_img = decoder(encoded_img)
autoencoder = keras.Model(autoencoder_input, decoded_img, name="autoencoder")
autoencoder.summary()
``````
```Model: "encoder"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
original_img (InputLayer)   [(None, 28, 28, 1)]       0

conv2d_4 (Conv2D)           (None, 26, 26, 16)        160

conv2d_5 (Conv2D)           (None, 24, 24, 32)        4640

max_pooling2d_1 (MaxPooling  (None, 8, 8, 32)         0
2D)

conv2d_6 (Conv2D)           (None, 6, 6, 32)          9248

conv2d_7 (Conv2D)           (None, 4, 4, 16)          4624

global_max_pooling2d_1 (Glo  (None, 16)               0
balMaxPooling2D)

=================================================================
Total params: 18,672
Trainable params: 18,672
Non-trainable params: 0
_________________________________________________________________
Model: "decoder"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
encoded_img (InputLayer)    [(None, 16)]              0

reshape_1 (Reshape)         (None, 4, 4, 1)           0

conv2d_transpose_4 (Conv2DT  (None, 6, 6, 16)         160
ranspose)

conv2d_transpose_5 (Conv2DT  (None, 8, 8, 32)         4640
ranspose)

up_sampling2d_1 (UpSampling  (None, 24, 24, 32)       0
2D)

conv2d_transpose_6 (Conv2DT  (None, 26, 26, 16)       4624
ranspose)

conv2d_transpose_7 (Conv2DT  (None, 28, 28, 1)        145
ranspose)

=================================================================
Total params: 9,569
Trainable params: 9,569
Non-trainable params: 0
_________________________________________________________________
Model: "autoencoder"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
img (InputLayer)            [(None, 28, 28, 1)]       0

encoder (Functional)        (None, 16)                18672

decoder (Functional)        (None, 28, 28, 1)         9569

=================================================================
Total params: 28,241
Trainable params: 28,241
Non-trainable params: 0
_________________________________________________________________
```

``````def get_model():
inputs = keras.Input(shape=(128,))
outputs = layers.Dense(1)(inputs)
return keras.Model(inputs, outputs)

model1 = get_model()
model2 = get_model()
model3 = get_model()

inputs = keras.Input(shape=(128,))
y1 = model1(inputs)
y2 = model2(inputs)
y3 = model3(inputs)
outputs = layers.average([y1, y2, y3])
ensemble_model = keras.Model(inputs=inputs, outputs=outputs)
``````

## 处理复杂的计算图拓扑

### 具有多个输入和输出的模型

• 工单标题（文本输入），
• 工单的文本正文（文本输入），以及
• 用户添加的任何标签（分类输入）

• 介于 0 和 1 之间的优先级分数（标量 Sigmoid 输出），以及
• 应该处理工单的部门（部门范围内的 Softmax 输出）。

``````num_tags = 12  # Number of unique issue tags
num_words = 10000  # Size of vocabulary obtained when preprocessing text data
num_departments = 4  # Number of departments for predictions

title_input = keras.Input(
shape=(None,), name="title"
)  # Variable-length sequence of ints
body_input = keras.Input(shape=(None,), name="body")  # Variable-length sequence of ints
tags_input = keras.Input(
shape=(num_tags,), name="tags"
)  # Binary vectors of size `num_tags`

# Embed each word in the title into a 64-dimensional vector
title_features = layers.Embedding(num_words, 64)(title_input)
# Embed each word in the text into a 64-dimensional vector
body_features = layers.Embedding(num_words, 64)(body_input)

# Reduce sequence of embedded words in the title into a single 128-dimensional vector
title_features = layers.LSTM(128)(title_features)
# Reduce sequence of embedded words in the body into a single 32-dimensional vector
body_features = layers.LSTM(32)(body_features)

# Merge all available features into a single large vector via concatenation
x = layers.concatenate([title_features, body_features, tags_input])

# Stick a logistic regression for priority prediction on top of the features
priority_pred = layers.Dense(1, name="priority")(x)
# Stick a department classifier on top of the features
department_pred = layers.Dense(num_departments, name="department")(x)

# Instantiate an end-to-end model predicting both priority and department
model = keras.Model(
inputs=[title_input, body_input, tags_input],
outputs=[priority_pred, department_pred],
)
``````

``````keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)
``````
```You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
```

``````model.compile(
optimizer=keras.optimizers.RMSprop(1e-3),
loss=[
keras.losses.BinaryCrossentropy(from_logits=True),
keras.losses.CategoricalCrossentropy(from_logits=True),
],
loss_weights=[1.0, 0.2],
)
``````

``````model.compile(
optimizer=keras.optimizers.RMSprop(1e-3),
loss={
"priority": keras.losses.BinaryCrossentropy(from_logits=True),
"department": keras.losses.CategoricalCrossentropy(from_logits=True),
},
loss_weights=[1.0, 0.2],
)
``````

``````# Dummy input data
title_data = np.random.randint(num_words, size=(1280, 10))
body_data = np.random.randint(num_words, size=(1280, 100))
tags_data = np.random.randint(2, size=(1280, num_tags)).astype("float32")

# Dummy target data
priority_targets = np.random.random(size=(1280, 1))
dept_targets = np.random.randint(2, size=(1280, num_departments))

model.fit(
{"title": title_data, "body": body_data, "tags": tags_data},
{"priority": priority_targets, "department": dept_targets},
epochs=2,
batch_size=32,
)
``````
```Epoch 1/2
40/40 [==============================] - 4s 8ms/step - loss: 1.3104 - priority_loss: 0.7062 - department_loss: 3.0209
Epoch 2/2
40/40 [==============================] - 0s 7ms/step - loss: 1.3112 - priority_loss: 0.6982 - department_loss: 3.0651
<keras.callbacks.History at 0x7fdc09bf2d00>
```

### 小 ResNet 模型

``````inputs = keras.Input(shape=(32, 32, 3), name="img")
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
block_1_output = layers.MaxPooling2D(3)(x)

x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)

x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)

x = layers.Conv2D(64, 3, activation="relu")(block_3_output)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10)(x)

model = keras.Model(inputs, outputs, name="toy_resnet")
model.summary()
``````
```Model: "toy_resnet"
__________________________________________________________________________________________________
Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
img (InputLayer)               [(None, 32, 32, 3)]  0           []

conv2d_8 (Conv2D)              (None, 30, 30, 32)   896         ['img[0][0]']

conv2d_9 (Conv2D)              (None, 28, 28, 64)   18496       ['conv2d_8[0][0]']

max_pooling2d_2 (MaxPooling2D)  (None, 9, 9, 64)    0           ['conv2d_9[0][0]']

conv2d_10 (Conv2D)             (None, 9, 9, 64)     36928       ['max_pooling2d_2[0][0]']

conv2d_11 (Conv2D)             (None, 9, 9, 64)     36928       ['conv2d_10[0][0]']

'max_pooling2d_2[0][0]']

conv2d_12 (Conv2D)             (None, 9, 9, 64)     36928       ['add[0][0]']

conv2d_13 (Conv2D)             (None, 9, 9, 64)     36928       ['conv2d_12[0][0]']

conv2d_14 (Conv2D)             (None, 7, 7, 64)     36928       ['add_1[0][0]']

global_average_pooling2d (Glob  (None, 64)          0           ['conv2d_14[0][0]']
alAveragePooling2D)

dense_6 (Dense)                (None, 256)          16640       ['global_average_pooling2d[0][0]'
]

dropout (Dropout)              (None, 256)          0           ['dense_6[0][0]']

dense_7 (Dense)                (None, 10)           2570        ['dropout[0][0]']

==================================================================================================
Total params: 223,242
Trainable params: 223,242
Non-trainable params: 0
__________________________________________________________________________________________________
```

``````keras.utils.plot_model(model, "mini_resnet.png", show_shapes=True)
``````
```You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
```

``````(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

model.compile(
optimizer=keras.optimizers.RMSprop(1e-3),
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=["acc"],
)
# We restrict the data to the first 1000 samples so as to limit execution time
# on Colab. Try to train on the entire dataset until convergence!
model.fit(x_train[:1000], y_train[:1000], batch_size=64, epochs=1, validation_split=0.2)
``````
```Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 2s 0us/step
13/13 [==============================] - 2s 28ms/step - loss: 2.3238 - acc: 0.0862 - val_loss: 2.3014 - val_acc: 0.1100
<keras.callbacks.History at 0x7fdc08d11040>
```

## 共享层

``````# Embedding for 1000 unique words mapped to 128-dimensional vectors
shared_embedding = layers.Embedding(1000, 128)

# Variable-length sequence of integers
text_input_a = keras.Input(shape=(None,), dtype="int32")

# Variable-length sequence of integers
text_input_b = keras.Input(shape=(None,), dtype="int32")

# Reuse the same layer to encode both inputs
encoded_input_a = shared_embedding(text_input_a)
encoded_input_b = shared_embedding(text_input_b)
``````

## 提取和重用层计算图中的节点

``````vgg19 = tf.keras.applications.VGG19()
``````
```Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5
574710816/574710816 [==============================] - 3s 0us/step
```

``````features_list = [layer.output for layer in vgg19.layers]
``````

``````feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)

img = np.random.random((1, 224, 224, 3)).astype("float32")
extracted_features = feat_extraction_model(img)
``````

## 使用自定义层扩展 API

`tf.keras` 包含了各种内置层，例如：

• 卷积层：`Conv1D``Conv2D``Conv3D``Conv2DTranspose`
• 池化层：`MaxPooling1D``MaxPooling2D``MaxPooling3D``AveragePooling1D`
• RNN 层：`GRU``LSTM``ConvLSTM2D`
• `BatchNormalization``Dropout``Embedding`

• `call` 方法，用于指定由层完成的计算。
• `build` 方法，用于创建层的权重（这只是一种样式约定，因为您也可以在 `__init__` 中创建权重）。

``````class CustomDense(layers.Layer):
def __init__(self, units=32):
super(CustomDense, self).__init__()
self.units = units

def build(self, input_shape):
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
shape=(self.units,), initializer="random_normal", trainable=True
)

def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b

inputs = keras.Input((4,))
outputs = CustomDense(10)(inputs)

model = keras.Model(inputs, outputs)
``````

``````class CustomDense(layers.Layer):
def __init__(self, units=32):
super(CustomDense, self).__init__()
self.units = units

def build(self, input_shape):
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
shape=(self.units,), initializer="random_normal", trainable=True
)

def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b

def get_config(self):
return {"units": self.units}

inputs = keras.Input((4,))
outputs = CustomDense(10)(inputs)

model = keras.Model(inputs, outputs)
config = model.get_config()

new_model = keras.Model.from_config(config, custom_objects={"CustomDense": CustomDense})
``````

``````def from_config(cls, config):   return cls(**config)
``````

## 何时使用函数式 API

### 函数式 API 的优势：

#### 更加简洁

``````inputs = keras.Input(shape=(32,)) x = layers.Dense(64, activation='relu')(inputs) outputs = layers.Dense(10)(x) mlp = keras.Model(inputs, outputs)
``````

``````class MLP(keras.Model):    def __init__(self, **kwargs):     super(MLP, self).__init__(**kwargs)     self.dense_1 = layers.Dense(64, activation='relu')     self.dense_2 = layers.Dense(10)    def call(self, inputs):     x = self.dense_1(inputs)     return self.dense_2(x)  # Instantiate the model. mlp = MLP() # Necessary to create the model's state. # The model doesn't have a state until it's called at least once. _ = mlp(tf.zeros((1, 32)))
``````

#### 函数式模型可绘制且可检查

``````features_list = [layer.output for layer in vgg19.layers] feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)
``````

## 混搭 API 样式

``````units = 32
timesteps = 10
input_dim = 5

# Define a Functional model
inputs = keras.Input((None, units))
x = layers.GlobalAveragePooling1D()(inputs)
outputs = layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

class CustomRNN(layers.Layer):
def __init__(self):
super(CustomRNN, self).__init__()
self.units = units
self.projection_1 = layers.Dense(units=units, activation="tanh")
self.projection_2 = layers.Dense(units=units, activation="tanh")
# Our previously-defined Functional model
self.classifier = model

def call(self, inputs):
outputs = []
state = tf.zeros(shape=(inputs.shape[0], self.units))
for t in range(inputs.shape[1]):
x = inputs[:, t, :]
h = self.projection_1(x)
y = h + self.projection_2(state)
state = y
outputs.append(y)
features = tf.stack(outputs, axis=1)
print(features.shape)
return self.classifier(features)

rnn_model = CustomRNN()
_ = rnn_model(tf.zeros((1, timesteps, input_dim)))
``````
```(1, 10, 32)
```

• `call(self, inputs, **kwargs)` - 其中 `inputs` 是张量或张量的嵌套结构（例如张量列表），`**kwargs` 是非张量参数（非输入）。
• `call(self, inputs, training=None, **kwargs)` - 其中 `training` 是指示该层是否应在训练模式和推断模式下运行的布尔值。
• `call(self, inputs, mask=None, **kwargs)` - 其中 `mask` 是一个布尔掩码张量（对 RNN 等十分有用）。
• `call(self, inputs, training=None, mask=None, **kwargs)` - 当然，您可以同时具有掩码和训练特有的行为。

``````units = 32
timesteps = 10
input_dim = 5
batch_size = 16

class CustomRNN(layers.Layer):
def __init__(self):
super(CustomRNN, self).__init__()
self.units = units
self.projection_1 = layers.Dense(units=units, activation="tanh")
self.projection_2 = layers.Dense(units=units, activation="tanh")
self.classifier = layers.Dense(1)

def call(self, inputs):
outputs = []
state = tf.zeros(shape=(inputs.shape[0], self.units))
for t in range(inputs.shape[1]):
x = inputs[:, t, :]
h = self.projection_1(x)
y = h + self.projection_2(state)
state = y
outputs.append(y)
features = tf.stack(outputs, axis=1)
return self.classifier(features)

# Note that you specify a static batch size for the inputs with the `batch_shape`
# arg, because the inner computation of `CustomRNN` requires a static batch size
# (when you create the `state` zeros tensor).
inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim))
x = layers.Conv1D(32, 3)(inputs)
outputs = CustomRNN()(x)

model = keras.Model(inputs, outputs)

rnn_model = CustomRNN()
_ = rnn_model(tf.zeros((1, 10, 5)))
``````
