![]() | ![]() | ![]() | ![]() |
Este bloco de notas treina um modelo de sequência para sequência (seq2seq) para tradução de espanhol para inglês. Este é um exemplo avançado que pressupõe algum conhecimento de modelos de seqüência para seqüência.
Depois de treinar o modelo neste notebook, você poderá inserir uma frase em espanhol, como "¿todavia estan en casa?" , e retornar a tradução em inglês: "are you still at home?"
A qualidade da tradução é razoável para um exemplo de brinquedo, mas o gráfico de atenção gerado é talvez mais interessante. Isso mostra quais partes da frase de entrada chamam a atenção do modelo durante a tradução:
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split
import unicodedata
import re
import numpy as np
import os
import io
import time
Baixe e prepare o conjunto de dados
Usaremos um conjunto de dados de idioma fornecido por http://www.manythings.org/anki/ Este conjunto de dados contém pares de tradução de idioma no formato:
May I borrow this book? ¿Puedo tomar prestado este libro?
Há uma variedade de idiomas disponíveis, mas usaremos o conjunto de dados inglês-espanhol. Para sua conveniência, hospedamos uma cópia deste conjunto de dados no Google Cloud, mas você também pode baixar sua própria cópia. Depois de fazer o download do conjunto de dados, aqui estão as etapas que seguiremos para preparar os dados:
- Adicione um token de início e fim para cada frase.
- Limpe as frases removendo caracteres especiais.
- Crie um índice de palavras e inverta o índice de palavras (mapeamento de dicionários de palavra → id e id → palavra).
- Preencha cada frase com um comprimento máximo.
# Download the file
path_to_zip = tf.keras.utils.get_file(
'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
extract=True)
path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip 2646016/2638744 [==============================] - 0s 0us/step
# Converts the unicode file to ascii
def unicode_to_ascii(s):
return ''.join(c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn')
def preprocess_sentence(w):
w = unicode_to_ascii(w.lower().strip())
# creating a space between a word and the punctuation following it
# eg: "he is a boy." => "he is a boy ."
# Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
w = re.sub(r"([?.!,¿])", r" \1 ", w)
w = re.sub(r'[" "]+', " ", w)
# replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
w = w.strip()
# adding a start and an end token to the sentence
# so that the model know when to start and stop predicting.
w = '<start> ' + w + ' <end>'
return w
en_sentence = u"May I borrow this book?"
sp_sentence = u"¿Puedo tomar prestado este libro?"
print(preprocess_sentence(en_sentence))
print(preprocess_sentence(sp_sentence).encode('utf-8'))
<start> may i borrow this book ? <end> b'<start> \xc2\xbf puedo tomar prestado este libro ? <end>'
# 1. Remove the accents
# 2. Clean the sentences
# 3. Return word pairs in the format: [ENGLISH, SPANISH]
def create_dataset(path, num_examples):
lines = io.open(path, encoding='UTF-8').read().strip().split('\n')
word_pairs = [[preprocess_sentence(w) for w in l.split('\t')] for l in lines[:num_examples]]
return zip(*word_pairs)
en, sp = create_dataset(path_to_file, None)
print(en[-1])
print(sp[-1])
<start> if you want to sound like a native speaker , you must be willing to practice saying the same sentence over and over in the same way that banjo players practice the same phrase over and over until they can play it correctly and at the desired tempo . <end> <start> si quieres sonar como un hablante nativo , debes estar dispuesto a practicar diciendo la misma frase una y otra vez de la misma manera en que un musico de banjo practica el mismo fraseo una y otra vez hasta que lo puedan tocar correctamente y en el tiempo esperado . <end>
def tokenize(lang):
lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
filters='')
lang_tokenizer.fit_on_texts(lang)
tensor = lang_tokenizer.texts_to_sequences(lang)
tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,
padding='post')
return tensor, lang_tokenizer
def load_dataset(path, num_examples=None):
# creating cleaned input, output pairs
targ_lang, inp_lang = create_dataset(path, num_examples)
input_tensor, inp_lang_tokenizer = tokenize(inp_lang)
target_tensor, targ_lang_tokenizer = tokenize(targ_lang)
return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer
Limite o tamanho do conjunto de dados para fazer experiências mais rapidamente (opcional)
O treinamento no conjunto de dados completo de> 100.000 sentenças levará muito tempo. Para treinar mais rápido, podemos limitar o tamanho do conjunto de dados a 30.000 frases (é claro, a qualidade da tradução diminui com menos dados):
# Try experimenting with the size of that dataset
num_examples = 30000
input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)
# Calculate max_length of the target tensors
max_length_targ, max_length_inp = target_tensor.shape[1], input_tensor.shape[1]
# Creating training and validation sets using an 80-20 split
input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)
# Show length
print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))
24000 24000 6000 6000
def convert(lang, tensor):
for t in tensor:
if t!=0:
print ("%d ----> %s" % (t, lang.index_word[t]))
print ("Input Language; index to word mapping")
convert(inp_lang, input_tensor_train[0])
print ()
print ("Target Language; index to word mapping")
convert(targ_lang, target_tensor_train[0])
Input Language; index to word mapping 1 ----> <start> 6379 ----> dese 395 ----> vuelta 32 ----> , 22 ----> por 50 ----> favor 3 ----> . 2 ----> <end> Target Language; index to word mapping 1 ----> <start> 56 ----> please 205 ----> turn 197 ----> over 3 ----> . 2 ----> <end>
Crie um conjunto de dados tf.data
BUFFER_SIZE = len(input_tensor_train)
BATCH_SIZE = 64
steps_per_epoch = len(input_tensor_train)//BATCH_SIZE
embedding_dim = 256
units = 1024
vocab_inp_size = len(inp_lang.word_index)+1
vocab_tar_size = len(targ_lang.word_index)+1
dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
example_input_batch, example_target_batch = next(iter(dataset))
example_input_batch.shape, example_target_batch.shape
(TensorShape([64, 16]), TensorShape([64, 11]))
Escreva o codificador e o modelo do decodificador
Implemente um modelo de codificador-decodificador com atenção, sobre o qual você pode ler no tutorial de Tradução automática neural do TensorFlow (seq2seq) . Este exemplo usa um conjunto mais recente de APIs. Este notebook implementa as equações de atenção do tutorial seq2seq. O diagrama a seguir mostra que a cada palavra de entrada é atribuído um peso pelo mecanismo de atenção que é então usado pelo decodificador para prever a próxima palavra na frase. A imagem e as fórmulas abaixo são um exemplo de mecanismo de atenção do artigo de Luong .
A entrada é passada por um modelo de codificador que nos dá a saída do codificador de forma (batch_size, max_length, hidden_size) e o estado oculto do codificador de forma (batch_size, hidden_size) .
Aqui estão as equações que são implementadas:
Este tutorial usa a atenção Bahdanau para o codificador. Vamos decidir sobre a notação antes de escrever a forma simplificada:
- FC = camada totalmente conectada (densa)
- EO = saída do codificador
- H = estado oculto
- X = entrada para o decodificador
E o pseudocódigo:
-
score = FC(tanh(FC(EO) + FC(H)))
-
attention weights = softmax(score, axis = 1)
. Softmax por padrão é aplicado no último eixo, mas aqui queremos aplicá-lo no primeiro eixo , uma vez que a forma da pontuação é (batch_size, max_length, hidden_size) .Max_length
é o comprimento de nossa entrada. Como estamos tentando atribuir um peso a cada entrada, o softmax deve ser aplicado a esse eixo. -
context vector = sum(attention weights * EO, axis = 1)
. Mesma razão acima para escolher o eixo como 1. -
embedding output
= a entrada para o decodificador X é passada por uma camada de incorporação. -
merged vector = concat(embedding output, context vector)
- Este vetor mesclado é então dado ao GRU
As formas de todos os vetores em cada etapa foram especificadas nos comentários do código:
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
def call(self, x, hidden):
x = self.embedding(x)
output, state = self.gru(x, initial_state = hidden)
return output, state
def initialize_hidden_state(self):
return tf.zeros((self.batch_sz, self.enc_units))
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
# sample input
sample_hidden = encoder.initialize_hidden_state()
sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)
print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))
Encoder output shape: (batch size, sequence length, units) (64, 16, 1024) Encoder Hidden state shape: (batch size, units) (64, 1024)
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, query, values):
# query hidden state shape == (batch_size, hidden size)
# query_with_time_axis shape == (batch_size, 1, hidden size)
# values shape == (batch_size, max_len, hidden size)
# we are doing this to broadcast addition along the time axis to calculate the score
query_with_time_axis = tf.expand_dims(query, 1)
# score shape == (batch_size, max_length, 1)
# we get 1 at the last axis because we are applying score to self.V
# the shape of the tensor before applying self.V is (batch_size, max_length, units)
score = self.V(tf.nn.tanh(
self.W1(query_with_time_axis) + self.W2(values)))
# attention_weights shape == (batch_size, max_length, 1)
attention_weights = tf.nn.softmax(score, axis=1)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
attention_layer = BahdanauAttention(10)
attention_result, attention_weights = attention_layer(sample_hidden, sample_output)
print("Attention result shape: (batch size, units) {}".format(attention_result.shape))
print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))
Attention result shape: (batch size, units) (64, 1024) Attention weights shape: (batch_size, sequence_length, 1) (64, 16, 1)
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(self.dec_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
self.fc = tf.keras.layers.Dense(vocab_size)
# used for attention
self.attention = BahdanauAttention(self.dec_units)
def call(self, x, hidden, enc_output):
# enc_output shape == (batch_size, max_length, hidden_size)
context_vector, attention_weights = self.attention(hidden, enc_output)
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
x = self.embedding(x)
# x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
# passing the concatenated vector to the GRU
output, state = self.gru(x)
# output shape == (batch_size * 1, hidden_size)
output = tf.reshape(output, (-1, output.shape[2]))
# output shape == (batch_size, vocab)
x = self.fc(output)
return x, state, attention_weights
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)
sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),
sample_hidden, sample_output)
print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))
Decoder output shape: (batch_size, vocab size) (64, 4935)
Defina o otimizador e a função de perda
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_mean(loss_)
Pontos de verificação (salvamento baseado em objeto)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
Treinamento
- Passe a entrada pelo codificador que retorna a saída do codificador e o estado oculto do codificador .
- A saída do codificador, o estado oculto do codificador e a entrada do decodificador (que é o token de início ) são passados para o decodificador.
- O decodificador retorna as previsões e o estado oculto do decodificador .
- O estado oculto do decodificador é então passado de volta para o modelo e as previsões são usadas para calcular a perda.
- Use a força do professor para decidir a próxima entrada para o decodificador.
- Forçar professor é a técnica em que a palavra alvo é passada como a próxima entrada para o decodificador.
- A etapa final é calcular os gradientes e aplicá-los ao otimizador e retropropagar.
@tf.function
def train_step(inp, targ, enc_hidden):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_hidden = encoder(inp, enc_hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)
# Teacher forcing - feeding the target as the next input
for t in range(1, targ.shape[1]):
# passing enc_output to the decoder
predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
loss += loss_function(targ[:, t], predictions)
# using teacher forcing
dec_input = tf.expand_dims(targ[:, t], 1)
batch_loss = (loss / int(targ.shape[1]))
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return batch_loss
EPOCHS = 10
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(inp, targ, enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1,
total_loss / steps_per_epoch))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
Epoch 1 Batch 0 Loss 4.7113 Epoch 1 Batch 100 Loss 2.1051 Epoch 1 Batch 200 Loss 1.9095 Epoch 1 Batch 300 Loss 1.7646 Epoch 1 Loss 2.0334 Time taken for 1 epoch 26.513352870941162 sec Epoch 2 Batch 0 Loss 1.4994 Epoch 2 Batch 100 Loss 1.4381 Epoch 2 Batch 200 Loss 1.3774 Epoch 2 Batch 300 Loss 1.1783 Epoch 2 Loss 1.3686 Time taken for 1 epoch 15.74858546257019 sec Epoch 3 Batch 0 Loss 0.9827 Epoch 3 Batch 100 Loss 1.0305 Epoch 3 Batch 200 Loss 0.9073 Epoch 3 Batch 300 Loss 0.8466 Epoch 3 Loss 0.9339 Time taken for 1 epoch 15.360853910446167 sec Epoch 4 Batch 0 Loss 0.5953 Epoch 4 Batch 100 Loss 0.6024 Epoch 4 Batch 200 Loss 0.6550 Epoch 4 Batch 300 Loss 0.6959 Epoch 4 Loss 0.6273 Time taken for 1 epoch 15.659878015518188 sec Epoch 5 Batch 0 Loss 0.4362 Epoch 5 Batch 100 Loss 0.4403 Epoch 5 Batch 200 Loss 0.5202 Epoch 5 Batch 300 Loss 0.3749 Epoch 5 Loss 0.4293 Time taken for 1 epoch 15.344685077667236 sec Epoch 6 Batch 0 Loss 0.3615 Epoch 6 Batch 100 Loss 0.2462 Epoch 6 Batch 200 Loss 0.2649 Epoch 6 Batch 300 Loss 0.3645 Epoch 6 Loss 0.2965 Time taken for 1 epoch 15.627461910247803 sec Epoch 7 Batch 0 Loss 0.2720 Epoch 7 Batch 100 Loss 0.1868 Epoch 7 Batch 200 Loss 0.2354 Epoch 7 Batch 300 Loss 0.2372 Epoch 7 Loss 0.2145 Time taken for 1 epoch 15.387472867965698 sec Epoch 8 Batch 0 Loss 0.1477 Epoch 8 Batch 100 Loss 0.1718 Epoch 8 Batch 200 Loss 0.1659 Epoch 8 Batch 300 Loss 0.1612 Epoch 8 Loss 0.1623 Time taken for 1 epoch 15.627415657043457 sec Epoch 9 Batch 0 Loss 0.0871 Epoch 9 Batch 100 Loss 0.1062 Epoch 9 Batch 200 Loss 0.1450 Epoch 9 Batch 300 Loss 0.1639 Epoch 9 Loss 0.1268 Time taken for 1 epoch 15.357704162597656 sec Epoch 10 Batch 0 Loss 0.0960 Epoch 10 Batch 100 Loss 0.0805 Epoch 10 Batch 200 Loss 0.1251 Epoch 10 Batch 300 Loss 0.1206 Epoch 10 Loss 0.1037 Time taken for 1 epoch 15.646350383758545 sec
Traduzir
- A função de avaliação é semelhante ao loop de treinamento, exceto que não usamos o forçamento do professor aqui. A entrada para o decodificador em cada etapa de tempo são suas previsões anteriores junto com o estado oculto e a saída do codificador.
- Pare de prever quando o modelo prevê o token final .
- E armazene os pesos de atenção para cada passo de tempo .
def evaluate(sentence):
attention_plot = np.zeros((max_length_targ, max_length_inp))
sentence = preprocess_sentence(sentence)
inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
maxlen=max_length_inp,
padding='post')
inputs = tf.convert_to_tensor(inputs)
result = ''
hidden = [tf.zeros((1, units))]
enc_out, enc_hidden = encoder(inputs, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)
for t in range(max_length_targ):
predictions, dec_hidden, attention_weights = decoder(dec_input,
dec_hidden,
enc_out)
# storing the attention weights to plot later on
attention_weights = tf.reshape(attention_weights, (-1, ))
attention_plot[t] = attention_weights.numpy()
predicted_id = tf.argmax(predictions[0]).numpy()
result += targ_lang.index_word[predicted_id] + ' '
if targ_lang.index_word[predicted_id] == '<end>':
return result, sentence, attention_plot
# the predicted ID is fed back into the model
dec_input = tf.expand_dims([predicted_id], 0)
return result, sentence, attention_plot
# function for plotting the attention weights
def plot_attention(attention, sentence, predicted_sentence):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
fontdict = {'fontsize': 14}
ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
def translate(sentence):
result, sentence, attention_plot = evaluate(sentence)
print('Input: %s' % (sentence))
print('Predicted translation: {}'.format(result))
attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
plot_attention(attention_plot, sentence.split(' '), result.split(' '))
Restaure o último ponto de verificação e teste
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f3f4a04af60>
translate(u'hace mucho frio aqui.')
Input: <start> hace mucho frio aqui . <end> Predicted translation: it s very cold here . <end>
translate(u'esta es mi vida.')
Input: <start> esta es mi vida . <end> Predicted translation: this is my life . <end>
translate(u'¿todavia estan en casa?')
Input: <start> ¿ todavia estan en casa ? <end> Predicted translation: are you still at home ? <end>
# wrong translation
translate(u'trata de averiguarlo.')
Input: <start> trata de averiguarlo . <end> Predicted translation: try to figure it out . <end>
Próximos passos
- Baixe um conjunto de dados diferente para experimentar traduções, por exemplo, inglês para alemão ou inglês para francês.
- Experimente treinar em um conjunto de dados maior ou usar mais épocas