![]() | ![]() | ![]() | ![]() |
Buku catatan ini melatih model urutan ke urutan (seq2seq) untuk terjemahan bahasa Spanyol ke bahasa Inggris. Ini adalah contoh lanjutan yang mengasumsikan beberapa pengetahuan tentang model urutan ke model urutan.
Setelah melatih model di buku catatan ini, Anda akan bisa memasukkan kalimat bahasa Spanyol, seperti "¿todavia estan en casa?" , dan mengembalikan terjemahan bahasa Inggris: "apakah kamu masih di rumah?"
Kualitas terjemahannya masuk akal untuk sebuah contoh mainan, tetapi plot perhatian yang dihasilkan mungkin lebih menarik. Ini menunjukkan bagian mana dari kalimat masukan yang mendapat perhatian model saat menerjemahkan:
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split
import unicodedata
import re
import numpy as np
import os
import io
import time
Unduh dan siapkan dataset
Kami akan menggunakan kumpulan data bahasa yang disediakan oleh http://www.manythings.org/anki/ Kumpulan data ini berisi pasangan terjemahan bahasa dalam format:
May I borrow this book? ¿Puedo tomar prestado este libro?
Ada berbagai bahasa yang tersedia, tetapi kami akan menggunakan kumpulan data Inggris-Spanyol. Untuk kenyamanan, kami telah menghosting salinan set data ini di Google Cloud, tetapi Anda juga dapat mendownload salinan Anda sendiri. Setelah mengunduh dataset, berikut adalah langkah-langkah yang akan kami lakukan untuk menyiapkan data:
- Tambahkan tanda awal dan akhir untuk setiap kalimat.
- Bersihkan kalimat dengan menghapus karakter khusus.
- Buat indeks kata dan indeks kata terbalik (kamus memetakan dari kata → id dan id → kata).
- Padatkan setiap kalimat dengan panjang maksimum.
# Download the file
path_to_zip = tf.keras.utils.get_file(
'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
extract=True)
path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip 2646016/2638744 [==============================] - 0s 0us/step
# Converts the unicode file to ascii
def unicode_to_ascii(s):
return ''.join(c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn')
def preprocess_sentence(w):
w = unicode_to_ascii(w.lower().strip())
# creating a space between a word and the punctuation following it
# eg: "he is a boy." => "he is a boy ."
# Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
w = re.sub(r"([?.!,¿])", r" \1 ", w)
w = re.sub(r'[" "]+', " ", w)
# replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
w = w.strip()
# adding a start and an end token to the sentence
# so that the model know when to start and stop predicting.
w = '<start> ' + w + ' <end>'
return w
en_sentence = u"May I borrow this book?"
sp_sentence = u"¿Puedo tomar prestado este libro?"
print(preprocess_sentence(en_sentence))
print(preprocess_sentence(sp_sentence).encode('utf-8'))
<start> may i borrow this book ? <end> b'<start> \xc2\xbf puedo tomar prestado este libro ? <end>'
# 1. Remove the accents
# 2. Clean the sentences
# 3. Return word pairs in the format: [ENGLISH, SPANISH]
def create_dataset(path, num_examples):
lines = io.open(path, encoding='UTF-8').read().strip().split('\n')
word_pairs = [[preprocess_sentence(w) for w in l.split('\t')] for l in lines[:num_examples]]
return zip(*word_pairs)
en, sp = create_dataset(path_to_file, None)
print(en[-1])
print(sp[-1])
<start> if you want to sound like a native speaker , you must be willing to practice saying the same sentence over and over in the same way that banjo players practice the same phrase over and over until they can play it correctly and at the desired tempo . <end> <start> si quieres sonar como un hablante nativo , debes estar dispuesto a practicar diciendo la misma frase una y otra vez de la misma manera en que un musico de banjo practica el mismo fraseo una y otra vez hasta que lo puedan tocar correctamente y en el tiempo esperado . <end>
def tokenize(lang):
lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
filters='')
lang_tokenizer.fit_on_texts(lang)
tensor = lang_tokenizer.texts_to_sequences(lang)
tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,
padding='post')
return tensor, lang_tokenizer
def load_dataset(path, num_examples=None):
# creating cleaned input, output pairs
targ_lang, inp_lang = create_dataset(path, num_examples)
input_tensor, inp_lang_tokenizer = tokenize(inp_lang)
target_tensor, targ_lang_tokenizer = tokenize(targ_lang)
return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer
Batasi ukuran kumpulan data untuk bereksperimen lebih cepat (opsional)
Pelatihan tentang kumpulan data lengkap> 100.000 kalimat akan membutuhkan waktu lama. Untuk berlatih lebih cepat, kami dapat membatasi ukuran kumpulan data menjadi 30.000 kalimat (tentu saja, kualitas terjemahan menurun dengan lebih sedikit data):
# Try experimenting with the size of that dataset
num_examples = 30000
input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)
# Calculate max_length of the target tensors
max_length_targ, max_length_inp = target_tensor.shape[1], input_tensor.shape[1]
# Creating training and validation sets using an 80-20 split
input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)
# Show length
print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))
24000 24000 6000 6000
def convert(lang, tensor):
for t in tensor:
if t!=0:
print ("%d ----> %s" % (t, lang.index_word[t]))
print ("Input Language; index to word mapping")
convert(inp_lang, input_tensor_train[0])
print ()
print ("Target Language; index to word mapping")
convert(targ_lang, target_tensor_train[0])
Input Language; index to word mapping 1 ----> <start> 6379 ----> dese 395 ----> vuelta 32 ----> , 22 ----> por 50 ----> favor 3 ----> . 2 ----> <end> Target Language; index to word mapping 1 ----> <start> 56 ----> please 205 ----> turn 197 ----> over 3 ----> . 2 ----> <end>
Buat set data tf.data
BUFFER_SIZE = len(input_tensor_train)
BATCH_SIZE = 64
steps_per_epoch = len(input_tensor_train)//BATCH_SIZE
embedding_dim = 256
units = 1024
vocab_inp_size = len(inp_lang.word_index)+1
vocab_tar_size = len(targ_lang.word_index)+1
dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
example_input_batch, example_target_batch = next(iter(dataset))
example_input_batch.shape, example_target_batch.shape
(TensorShape([64, 16]), TensorShape([64, 11]))
Tulis model encoder dan decoder
Implementasikan model encoder-decoder dengan perhatian yang dapat Anda baca di tutorial TensorFlow Neural Machine Translation (seq2seq) . Contoh ini menggunakan satu set API yang lebih baru. Notebook ini mengimplementasikan persamaan perhatian dari tutorial seq2seq. Diagram berikut menunjukkan bahwa setiap kata masukan diberi bobot oleh mekanisme perhatian yang kemudian digunakan oleh decoder untuk memprediksi kata berikutnya dalam kalimat. Gambar dan rumus di bawah ini merupakan contoh mekanisme atensi dari makalah Luong .
Masukan dimasukkan melalui model encoder yang memberi kita keluaran encoder dalam bentuk (batch_size, max_length, hidden_size) dan encoder hidden state of shape (batch_size, hidden_size) .
Berikut persamaan yang diimplementasikan:
Tutorial ini menggunakan perhatian Bahdanau untuk pembuat enkode. Mari kita putuskan notasi sebelum menulis bentuk yang disederhanakan:
- FC = Lapisan terhubung penuh (padat)
- EO = Keluaran encoder
- H = keadaan tersembunyi
- X = masukan ke decoder
Dan pseudo-code:
-
score = FC(tanh(FC(EO) + FC(H)))
-
attention weights = softmax(score, axis = 1)
. Softmax secara default diterapkan pada sumbu terakhir tetapi di sini kami ingin menerapkannya pada sumbu pertama , karena bentuk skornya adalah (batch_size, max_length, hidden_size) .Max_length
adalah panjang input kita. Karena kami mencoba untuk menetapkan bobot ke setiap input, softmax harus diterapkan pada sumbu itu. -
context vector = sum(attention weights * EO, axis = 1)
. Alasan yang sama seperti di atas untuk memilih sumbu sebagai 1. -
embedding output
= Input ke decoder X dilewatkan melalui lapisan embedding. -
merged vector = concat(embedding output, context vector)
- Vektor gabungan ini kemudian diberikan ke GRU
Bentuk semua vektor di setiap langkah telah ditentukan di komentar di kode:
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
def call(self, x, hidden):
x = self.embedding(x)
output, state = self.gru(x, initial_state = hidden)
return output, state
def initialize_hidden_state(self):
return tf.zeros((self.batch_sz, self.enc_units))
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
# sample input
sample_hidden = encoder.initialize_hidden_state()
sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)
print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))
Encoder output shape: (batch size, sequence length, units) (64, 16, 1024) Encoder Hidden state shape: (batch size, units) (64, 1024)
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, query, values):
# query hidden state shape == (batch_size, hidden size)
# query_with_time_axis shape == (batch_size, 1, hidden size)
# values shape == (batch_size, max_len, hidden size)
# we are doing this to broadcast addition along the time axis to calculate the score
query_with_time_axis = tf.expand_dims(query, 1)
# score shape == (batch_size, max_length, 1)
# we get 1 at the last axis because we are applying score to self.V
# the shape of the tensor before applying self.V is (batch_size, max_length, units)
score = self.V(tf.nn.tanh(
self.W1(query_with_time_axis) + self.W2(values)))
# attention_weights shape == (batch_size, max_length, 1)
attention_weights = tf.nn.softmax(score, axis=1)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
attention_layer = BahdanauAttention(10)
attention_result, attention_weights = attention_layer(sample_hidden, sample_output)
print("Attention result shape: (batch size, units) {}".format(attention_result.shape))
print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))
Attention result shape: (batch size, units) (64, 1024) Attention weights shape: (batch_size, sequence_length, 1) (64, 16, 1)
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(self.dec_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
self.fc = tf.keras.layers.Dense(vocab_size)
# used for attention
self.attention = BahdanauAttention(self.dec_units)
def call(self, x, hidden, enc_output):
# enc_output shape == (batch_size, max_length, hidden_size)
context_vector, attention_weights = self.attention(hidden, enc_output)
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
x = self.embedding(x)
# x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
# passing the concatenated vector to the GRU
output, state = self.gru(x)
# output shape == (batch_size * 1, hidden_size)
output = tf.reshape(output, (-1, output.shape[2]))
# output shape == (batch_size, vocab)
x = self.fc(output)
return x, state, attention_weights
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)
sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),
sample_hidden, sample_output)
print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))
Decoder output shape: (batch_size, vocab size) (64, 4935)
Tentukan pengoptimal dan fungsi kerugian
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_mean(loss_)
Checkpoints (Penghematan berbasis objek)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
Latihan
- Teruskan masukan melalui pembuat enkode yang mengembalikan keluaran pembuat enkode dan status tersembunyi pembuat enkode .
- Keluaran pembuat enkode, status tersembunyi pembuat enkode, dan masukan dekoder (yang merupakan token awal ) diteruskan ke pembuat enkode.
- Dekoder mengembalikan prediksi dan status tersembunyi dekoder .
- Status tersembunyi decoder kemudian diteruskan kembali ke model dan prediksi digunakan untuk menghitung kerugian.
- Gunakan pemaksaan guru untuk memutuskan masukan berikutnya ke decoder.
- Pemaksaan guru adalah teknik di mana kata target diteruskan sebagai masukan berikutnya ke decoder.
- Langkah terakhir adalah menghitung gradien dan menerapkannya ke pengoptimal dan propagasi mundur.
@tf.function
def train_step(inp, targ, enc_hidden):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_hidden = encoder(inp, enc_hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)
# Teacher forcing - feeding the target as the next input
for t in range(1, targ.shape[1]):
# passing enc_output to the decoder
predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
loss += loss_function(targ[:, t], predictions)
# using teacher forcing
dec_input = tf.expand_dims(targ[:, t], 1)
batch_loss = (loss / int(targ.shape[1]))
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return batch_loss
EPOCHS = 10
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(inp, targ, enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1,
total_loss / steps_per_epoch))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
Epoch 1 Batch 0 Loss 4.7113 Epoch 1 Batch 100 Loss 2.1051 Epoch 1 Batch 200 Loss 1.9095 Epoch 1 Batch 300 Loss 1.7646 Epoch 1 Loss 2.0334 Time taken for 1 epoch 26.513352870941162 sec Epoch 2 Batch 0 Loss 1.4994 Epoch 2 Batch 100 Loss 1.4381 Epoch 2 Batch 200 Loss 1.3774 Epoch 2 Batch 300 Loss 1.1783 Epoch 2 Loss 1.3686 Time taken for 1 epoch 15.74858546257019 sec Epoch 3 Batch 0 Loss 0.9827 Epoch 3 Batch 100 Loss 1.0305 Epoch 3 Batch 200 Loss 0.9073 Epoch 3 Batch 300 Loss 0.8466 Epoch 3 Loss 0.9339 Time taken for 1 epoch 15.360853910446167 sec Epoch 4 Batch 0 Loss 0.5953 Epoch 4 Batch 100 Loss 0.6024 Epoch 4 Batch 200 Loss 0.6550 Epoch 4 Batch 300 Loss 0.6959 Epoch 4 Loss 0.6273 Time taken for 1 epoch 15.659878015518188 sec Epoch 5 Batch 0 Loss 0.4362 Epoch 5 Batch 100 Loss 0.4403 Epoch 5 Batch 200 Loss 0.5202 Epoch 5 Batch 300 Loss 0.3749 Epoch 5 Loss 0.4293 Time taken for 1 epoch 15.344685077667236 sec Epoch 6 Batch 0 Loss 0.3615 Epoch 6 Batch 100 Loss 0.2462 Epoch 6 Batch 200 Loss 0.2649 Epoch 6 Batch 300 Loss 0.3645 Epoch 6 Loss 0.2965 Time taken for 1 epoch 15.627461910247803 sec Epoch 7 Batch 0 Loss 0.2720 Epoch 7 Batch 100 Loss 0.1868 Epoch 7 Batch 200 Loss 0.2354 Epoch 7 Batch 300 Loss 0.2372 Epoch 7 Loss 0.2145 Time taken for 1 epoch 15.387472867965698 sec Epoch 8 Batch 0 Loss 0.1477 Epoch 8 Batch 100 Loss 0.1718 Epoch 8 Batch 200 Loss 0.1659 Epoch 8 Batch 300 Loss 0.1612 Epoch 8 Loss 0.1623 Time taken for 1 epoch 15.627415657043457 sec Epoch 9 Batch 0 Loss 0.0871 Epoch 9 Batch 100 Loss 0.1062 Epoch 9 Batch 200 Loss 0.1450 Epoch 9 Batch 300 Loss 0.1639 Epoch 9 Loss 0.1268 Time taken for 1 epoch 15.357704162597656 sec Epoch 10 Batch 0 Loss 0.0960 Epoch 10 Batch 100 Loss 0.0805 Epoch 10 Batch 200 Loss 0.1251 Epoch 10 Batch 300 Loss 0.1206 Epoch 10 Loss 0.1037 Time taken for 1 epoch 15.646350383758545 sec
Menterjemahkan
- Fungsi evaluasi mirip dengan loop pelatihan, kecuali kita tidak menggunakan penggerak guru di sini. Masukan ke decoder di setiap langkah waktu adalah prediksi sebelumnya bersama dengan status tersembunyi dan keluaran encoder.
- Berhenti memprediksi saat model memprediksi token akhir .
- Dan simpan bobot perhatian untuk setiap langkah waktu .
def evaluate(sentence):
attention_plot = np.zeros((max_length_targ, max_length_inp))
sentence = preprocess_sentence(sentence)
inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
maxlen=max_length_inp,
padding='post')
inputs = tf.convert_to_tensor(inputs)
result = ''
hidden = [tf.zeros((1, units))]
enc_out, enc_hidden = encoder(inputs, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)
for t in range(max_length_targ):
predictions, dec_hidden, attention_weights = decoder(dec_input,
dec_hidden,
enc_out)
# storing the attention weights to plot later on
attention_weights = tf.reshape(attention_weights, (-1, ))
attention_plot[t] = attention_weights.numpy()
predicted_id = tf.argmax(predictions[0]).numpy()
result += targ_lang.index_word[predicted_id] + ' '
if targ_lang.index_word[predicted_id] == '<end>':
return result, sentence, attention_plot
# the predicted ID is fed back into the model
dec_input = tf.expand_dims([predicted_id], 0)
return result, sentence, attention_plot
# function for plotting the attention weights
def plot_attention(attention, sentence, predicted_sentence):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
fontdict = {'fontsize': 14}
ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
def translate(sentence):
result, sentence, attention_plot = evaluate(sentence)
print('Input: %s' % (sentence))
print('Predicted translation: {}'.format(result))
attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
plot_attention(attention_plot, sentence.split(' '), result.split(' '))
Kembalikan pos pemeriksaan dan tes terbaru
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f3f4a04af60>
translate(u'hace mucho frio aqui.')
Input: <start> hace mucho frio aqui . <end> Predicted translation: it s very cold here . <end>
translate(u'esta es mi vida.')
Input: <start> esta es mi vida . <end> Predicted translation: this is my life . <end>
translate(u'¿todavia estan en casa?')
Input: <start> ¿ todavia estan en casa ? <end> Predicted translation: are you still at home ? <end>
# wrong translation
translate(u'trata de averiguarlo.')
Input: <start> trata de averiguarlo . <end> Predicted translation: try to figure it out . <end>
Langkah selanjutnya
- Unduh kumpulan data yang berbeda untuk bereksperimen dengan terjemahan, misalnya, Inggris ke Jerman, atau Inggris ke Prancis.
- Bereksperimenlah dengan pelatihan pada kumpulan data yang lebih besar, atau gunakan lebih banyak waktu