{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "headers"
},
"source": [
"Project: /overview/_project.yaml\n",
"Book: /overview/_book.yaml\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d201e826ab29"
},
"source": [
"# コールバックを書く"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "71699af85d2d"
},
"source": [
"\n",
" \n",
" \n",
" \n",
" \n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d75eb2e25f36"
},
"source": [
"## はじめに\n",
"\n",
"コールバックは、トレーニング、評価、推論の間に Keras モデルの動作をカスタマイズするための強力なツールです。例には、TensorBoard でトレーニングの進捗状況や結果を可視化できる `tf.keras.callbacks.TensorBoard` や、トレーニング中にモデルを定期的に保存できる `tf.keras.callbacks.ModelCheckpoint` などを含みます。\n",
"\n",
"このガイドでは、Keras コールバックとは何か、それができること、そして独自のコールバックを構築する方法を学ぶことができます。まずは、簡単なコールバックアプリケーションのデモをいくつか紹介します。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b3600ee25c8e"
},
"source": [
"## セットアップ"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:37.920169Z",
"iopub.status.busy": "2021-02-12T21:32:37.919573Z",
"iopub.status.idle": "2021-02-12T21:32:43.586422Z",
"shell.execute_reply": "2021-02-12T21:32:43.585852Z"
},
"id": "4dadb6688663"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "42676f705fc8"
},
"source": [
"## Keras コールバックの概要\n",
"\n",
"全てのコールバックは `keras.callbacks.Callbacks.Callback` クラスをサブクラス化し、トレーニング、テスト、予測のさまざまな段階で呼び出される一連のメソッドをオーバーライドします。コールバックは、トレーニング中にモデルの内部状態や統計上のビューを取得するのに有用です。\n",
"\n",
"以下のモデルメソッドには、（キーワード引数 `callbacks` として）コールバックのリストを渡すことができます。\n",
"\n",
"- `keras.Model.fit()`\n",
"- `keras.Model.evaluate()`\n",
"- `keras.Model.predict()`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "46945bdf5056"
},
"source": [
"## コールバックメソッドの概要\n",
"\n",
"### グローバルメソッド\n",
"\n",
"#### `on_(train|test|predict)_begin(self, logs=None)`\n",
"\n",
"`fit`/`evaluate`/`predict` の先頭で呼び出されます。\n",
"\n",
"#### `on_(train|test|predict)_end(self, logs=None)`\n",
"\n",
"`fit`/`evaluate`/`predict` の最後に呼び出されます。\n",
"\n",
"### トレーニング/テスト/予測のためのバッチレベルのメソッド\n",
"\n",
"#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`\n",
"\n",
"トレーニング/テスト/予測中に、バッチを処理する直前に呼び出されます。\n",
"\n",
"#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`\n",
"\n",
"バッチのトレーニング/テスト/予測の終了時に呼び出されます。このメソッド内では、`logs` はメトリクスの結果を含むディクショナリです。\n",
"\n",
"### エポックレベルのメソッド（トレーニングのみ）\n",
"\n",
"#### `on_epoch_begin(self, epoch, logs=None)`\n",
"\n",
"トレーニング中に、エポックの最初に呼び出されます。\n",
"\n",
"#### `on_epoch_end(self, epoch, logs=None)`\n",
"\n",
"トレーニング中、エポックの最後に呼び出されます。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "82f2370418a1"
},
"source": [
"## 基本的な例\n",
"\n",
"具体的な例を見てみましょう。まず最初に、TensorFlow をインポートして単純な Sequential Keras モデルを定義してみます。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:43.592409Z",
"iopub.status.busy": "2021-02-12T21:32:43.591752Z",
"iopub.status.idle": "2021-02-12T21:32:43.593078Z",
"shell.execute_reply": "2021-02-12T21:32:43.593525Z"
},
"id": "7350ea602e50"
},
"outputs": [],
"source": [
"# Define the Keras model to add callbacks to\n",
"def get_model():\n",
" model = keras.Sequential()\n",
" model.add(keras.layers.Dense(1, input_dim=784))\n",
" model.compile(\n",
" optimizer=keras.optimizers.RMSprop(learning_rate=0.1),\n",
" loss=\"mean_squared_error\",\n",
" metrics=[\"mean_absolute_error\"],\n",
" )\n",
" return model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "044db5f2dc6f"
},
"source": [
"次に、Keras データセット API からトレーニングとテスト用の MNIST データを読み込みます。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:43.598628Z",
"iopub.status.busy": "2021-02-12T21:32:43.598029Z",
"iopub.status.idle": "2021-02-12T21:32:43.944306Z",
"shell.execute_reply": "2021-02-12T21:32:43.944704Z"
},
"id": "f8826736a184"
},
"outputs": [],
"source": [
"# Load example MNIST data and pre-process it\n",
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
"x_train = x_train.reshape(-1, 784).astype(\"float32\") / 255.0\n",
"x_test = x_test.reshape(-1, 784).astype(\"float32\") / 255.0\n",
"\n",
"# Limit the data to 1000 samples\n",
"x_train = x_train[:1000]\n",
"y_train = y_train[:1000]\n",
"x_test = x_test[:1000]\n",
"y_test = y_test[:1000]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b9acd50b2215"
},
"source": [
"今度は、以下のログを記録する単純なカスタムコールバックを定義します。\n",
"\n",
"- When `fit`/`evaluate`/`predict` starts & ends\n",
"- When each epoch starts & ends\n",
"- 各トレーニングバッチの開始時と終了時\n",
"- 各評価（テスト）バッチの開始時と終了時\n",
"- 各推論（予測）バッチの開始時と終了時"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:43.956463Z",
"iopub.status.busy": "2021-02-12T21:32:43.955774Z",
"iopub.status.idle": "2021-02-12T21:32:43.957338Z",
"shell.execute_reply": "2021-02-12T21:32:43.957685Z"
},
"id": "cc9888d28e79"
},
"outputs": [],
"source": [
"class CustomCallback(keras.callbacks.Callback):\n",
" def on_train_begin(self, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"Starting training; got log keys: {}\".format(keys))\n",
"\n",
" def on_train_end(self, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"Stop training; got log keys: {}\".format(keys))\n",
"\n",
" def on_epoch_begin(self, epoch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"Start epoch {} of training; got log keys: {}\".format(epoch, keys))\n",
"\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"End epoch {} of training; got log keys: {}\".format(epoch, keys))\n",
"\n",
" def on_test_begin(self, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"Start testing; got log keys: {}\".format(keys))\n",
"\n",
" def on_test_end(self, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"Stop testing; got log keys: {}\".format(keys))\n",
"\n",
" def on_predict_begin(self, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"Start predicting; got log keys: {}\".format(keys))\n",
"\n",
" def on_predict_end(self, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"Stop predicting; got log keys: {}\".format(keys))\n",
"\n",
" def on_train_batch_begin(self, batch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"...Training: start of batch {}; got log keys: {}\".format(batch, keys))\n",
"\n",
" def on_train_batch_end(self, batch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"...Training: end of batch {}; got log keys: {}\".format(batch, keys))\n",
"\n",
" def on_test_batch_begin(self, batch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"...Evaluating: start of batch {}; got log keys: {}\".format(batch, keys))\n",
"\n",
" def on_test_batch_end(self, batch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"...Evaluating: end of batch {}; got log keys: {}\".format(batch, keys))\n",
"\n",
" def on_predict_batch_begin(self, batch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"...Predicting: start of batch {}; got log keys: {}\".format(batch, keys))\n",
"\n",
" def on_predict_batch_end(self, batch, logs=None):\n",
" keys = list(logs.keys())\n",
" print(\"...Predicting: end of batch {}; got log keys: {}\".format(batch, keys))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8184bd3a76c2"
},
"source": [
"試してみましょう。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:43.962751Z",
"iopub.status.busy": "2021-02-12T21:32:43.962169Z",
"iopub.status.idle": "2021-02-12T21:32:46.779838Z",
"shell.execute_reply": "2021-02-12T21:32:46.779366Z"
},
"id": "75f7aa1edac6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting training; got log keys: []\n",
"Start epoch 0 of training; got log keys: []\n",
"...Training: start of batch 0; got log keys: []\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error']\n",
"...Training: start of batch 1; got log keys: []\n",
"...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error']\n",
"...Training: start of batch 2; got log keys: []\n",
"...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error']\n",
"...Training: start of batch 3; got log keys: []\n",
"...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error']\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start testing; got log keys: []\n",
"...Evaluating: start of batch 0; got log keys: []\n",
"...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 1; got log keys: []\n",
"...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 2; got log keys: []\n",
"...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 3; got log keys: []\n",
"...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']\n",
"Stop testing; got log keys: ['loss', 'mean_absolute_error']\n",
"End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']\n",
"Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']\n",
"Start testing; got log keys: []\n",
"...Evaluating: start of batch 0; got log keys: []\n",
"...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 1; got log keys: []\n",
"...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 2; got log keys: []\n",
"...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 3; got log keys: []\n",
"...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 4; got log keys: []\n",
"...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 5; got log keys: []\n",
"...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 6; got log keys: []\n",
"...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error']\n",
"...Evaluating: start of batch 7; got log keys: []\n",
"...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error']\n",
"Stop testing; got log keys: ['loss', 'mean_absolute_error']\n",
"Start predicting; got log keys: []\n",
"...Predicting: start of batch 0; got log keys: []\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"...Predicting: end of batch 0; got log keys: ['outputs']\n",
"...Predicting: start of batch 1; got log keys: []\n",
"...Predicting: end of batch 1; got log keys: ['outputs']\n",
"...Predicting: start of batch 2; got log keys: []\n",
"...Predicting: end of batch 2; got log keys: ['outputs']\n",
"...Predicting: start of batch 3; got log keys: []\n",
"...Predicting: end of batch 3; got log keys: ['outputs']\n",
"...Predicting: start of batch 4; got log keys: []\n",
"...Predicting: end of batch 4; got log keys: ['outputs']\n",
"...Predicting: start of batch 5; got log keys: []\n",
"...Predicting: end of batch 5; got log keys: ['outputs']\n",
"...Predicting: start of batch 6; got log keys: []\n",
"...Predicting: end of batch 6; got log keys: ['outputs']\n",
"...Predicting: start of batch 7; got log keys: []\n",
"...Predicting: end of batch 7; got log keys: ['outputs']\n",
"Stop predicting; got log keys: []\n"
]
}
],
"source": [
"model = get_model()\n",
"model.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size=128,\n",
" epochs=1,\n",
" verbose=0,\n",
" validation_split=0.5,\n",
" callbacks=[CustomCallback()],\n",
")\n",
"\n",
"res = model.evaluate(\n",
" x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]\n",
")\n",
"\n",
"res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "02113b8677eb"
},
"source": [
"### `logs` ディクショナリを使用する\n",
"\n",
"`logs` ディクショナリは、バッチまたはエポックの最後の損失値と全てのメトリクスを含みます。次の例は、損失値と平均絶対誤差を含んでいます。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:46.788758Z",
"iopub.status.busy": "2021-02-12T21:32:46.788110Z",
"iopub.status.idle": "2021-02-12T21:32:47.221177Z",
"shell.execute_reply": "2021-02-12T21:32:47.220518Z"
},
"id": "2c5e71af7abe"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For batch 0, loss is 31.57.\n",
"For batch 1, loss is 452.68.\n",
"For batch 2, loss is 311.45.\n",
"For batch 3, loss is 236.22.\n",
"For batch 4, loss is 190.69.\n",
"For batch 5, loss is 159.97.\n",
"For batch 6, loss is 138.05.\n",
"For batch 7, loss is 124.27.\n",
"The average loss for epoch 0 is 124.27 and mean absolute error is 6.12.\n",
"For batch 0, loss is 4.65.\n",
"For batch 1, loss is 4.74.\n",
"For batch 2, loss is 5.01.\n",
"For batch 3, loss is 4.82.\n",
"For batch 4, loss is 4.79.\n",
"For batch 5, loss is 4.70.\n",
"For batch 6, loss is 4.66.\n",
"For batch 7, loss is 4.56.\n",
"The average loss for epoch 1 is 4.56 and mean absolute error is 1.72.\n",
"For batch 0, loss is 4.67.\n",
"For batch 1, loss is 4.24.\n",
"For batch 2, loss is 4.29.\n",
"For batch 3, loss is 4.22.\n",
"For batch 4, loss is 4.36.\n",
"For batch 5, loss is 4.35.\n",
"For batch 6, loss is 4.30.\n",
"For batch 7, loss is 4.25.\n"
]
}
],
"source": [
"class LossAndErrorPrintingCallback(keras.callbacks.Callback):\n",
" def on_train_batch_end(self, batch, logs=None):\n",
" print(\"For batch {}, loss is {:7.2f}.\".format(batch, logs[\"loss\"]))\n",
"\n",
" def on_test_batch_end(self, batch, logs=None):\n",
" print(\"For batch {}, loss is {:7.2f}.\".format(batch, logs[\"loss\"]))\n",
"\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" print(\n",
" \"The average loss for epoch {} is {:7.2f} \"\n",
" \"and mean absolute error is {:7.2f}.\".format(\n",
" epoch, logs[\"loss\"], logs[\"mean_absolute_error\"]\n",
" )\n",
" )\n",
"\n",
"\n",
"model = get_model()\n",
"model.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size=128,\n",
" epochs=2,\n",
" verbose=0,\n",
" callbacks=[LossAndErrorPrintingCallback()],\n",
")\n",
"\n",
"res = model.evaluate(\n",
" x_test,\n",
" y_test,\n",
" batch_size=128,\n",
" verbose=0,\n",
" callbacks=[LossAndErrorPrintingCallback()],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "742d62e5394a"
},
"source": [
"## `self.model` 属性を使用する\n",
"\n",
"コールバックは、そのメソッドの 1 つが呼び出された時にログ情報を受け取ることに加え、現在のトレーニング/評価/推論のラウンドに関連付けられたモデルに、`self.model` でアクセスすることができます。\n",
"\n",
"コールバックで `self.model` を使用してできることを幾つか次に示します。\n",
"\n",
"- `self.model.stop_training = True` を設定して直ちにトレーニングを中断する。\n",
"- `self.model.optimizer.learning_rate` など、オプティマイザ（`self.model.optimizer` として使用可能）のハイパーパラメータを変化させる。\n",
"- 一定間隔でモデルを保存する。\n",
"- 各エポックの終了時に幾つかのテストサンプルの `model.predict()` の出力を記録し、トレーニング中にサ二ティーチェックとして使用する。\n",
"- 各エポックの終了時に中間特徴の可視化を抽出して、モデルが何を学習しているかを経時的に監視する。\n",
"- など\n",
"\n",
"これを確認するために、2 つの例で見てみましょう。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7eb29d3ed752"
},
"source": [
"## Keras コールバックアプリケーションの例"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2d1d29d99fa5"
},
"source": [
"### 最小損失で Early stopping する\n",
"\n",
"この最初の例は、属性 `self.model.stop_training`（ブール）を設定して、損失の最小値に達した時点でトレーニングを停止する `Callback` を作成しています。オプションで、ローカル最小値に到達した後、実際に停止するまでに幾つのエポックを待つべきか、引数 `patience` で指定することが可能です。\n",
"\n",
"`tf.keras.callbacks.EarlyStopping` は、より完全で一般的な実装を提供します。"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:47.232608Z",
"iopub.status.busy": "2021-02-12T21:32:47.229781Z",
"iopub.status.idle": "2021-02-12T21:32:47.579577Z",
"shell.execute_reply": "2021-02-12T21:32:47.580005Z"
},
"id": "5d2acd79cecd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For batch 0, loss is 31.99.\n",
"For batch 1, loss is 456.49.\n",
"For batch 2, loss is 316.17.\n",
"For batch 3, loss is 240.72.\n",
"For batch 4, loss is 194.07.\n",
"The average loss for epoch 0 is 194.07 and mean absolute error is 8.71.\n",
"For batch 0, loss is 5.99.\n",
"For batch 1, loss is 5.77.\n",
"For batch 2, loss is 5.26.\n",
"For batch 3, loss is 5.28.\n",
"For batch 4, loss is 5.48.\n",
"The average loss for epoch 1 is 5.48 and mean absolute error is 1.90.\n",
"For batch 0, loss is 6.50.\n",
"For batch 1, loss is 4.84.\n",
"For batch 2, loss is 4.86.\n",
"For batch 3, loss is 4.95.\n",
"For batch 4, loss is 4.73.\n",
"The average loss for epoch 2 is 4.73 and mean absolute error is 1.76.\n",
"For batch 0, loss is 4.44.\n",
"For batch 1, loss is 5.31.\n",
"For batch 2, loss is 5.56.\n",
"For batch 3, loss is 5.97.\n",
"For batch 4, loss is 6.98.\n",
"The average loss for epoch 3 is 6.98 and mean absolute error is 2.10.\n",
"Restoring model weights from the end of the best epoch.\n",
"Epoch 00004: early stopping\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"\n",
"\n",
"class EarlyStoppingAtMinLoss(keras.callbacks.Callback):\n",
" \"\"\"Stop training when the loss is at its min, i.e. the loss stops decreasing.\n",
"\n",
" Arguments:\n",
" patience: Number of epochs to wait after min has been hit. After this\n",
" number of no improvement, training stops.\n",
" \"\"\"\n",
"\n",
" def __init__(self, patience=0):\n",
" super(EarlyStoppingAtMinLoss, self).__init__()\n",
" self.patience = patience\n",
" # best_weights to store the weights at which the minimum loss occurs.\n",
" self.best_weights = None\n",
"\n",
" def on_train_begin(self, logs=None):\n",
" # The number of epoch it has waited when loss is no longer minimum.\n",
" self.wait = 0\n",
" # The epoch the training stops at.\n",
" self.stopped_epoch = 0\n",
" # Initialize the best as infinity.\n",
" self.best = np.Inf\n",
"\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" current = logs.get(\"loss\")\n",
" if np.less(current, self.best):\n",
" self.best = current\n",
" self.wait = 0\n",
" # Record the best weights if current results is better (less).\n",
" self.best_weights = self.model.get_weights()\n",
" else:\n",
" self.wait += 1\n",
" if self.wait >= self.patience:\n",
" self.stopped_epoch = epoch\n",
" self.model.stop_training = True\n",
" print(\"Restoring model weights from the end of the best epoch.\")\n",
" self.model.set_weights(self.best_weights)\n",
"\n",
" def on_train_end(self, logs=None):\n",
" if self.stopped_epoch > 0:\n",
" print(\"Epoch %05d: early stopping\" % (self.stopped_epoch + 1))\n",
"\n",
"\n",
"model = get_model()\n",
"model.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size=64,\n",
" steps_per_epoch=5,\n",
" epochs=30,\n",
" verbose=0,\n",
" callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "939ecfbe0383"
},
"source": [
"### 学習率をスケジューリングする\n",
"\n",
"この例では、トレーニングの過程でカスタムコールバックを使用して、オプティマイザの学習率を動的に変更する方法を示します。\n",
"\n",
"より一般的な実装については、`callbacks.LearningRateScheduler` をご覧ください。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2021-02-12T21:32:47.590429Z",
"iopub.status.busy": "2021-02-12T21:32:47.589849Z",
"iopub.status.idle": "2021-02-12T21:32:48.057347Z",
"shell.execute_reply": "2021-02-12T21:32:48.056664Z"
},
"id": "71c752b248c0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 00000: Learning rate is 0.1000.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"For batch 0, loss is 33.29.\n",
"For batch 1, loss is 387.77.\n",
"For batch 2, loss is 268.60.\n",
"For batch 3, loss is 205.39.\n",
"For batch 4, loss is 166.73.\n",
"The average loss for epoch 0 is 166.73 and mean absolute error is 8.26.\n",
"\n",
"Epoch 00001: Learning rate is 0.1000.\n",
"For batch 0, loss is 6.32.\n",
"For batch 1, loss is 6.86.\n",
"For batch 2, loss is 6.77.\n",
"For batch 3, loss is 6.16.\n",
"For batch 4, loss is 6.01.\n",
"The average loss for epoch 1 is 6.01 and mean absolute error is 2.01.\n",
"\n",
"Epoch 00002: Learning rate is 0.1000.\n",
"For batch 0, loss is 6.02.\n",
"For batch 1, loss is 6.14.\n",
"For batch 2, loss is 6.53.\n",
"For batch 3, loss is 6.60.\n",
"For batch 4, loss is 7.06.\n",
"The average loss for epoch 2 is 7.06 and mean absolute error is 2.24.\n",
"\n",
"Epoch 00003: Learning rate is 0.0500.\n",
"For batch 0, loss is 22.22.\n",
"For batch 1, loss is 12.58.\n",
"For batch 2, loss is 9.46.\n",
"For batch 3, loss is 8.06.\n",
"For batch 4, loss is 7.24.\n",
"The average loss for epoch 3 is 7.24 and mean absolute error is 2.05.\n",
"\n",
"Epoch 00004: Learning rate is 0.0500.\n",
"For batch 0, loss is 4.29.\n",
"For batch 1, loss is 3.73.\n",
"For batch 2, loss is 4.17.\n",
"For batch 3, loss is 4.00.\n",
"For batch 4, loss is 4.07.\n",
"The average loss for epoch 4 is 4.07 and mean absolute error is 1.55.\n",
"\n",
"Epoch 00005: Learning rate is 0.0500.\n",
"For batch 0, loss is 4.19.\n",
"For batch 1, loss is 4.38.\n",
"For batch 2, loss is 4.37.\n",
"For batch 3, loss is 4.78.\n",
"For batch 4, loss is 6.04.\n",
"The average loss for epoch 5 is 6.04 and mean absolute error is 1.97.\n",
"\n",
"Epoch 00006: Learning rate is 0.0100.\n",
"For batch 0, loss is 18.12.\n",
"For batch 1, loss is 13.69.\n",
"For batch 2, loss is 10.70.\n",
"For batch 3, loss is 8.68.\n",
"For batch 4, loss is 7.65.\n",
"The average loss for epoch 6 is 7.65 and mean absolute error is 2.21.\n",
"\n",
"Epoch 00007: Learning rate is 0.0100.\n",
"For batch 0, loss is 4.44.\n",
"For batch 1, loss is 3.88.\n",
"For batch 2, loss is 3.77.\n",
"For batch 3, loss is 3.73.\n",
"For batch 4, loss is 3.73.\n",
"The average loss for epoch 7 is 3.73 and mean absolute error is 1.50.\n",
"\n",
"Epoch 00008: Learning rate is 0.0100.\n",
"For batch 0, loss is 3.97.\n",
"For batch 1, loss is 4.00.\n",
"For batch 2, loss is 3.94.\n",
"For batch 3, loss is 3.86.\n",
"For batch 4, loss is 3.75.\n",
"The average loss for epoch 8 is 3.75 and mean absolute error is 1.49.\n",
"\n",
"Epoch 00009: Learning rate is 0.0050.\n",
"For batch 0, loss is 2.06.\n",
"For batch 1, loss is 2.53.\n",
"For batch 2, loss is 2.68.\n",
"For batch 3, loss is 2.75.\n",
"For batch 4, loss is 3.13.\n",
"The average loss for epoch 9 is 3.13 and mean absolute error is 1.39.\n",
"\n",
"Epoch 00010: Learning rate is 0.0050.\n",
"For batch 0, loss is 4.04.\n",
"For batch 1, loss is 3.53.\n",
"For batch 2, loss is 3.26.\n",
"For batch 3, loss is 2.98.\n",
"For batch 4, loss is 2.97.\n",
"The average loss for epoch 10 is 2.97 and mean absolute error is 1.35.\n",
"\n",
"Epoch 00011: Learning rate is 0.0050.\n",
"For batch 0, loss is 4.37.\n",
"For batch 1, loss is 3.51.\n",
"For batch 2, loss is 3.39.\n",
"For batch 3, loss is 3.55.\n",
"For batch 4, loss is 3.48.\n",
"The average loss for epoch 11 is 3.48 and mean absolute error is 1.46.\n",
"\n",
"Epoch 00012: Learning rate is 0.0010.\n",
"For batch 0, loss is 3.20.\n",
"For batch 1, loss is 3.12.\n",
"For batch 2, loss is 3.21.\n",
"For batch 3, loss is 3.13.\n",
"For batch 4, loss is 3.41.\n",
"The average loss for epoch 12 is 3.41 and mean absolute error is 1.45.\n",
"\n",
"Epoch 00013: Learning rate is 0.0010.\n",
"For batch 0, loss is 2.97.\n",
"For batch 1, loss is 3.40.\n",
"For batch 2, loss is 3.36.\n",
"For batch 3, loss is 3.19.\n",
"For batch 4, loss is 3.33.\n",
"The average loss for epoch 13 is 3.33 and mean absolute error is 1.41.\n",
"\n",
"Epoch 00014: Learning rate is 0.0010.\n",
"For batch 0, loss is 3.69.\n",
"For batch 1, loss is 2.83.\n",
"For batch 2, loss is 3.09.\n",
"For batch 3, loss is 3.07.\n",
"For batch 4, loss is 3.07.\n",
"The average loss for epoch 14 is 3.07 and mean absolute error is 1.37.\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class CustomLearningRateScheduler(keras.callbacks.Callback):\n",
" \"\"\"Learning rate scheduler which sets the learning rate according to schedule.\n",
"\n",
" Arguments:\n",
" schedule: a function that takes an epoch index\n",
" (integer, indexed from 0) and current learning rate\n",
" as inputs and returns a new learning rate as output (float).\n",
" \"\"\"\n",
"\n",
" def __init__(self, schedule):\n",
" super(CustomLearningRateScheduler, self).__init__()\n",
" self.schedule = schedule\n",
"\n",
" def on_epoch_begin(self, epoch, logs=None):\n",
" if not hasattr(self.model.optimizer, \"lr\"):\n",
" raise ValueError('Optimizer must have a \"lr\" attribute.')\n",
" # Get the current learning rate from model's optimizer.\n",
" lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))\n",
" # Call schedule function to get the scheduled learning rate.\n",
" scheduled_lr = self.schedule(epoch, lr)\n",
" # Set the value back to the optimizer before this epoch starts\n",
" tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)\n",
" print(\"\\nEpoch %05d: Learning rate is %6.4f.\" % (epoch, scheduled_lr))\n",
"\n",
"\n",
"LR_SCHEDULE = [\n",
" # (epoch to start, learning rate) tuples\n",
" (3, 0.05),\n",
" (6, 0.01),\n",
" (9, 0.005),\n",
" (12, 0.001),\n",
"]\n",
"\n",
"\n",
"def lr_schedule(epoch, lr):\n",
" \"\"\"Helper function to retrieve the scheduled learning rate based on epoch.\"\"\"\n",
" if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:\n",
" return lr\n",
" for i in range(len(LR_SCHEDULE)):\n",
" if epoch == LR_SCHEDULE[i][0]:\n",
" return LR_SCHEDULE[i][1]\n",
" return lr\n",
"\n",
"\n",
"model = get_model()\n",
"model.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size=64,\n",
" steps_per_epoch=5,\n",
" epochs=15,\n",
" verbose=0,\n",
" callbacks=[\n",
" LossAndErrorPrintingCallback(),\n",
" CustomLearningRateScheduler(lr_schedule),\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9be225b57f1"
},
"source": [
"### 組み込みの Keras コールバック\n",
"\n",
"既存の Keras コールバックについては、[API ドキュメント](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/)を読んで必ず確認してください。アプリケーションには、CSV へのロギング、モデルの保存、TensorBoard でのメトリクスの可視化、その他多数があります。"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "custom_callback.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}