![]() |
![]() |
![]() |
![]() |
このチュートリアルでは、文字ベースの RNN を使ってテキストを生成する方法を示します。ここでは、Andrej Karpathy の The Unreasonable Effectiveness of Recurrent Neural Networks からのシェイクスピア作品のデータセットを使います。このデータからの文字列("Shakespear")を入力にして、文字列中の次の文字("e")を予測するモデルを訓練します。このモデルを繰り返し呼び出すことで、より長い文字列を生成することができます。
このチュートリアルには、tf.keras と eager execution を使ったコードが含まれています。下記は、このチュートリアルのモデルを 30 エポック訓練したものに対して、文字列 "Q" を初期値とした場合の出力例です。
QUEENE: I had thought thou hadst a Roman; for the oracle, Thus by All bids the man against the word, Which are so weak of care, by old care done; Your children were in your holy love, And the precipitation through the bleeding throne. BISHOP OF ELY: Marry, and will, my lord, to weep in such a one were prettiest; Yet now I was adopted heir Of the world's lamentable day, To watch the next way with his father with his face? ESCALUS: The cause why then we are all resolved more sons. VOLUMNIA: O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead, And love and pale as any will to that word. QUEEN ELIZABETH: But how long have I heard the soul for this world, And show his hands of life be proved to stand. PETRUCHIO: I say he look'd on, if I must be content To stay him from the fatal of our country's bliss. His lordship pluck'd from this sentence then for prey, And then let us twain, being the moon, were she such a case as fills m
いくつかは文法にあったものがある一方で、ほとんどは意味をなしていません。このモデルは、単語の意味を学習していませんが、次のことを考えてみてください。
このモデルは文字ベースです。訓練が始まった時に、モデルは英語の単語のスペルも知りませんし、単語がテキストの単位であることも知らないのです。
出力の構造は戯曲に似ています。だいたいのばあい、データセットとおなじ大文字で書かれた話し手の名前で始まっています。
以下に示すように、モデルはテキストの小さなバッチ(各100文字)で訓練されていますが、一貫した構造のより長いテキストのシーケンスを生成できます。
設定
TensorFlow 等のライブラリインポート
import tensorflow as tf
import numpy as np
import os
import time
2022-08-08 20:19:47.701165: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2022-08-08 20:19:48.474159: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-08 20:19:48.474417: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-08 20:19:48.474430: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
シェイクスピアデータセットのダウンロード
独自のデータで実行するためには下記の行を変更してください。
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt 1115394/1115394 [==============================] - 0s 0us/step
データの読み込み
まずはテキストをのぞいてみましょう。
# 読み込んだのち、Python 2 との互換性のためにデコード
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# テキストの長さは含まれる文字数
print ('Length of text: {} characters'.format(len(text)))
Length of text: 1115394 characters
# テキストの最初の 250文字を参照
print(text[:250])
First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people.
# ファイル中のユニークな文字の数
vocab = sorted(set(text))
print ('{} unique characters'.format(len(vocab)))
65 unique characters
テキストの処理
テキストのベクトル化
訓練をする前に、文字列を数値表現に変換する必要があります。2つの参照テーブルを作成します。一つは文字を数字に変換するもの、もう一つは数字を文字に変換するものです。
# それぞれの文字からインデックスへの対応表を作成
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])
これで、それぞれの文字を整数で表現できました。文字を、0 からlen(unique)
までのインデックスに変換していることに注意してください。
print('{')
for char,_ in zip(char2idx, range(20)):
print(' {:4s}: {:3d},'.format(repr(char), char2idx[char]))
print(' ...\n}')
{ '\n': 0, ' ' : 1, '!' : 2, '$' : 3, '&' : 4, "'" : 5, ',' : 6, '-' : 7, '.' : 8, '3' : 9, ':' : 10, ';' : 11, '?' : 12, 'A' : 13, 'B' : 14, 'C' : 15, 'D' : 16, 'E' : 17, 'F' : 18, 'G' : 19, ... }
# テキストの最初の 13 文字がどのように整数に変換されるかを見てみる
print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))
'First Citizen' ---- characters mapped to int ---- > [18 47 56 57 58 1 15 47 58 47 64 43 52]
予測タスク
ある文字、あるいは文字列が与えられたとき、もっともありそうな次の文字はなにか?これが、モデルを訓練してやらせたいタスクです。モデルへの入力は文字列であり、モデルが出力、つまりそれぞれの時点での次の文字を予測をするようにモデルを訓練します。
RNN はすでに見た要素に基づく内部状態を保持しているため、この時点までに計算されたすべての文字を考えると、次の文字は何でしょうか?
訓練用サンプルとターゲットを作成
つぎに、テキストをサンプルシーケンスに分割します。それぞれの入力シーケンスは、元のテキストからの seq_length
個の文字を含みます。
入力シーケンスそれぞれに対して、対応するターゲットは同じ長さのテキストを含みますが、1文字ずつ右にシフトしたものです。
そのため、テキストを seq_length+1
のかたまりに分割します。たとえば、 seq_length
が 4 で、テキストが "Hello" だとします。入力シーケンスは "Hell" で、ターゲットシーケンスは "ello" となります。
これを行うために、最初に tf.data.Dataset.from_tensor_slices
関数を使ってテキストベクトルを文字インデックスの連続に変換します。
# ひとつの入力としたいシーケンスの文字数としての最大の長さ
seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)
# 訓練用サンプルとターゲットを作る
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
for i in char_dataset.take(5):
print(idx2char[i.numpy()])
F i r s t
batch
メソッドを使うと、個々の文字を求める長さのシーケンスに簡単に変換できます。
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)
for item in sequences.take(5):
print(repr(''.join(idx2char[item.numpy()])))
'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou ' 'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k' "now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki" "ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d" 'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'
シーケンスそれぞれに対して、map
メソッドを使って各バッチに単純な関数を適用することで、複製とシフトを行い、入力テキストとターゲットテキストを生成します。
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text
dataset = sequences.map(split_input_target)
最初のサンプルの入力とターゲットを出力します。
for input_example, target_example in dataset.take(1):
print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))
print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))
Input data: 'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou' Target data: 'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
これらのベクトルのインデックスそれぞれが一つのタイムステップとして処理されます。タイムステップ 0 の入力として、モデルは "F" のインデックスを受け取り、次の文字として "i" のインデックスを予測しようとします。次のタイムステップでもおなじことをしますが、RNN
は現在の入力文字に加えて、過去のステップのコンテキストも考慮します。
for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):
print("Step {:4d}".format(i))
print(" input: {} ({:s})".format(input_idx, repr(idx2char[input_idx])))
print(" expected output: {} ({:s})".format(target_idx, repr(idx2char[target_idx])))
Step 0 input: 18 ('F') expected output: 47 ('i') Step 1 input: 47 ('i') expected output: 56 ('r') Step 2 input: 56 ('r') expected output: 57 ('s') Step 3 input: 57 ('s') expected output: 58 ('t') Step 4 input: 58 ('t') expected output: 1 (' ')
訓練用バッチの作成
tf.data
を使ってテキストを分割し、扱いやすいシーケンスにします。しかし、このデータをモデルに供給する前に、データをシャッフルしてバッチにまとめる必要があります。
# バッチサイズ
BATCH_SIZE = 64
# データセットをシャッフルするためのバッファサイズ
# (TF data は可能性として無限長のシーケンスでも使えるように設計されています。
# このため、シーケンス全体をメモリ内でシャッフルしようとはしません。
# その代わりに、要素をシャッフルするためのバッファを保持しています)
BUFFER_SIZE = 10000
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
dataset
<BatchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>
モデルの構築
tf.keras.Sequential
を使ってモデルを定義します。この簡単な例では、モデルの定義に3つのレイヤーを使用しています。
tf.keras.layers.Embedding
: 入力レイヤー。それぞれの文字を表す数をembedding_dim
次元のベクトルに変換する、訓練可能な参照テーブル。tf.keras.layers.GRU
: サイズがunits=rnn_units
のRNNの一種(ここに LSTM レイヤーを使うこともできる)。tf.keras.layers.Dense
:vocab_size
の出力を持つ、出力レイヤー。
# 文字数で表されるボキャブラリーの長さ
vocab_size = len(vocab)
# 埋め込みベクトルの次元
embedding_dim = 256
# RNN ユニットの数
rnn_units = 1024
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim,
batch_input_shape=[batch_size, None]),
tf.keras.layers.GRU(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dense(vocab_size)
])
return model
model = build_model(
vocab_size = len(vocab),
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE)
1文字ごとにモデルは埋め込みベクトルを検索し、その埋め込みベクトルを入力として GRU を 1 タイムステップ実行します。そして Dense レイヤーを適用して、次の文字の対数尤度を予測するロジットを生成します。
モデルを試す
期待通りに動作するかどうかを確認するためモデルを動かしてみましょう。
最初に、出力の shape を確認します。
for input_example_batch, target_example_batch in dataset.take(1):
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
(64, 100, 65) # (batch_size, sequence_length, vocab_size)
上記の例では、入力のシーケンスの長さは 100
ですが、モデルはどのような長さの入力でも実行できます。
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) (64, None, 256) 16640 gru (GRU) (64, None, 1024) 3938304 dense (Dense) (64, None, 65) 66625 ================================================================= Total params: 4,021,569 Trainable params: 4,021,569 Non-trainable params: 0 _________________________________________________________________
モデルから実際の予測を得るには出力の分布からサンプリングを行う必要があります。この分布は、文字ボキャブラリー全体のロジットで定義されます。
バッチ中の最初のサンプルで試してみましょう。
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()
これにより、タイムステップそれぞれにおいて、次の文字のインデックスの予測が得られます。
sampled_indices
array([32, 48, 26, 30, 37, 23, 55, 16, 39, 57, 61, 58, 56, 12, 7, 47, 12, 12, 39, 14, 33, 11, 28, 34, 26, 43, 16, 1, 24, 15, 8, 44, 3, 44, 23, 47, 64, 21, 62, 49, 10, 31, 51, 51, 48, 14, 57, 60, 15, 35, 38, 43, 21, 28, 21, 37, 37, 13, 19, 36, 34, 34, 28, 17, 46, 46, 27, 5, 54, 48, 50, 9, 16, 56, 36, 29, 46, 52, 5, 48, 8, 0, 23, 44, 52, 15, 38, 5, 13, 14, 30, 62, 13, 52, 34, 42, 5, 49, 57, 39])
これらをデコードすることで、この訓練前のモデルによる予測テキストをみることができます。
print("Input: \n", repr("".join(idx2char[input_example_batch[0]])))
print()
print("Next Char Predictions: \n", repr("".join(idx2char[sampled_indices ])))
Input: 'eak to-night\nFain would I dwell on form, fain, fain deny\nWhat I have spoke: but farewell compliment!' Next Char Predictions: "TjNRYKqDaswtr?-i??aBU;PVNeD LC.f$fKizIxk:SmmjBsvCWZeIPIYYAGXVVPEhhO'pjl3DrXQhn'j.\nKfnCZ'ABRxAnVd'ksa"
モデルの訓練
ここまでくれば問題は標準的な分類問題として扱うことができます。これまでの RNN の状態と、いまのタイムステップの入力が与えられ、次の文字のクラスを予測します。
オプティマイザと損失関数の付加
この場合、標準の tf.keras.losses.sparse_categorical_crossentropy
損失関数が使えます。予測の最後の次元に適用されるからです。
このモデルはロジットを返すので、from_logits
フラグをセットする必要があります。
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("scalar_loss: ", example_batch_loss.numpy().mean())
Prediction shape: (64, 100, 65) # (batch_size, sequence_length, vocab_size) scalar_loss: 4.173848
tf.keras.Model.compile
を使って、訓練手順を定義します。既定の引数を持った tf.keras.optimizers.Adam
と、先ほどの loss 関数を使用しましょう。
model.compile(optimizer='adam', loss=loss)
チェックポイントの構成
tf.keras.callbacks.ModelCheckpoint
を使って、訓練中にチェックポイントを保存するようにします。
# チェックポイントが保存されるディレクトリ
checkpoint_dir = './training_checkpoints'
# チェックポイントファイルの名称
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
訓練の実行
訓練時間を適切に保つために、10エポックを使用してモデルを訓練します。Google Colab を使用する場合には、訓練を高速化するためにランタイムを GPU に設定します。
EPOCHS=10
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Epoch 1/10 172/172 [==============================] - 10s 40ms/step - loss: 2.7055 Epoch 2/10 172/172 [==============================] - 8s 40ms/step - loss: 1.9739 Epoch 3/10 172/172 [==============================] - 8s 40ms/step - loss: 1.7024 Epoch 4/10 172/172 [==============================] - 8s 40ms/step - loss: 1.5489 Epoch 5/10 172/172 [==============================] - 8s 40ms/step - loss: 1.4593 Epoch 6/10 172/172 [==============================] - 8s 40ms/step - loss: 1.3989 Epoch 7/10 172/172 [==============================] - 8s 40ms/step - loss: 1.3538 Epoch 8/10 172/172 [==============================] - 8s 40ms/step - loss: 1.3136 Epoch 9/10 172/172 [==============================] - 8s 40ms/step - loss: 1.2787 Epoch 10/10 172/172 [==============================] - 8s 40ms/step - loss: 1.2468
テキスト生成
最終チェックポイントの復元
予測ステップを単純にするため、バッチサイズ 1 を使用します。
RNN が状態をタイムステップからタイムステップへと渡す仕組みのため、モデルは一度構築されると固定されたバッチサイズしか受け付けられません。
モデルを異なる batch_size
で実行するためには、モデルを再構築し、チェックポイントから重みを復元する必要があります。
tf.train.latest_checkpoint(checkpoint_dir)
'./training_checkpoints/ckpt_10'
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding_1 (Embedding) (1, None, 256) 16640 gru_1 (GRU) (1, None, 1024) 3938304 dense_1 (Dense) (1, None, 65) 66625 ================================================================= Total params: 4,021,569 Trainable params: 4,021,569 Non-trainable params: 0 _________________________________________________________________
予測ループ
下記のコードブロックでテキストを生成します。
最初に、開始文字列を選択し、RNN の状態を初期化して、生成する文字数を設定します。
開始文字列と RNN の状態を使って、次の文字の予測分布を得ます。
つぎに、カテゴリー分布を使用して、予測された文字のインデックスを計算します。この予測された文字をモデルの次の入力にします。
モデルによって返された RNN の状態はモデルにフィードバックされるため、1つの文字だけでなく、より多くのコンテキストを持つことになります。つぎの文字を予測した後、更新された RNN の状態が再びモデルにフィードバックされます。こうしてモデルは以前に予測した文字からさらにコンテキストを得ることで学習するのです。
生成されたテキストを見ると、モデルがどこを大文字にするかや、段落の区切り方、シェークスピアらしい書き言葉を真似ることを知っていることがわかります。しかし、訓練のエポック数が少ないので、まだ一貫した文章を生成するところまでは学習していません。
def generate_text(model, start_string):
# 評価ステップ(学習済みモデルを使ったテキスト生成)
# 生成する文字数
num_generate = 1000
# 開始文字列を数値に変換(ベクトル化)
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
# 結果を保存する空文字列
text_generated = []
# 低い temperature は、より予測しやすいテキストをもたらし
# 高い temperature は、より意外なテキストをもたらす
# 実験により最適な設定を見つけること
temperature = 1.0
# ここではバッチサイズ == 1
model.reset_states()
for i in range(num_generate):
predictions = model(input_eval)
# バッチの次元を削除
predictions = tf.squeeze(predictions, 0)
# カテゴリー分布をつかってモデルから返された文字を予測
predictions = predictions / temperature
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
# 過去の隠れ状態とともに予測された文字をモデルへのつぎの入力として渡す
input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id])
return (start_string + ''.join(text_generated))
print(generate_text(model, start_string=u"ROMEO: "))
ROMEO: I could Ay her aboard. RITESSTER: As Pompry! so play, you'll say she was there no strange hour, spend it that hope I; I would fain lad-leantier, ey my lubute as you speak, do not home. GONZALO: A very goodly name, good Gromise: but the sestight of light, And her me bastard Clifford consuld thy knave in thiefors great; And power to see thy brother grow'st, Taking a tall you gave him tod the farewar not. JULIET: O, too much minally pingrous vell! Call't add the capap it, my true lords Shall all the woe behold the office, a by breath ere now, my lord, To make risen husband; and gold tongues that which engues inform thy wrong. RICHMOND: Slat to me to stage we would fetch you such a wakes As far married heils, and with their head for a white wedility, and I are mindy head Be think it would were so weary in the crown, Were still the enemy. MISTRESS OVERDONE: Well, good God! Ty thing to good mourn, kneel'd welcome us! But that thy streel's my father's spare. When, my Lord of Barnar, by y
この結果を改善するもっとも簡単な方法は、もっと長く訓練することです(EPOCHS=30
を試してみましょう)。
また、異なる初期文字列を使ったり、モデルの精度を向上させるためにもうひとつ RNN レイヤーを加えたり、temperature パラメータを調整して、よりランダム性の強い、あるいは、弱い予測を試してみたりすることができます。
上級編: 訓練のカスタマイズ
上記の訓練手順は単純ですが、制御できるところがそれほどありません。
モデルを手動で実行する方法を見てきたので、訓練ループを展開し、自分で実装してみましょう。このことが、たとえばモデルのオープンループによる出力を安定化するための カリキュラム学習 を実装するための出発点になります。
勾配を追跡するために tf.GradientTape
を使用します。このアプローチについての詳細を学ぶには、 eager execution guide をお読みください。
この手順は下記のように動作します。
最初に、RNN の状態を初期化する。
tf.keras.Model.reset_states
メソッドを呼び出すことでこれを実行する。つぎに、(1バッチずつ)データセットを順番に処理し、それぞれのバッチに対する予測値を計算する。
tf.GradientTape
をオープンし、そのコンテキストで、予測値と損失を計算する。tf.GradientTape.grads
メソッドを使って、モデルの変数に対する損失の勾配を計算する。最後に、オプティマイザの
tf.train.Optimizer.apply_gradients
メソッドを使って、逆方向の処理を行う。
model = build_model(
vocab_size = len(vocab),
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE)
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).layer_with_weights-0.embeddings WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).layer_with_weights-2.kernel WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).layer_with_weights-2.bias WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).layer_with_weights-1.cell.kernel WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).layer_with_weights-1.cell.recurrent_kernel WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).layer_with_weights-1.cell.bias WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).layer_with_weights-0.embeddings WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).layer_with_weights-2.kernel WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).layer_with_weights-2.bias WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).layer_with_weights-1.cell.kernel WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).layer_with_weights-1.cell.recurrent_kernel WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).layer_with_weights-1.cell.bias
optimizer = tf.keras.optimizers.Adam()
@tf.function
def train_step(inp, target):
with tf.GradientTape() as tape:
predictions = model(inp)
loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
target, predictions, from_logits=True))
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
# 訓練ステップ
EPOCHS = 10
for epoch in range(EPOCHS):
start = time.time()
# 各エポックの最初に、隠れ状態を初期化する
# 最初は隠れ状態は None
hidden = model.reset_states()
for (batch_n, (inp, target)) in enumerate(dataset):
loss = train_step(inp, target)
if batch_n % 100 == 0:
template = 'Epoch {} Batch {} Loss {}'
print(template.format(epoch+1, batch_n, loss))
# 5エポックごとにモデル(のチェックポイント)を保存する
if (epoch + 1) % 5 == 0:
model.save_weights(checkpoint_prefix.format(epoch=epoch))
print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))
print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
model.save_weights(checkpoint_prefix.format(epoch=epoch))
Epoch 1 Batch 0 Loss 4.1743550300598145 Epoch 1 Batch 100 Loss 2.333750009536743 Epoch 1 Loss 2.1486 Time taken for 1 epoch 8.718509435653687 sec Epoch 2 Batch 0 Loss 2.12961483001709 Epoch 2 Batch 100 Loss 1.9083470106124878 Epoch 2 Loss 1.8157 Time taken for 1 epoch 7.550929546356201 sec Epoch 3 Batch 0 Loss 1.793007254600525 Epoch 3 Batch 100 Loss 1.6301873922348022 Epoch 3 Loss 1.5751 Time taken for 1 epoch 7.586333751678467 sec Epoch 4 Batch 0 Loss 1.5677053928375244 Epoch 4 Batch 100 Loss 1.5695240497589111 Epoch 4 Loss 1.5237 Time taken for 1 epoch 7.602418422698975 sec Epoch 5 Batch 0 Loss 1.4697335958480835 Epoch 5 Batch 100 Loss 1.4353010654449463 Epoch 5 Loss 1.4499 Time taken for 1 epoch 7.613927602767944 sec Epoch 6 Batch 0 Loss 1.3613026142120361 Epoch 6 Batch 100 Loss 1.3927586078643799 Epoch 6 Loss 1.3805 Time taken for 1 epoch 7.568537473678589 sec Epoch 7 Batch 0 Loss 1.3312193155288696 Epoch 7 Batch 100 Loss 1.3174227476119995 Epoch 7 Loss 1.3735 Time taken for 1 epoch 7.580203056335449 sec Epoch 8 Batch 0 Loss 1.269272804260254 Epoch 8 Batch 100 Loss 1.2866979837417603 Epoch 8 Loss 1.3098 Time taken for 1 epoch 7.554703950881958 sec Epoch 9 Batch 0 Loss 1.2711602449417114 Epoch 9 Batch 100 Loss 1.3119456768035889 Epoch 9 Loss 1.3263 Time taken for 1 epoch 7.547409772872925 sec Epoch 10 Batch 0 Loss 1.205316424369812 Epoch 10 Batch 100 Loss 1.2689248323440552 Epoch 10 Loss 1.2364 Time taken for 1 epoch 7.631383180618286 sec