![]() | ![]() | ![]() | ![]() |
개요
이 노트북은 Sequence to Sequence 모델 아키텍처에 대한 간략한 소개를 제공합니다.이 노트에서는 신경 기계 번역에 필요한 네 가지 필수 주제를 광범위하게 다룹니다.
- 데이터 정리
- 데이터 준비
- 주의를 기울이는 신경 번역 모델
-
tf.addons.seq2seq.BasicDecoder
및tf.addons.seq2seq.BeamSearchDecoder
를tf.addons.seq2seq.BasicDecoder
최종 번역
하지만 이러한 모델의 기본 아이디어는 인코더-디코더 아키텍처뿐입니다. 이러한 네트워크는 일반적으로 텍스트 여름 화, 기계 번역, 이미지 캡션 등과 같은 다양한 작업에 사용됩니다.이 자습서는 개념에 대한 실습 이해를 제공하고 필요한 경우 기술 전문 용어를 설명합니다. seq2seq 모델의 첫 번째 테스트 베드 인 NMT (Neural Machine Translation) 작업에 집중합니다.
설정
pip install tensorflow-addons==0.11.2
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split
import unicodedata
import re
import numpy as np
import os
import io
import time
데이터 정리 및 데이터 준비
http://www.manythings.org/anki/에서 제공하는 언어 데이터 세트를 사용합니다 . 이 데이터 세트에는 다음 형식의 언어 번역 쌍이 포함됩니다.
May I borrow this book? ¿Puedo tomar prestado este libro?
사용 가능한 다양한 언어가 있지만 영어-스페인어 데이터 세트를 사용합니다. 데이터 세트를 다운로드 한 후 데이터를 준비하기 위해 수행 할 단계는 다음과 같습니다.
- 각 문장에 시작 및 종료 토큰을 추가하십시오.
- 특수 문자를 제거하여 문장을 정리하십시오.
- 단어 색인 (단어 → id에서 매핑) 및 역 단어 색인 (id → 단어에서 매핑)으로 어휘를 만듭니다.
- 각 문장을 최대 길이로 채 웁니다. (왜? 반복 인코더에 대한 입력의 최대 길이를 고정해야합니다)
def download_nmt():
path_to_zip = tf.keras.utils.get_file(
'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
extract=True)
path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"
return path_to_file
1 단계에서 4 단계까지 수행하는 데 필요한 함수로 NMTDataset 클래스를 정의합니다.
call()
은 다음을 반환합니다.
-
train_dataset
및val_dataset
:tf.data.Dataset
객체 -
inp_lang_tokenizer
및targ_lang_tokenizer
:tf.keras.preprocessing.text.Tokenizer
객체
class NMTDataset:
def __init__(self, problem_type='en-spa'):
self.problem_type = 'en-spa'
self.inp_lang_tokenizer = None
self.targ_lang_tokenizer = None
def unicode_to_ascii(self, s):
return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')
## Step 1 and Step 2
def preprocess_sentence(self, w):
w = self.unicode_to_ascii(w.lower().strip())
# creating a space between a word and the punctuation following it
# eg: "he is a boy." => "he is a boy ."
# Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
w = re.sub(r"([?.!,¿])", r" \1 ", w)
w = re.sub(r'[" "]+', " ", w)
# replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
w = w.strip()
# adding a start and an end token to the sentence
# so that the model know when to start and stop predicting.
w = '<start> ' + w + ' <end>'
return w
def create_dataset(self, path, num_examples):
# path : path to spa-eng.txt file
# num_examples : Limit the total number of training example for faster training (set num_examples = len(lines) to use full data)
lines = io.open(path, encoding='UTF-8').read().strip().split('\n')
word_pairs = [[self.preprocess_sentence(w) for w in l.split('\t')] for l in lines[:num_examples]]
return zip(*word_pairs)
# Step 3 and Step 4
def tokenize(self, lang):
# lang = list of sentences in a language
# print(len(lang), "example sentence: {}".format(lang[0]))
lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='', oov_token='<OOV>')
lang_tokenizer.fit_on_texts(lang)
## tf.keras.preprocessing.text.Tokenizer.texts_to_sequences converts string (w1, w2, w3, ......, wn)
## to a list of correspoding integer ids of words (id_w1, id_w2, id_w3, ...., id_wn)
tensor = lang_tokenizer.texts_to_sequences(lang)
## tf.keras.preprocessing.sequence.pad_sequences takes argument a list of integer id sequences
## and pads the sequences to match the longest sequences in the given input
tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor, padding='post')
return tensor, lang_tokenizer
def load_dataset(self, path, num_examples=None):
# creating cleaned input, output pairs
targ_lang, inp_lang = self.create_dataset(path, num_examples)
input_tensor, inp_lang_tokenizer = self.tokenize(inp_lang)
target_tensor, targ_lang_tokenizer = self.tokenize(targ_lang)
return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer
def call(self, num_examples, BUFFER_SIZE, BATCH_SIZE):
file_path = download_nmt()
input_tensor, target_tensor, self.inp_lang_tokenizer, self.targ_lang_tokenizer = self.load_dataset(file_path, num_examples)
input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)
train_dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train))
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
val_dataset = tf.data.Dataset.from_tensor_slices((input_tensor_val, target_tensor_val))
val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder=True)
return train_dataset, val_dataset, self.inp_lang_tokenizer, self.targ_lang_tokenizer
BUFFER_SIZE = 32000
BATCH_SIZE = 64
# Let's limit the #training examples for faster training
num_examples = 30000
dataset_creator = NMTDataset('en-spa')
train_dataset, val_dataset, inp_lang, targ_lang = dataset_creator.call(num_examples, BUFFER_SIZE, BATCH_SIZE)
example_input_batch, example_target_batch = next(iter(train_dataset))
example_input_batch.shape, example_target_batch.shape
(TensorShape([64, 16]), TensorShape([64, 11]))
몇 가지 중요한 매개 변수
vocab_inp_size = len(inp_lang.word_index)+1
vocab_tar_size = len(targ_lang.word_index)+1
max_length_input = example_input_batch.shape[1]
max_length_output = example_target_batch.shape[1]
embedding_dim = 256
units = 1024
steps_per_epoch = num_examples//BATCH_SIZE
print("max_length_spanish, max_length_english, vocab_size_spanish, vocab_size_english")
max_length_input, max_length_output, vocab_inp_size, vocab_tar_size
max_length_spanish, max_length_english, vocab_size_spanish, vocab_size_english (16, 11, 9415, 4936)
#####
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
##________ LSTM layer in Encoder ------- ##
self.lstm_layer = tf.keras.layers.LSTM(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
def call(self, x, hidden):
x = self.embedding(x)
output, h, c = self.lstm_layer(x, initial_state = hidden)
return output, h, c
def initialize_hidden_state(self):
return [tf.zeros((self.batch_sz, self.enc_units)), tf.zeros((self.batch_sz, self.enc_units))]
## Test Encoder Stack
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
# sample input
sample_hidden = encoder.initialize_hidden_state()
sample_output, sample_h, sample_c = encoder(example_input_batch, sample_hidden)
print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
print ('Encoder h vecotr shape: (batch size, units) {}'.format(sample_h.shape))
print ('Encoder c vector shape: (batch size, units) {}'.format(sample_c.shape))
Encoder output shape: (batch size, sequence length, units) (64, 16, 1024) Encoder h vecotr shape: (batch size, units) (64, 1024) Encoder c vector shape: (batch size, units) (64, 1024)
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, attention_type='luong'):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.attention_type = attention_type
# Embedding Layer
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
#Final Dense layer on which softmax will be applied
self.fc = tf.keras.layers.Dense(vocab_size)
# Define the fundamental cell for decoder recurrent structure
self.decoder_rnn_cell = tf.keras.layers.LSTMCell(self.dec_units)
# Sampler
self.sampler = tfa.seq2seq.sampler.TrainingSampler()
# Create attention mechanism with memory = None
self.attention_mechanism = self.build_attention_mechanism(self.dec_units,
None, self.batch_sz*[max_length_input], self.attention_type)
# Wrap attention mechanism with the fundamental rnn cell of decoder
self.rnn_cell = self.build_rnn_cell(batch_sz)
# Define the decoder with respect to fundamental rnn cell
self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler=self.sampler, output_layer=self.fc)
def build_rnn_cell(self, batch_sz):
rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnn_cell,
self.attention_mechanism, attention_layer_size=self.dec_units)
return rnn_cell
def build_attention_mechanism(self, dec_units, memory, memory_sequence_length, attention_type='luong'):
# ------------- #
# typ: Which sort of attention (Bahdanau, Luong)
# dec_units: final dimension of attention outputs
# memory: encoder hidden states of shape (batch_size, max_length_input, enc_units)
# memory_sequence_length: 1d array of shape (batch_size) with every element set to max_length_input (for masking purpose)
if(attention_type=='bahdanau'):
return tfa.seq2seq.BahdanauAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
else:
return tfa.seq2seq.LuongAttention(units=dec_units, memory=memory, memory_sequence_length=memory_sequence_length)
def build_initial_state(self, batch_sz, encoder_state, Dtype):
decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=batch_sz, dtype=Dtype)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
return decoder_initial_state
def call(self, inputs, initial_state):
x = self.embedding(inputs)
outputs, _, _ = self.decoder(x, initial_state=initial_state, sequence_length=self.batch_sz*[max_length_output-1])
return outputs
# Test decoder stack
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE, 'luong')
sample_x = tf.random.uniform((BATCH_SIZE, max_length_output))
decoder.attention_mechanism.setup_memory(sample_output)
initial_state = decoder.build_initial_state(BATCH_SIZE, [sample_h, sample_c], tf.float32)
sample_decoder_outputs = decoder(sample_x, initial_state)
print("Decoder Outputs Shape: ", sample_decoder_outputs.rnn_output.shape)
Decoder Outputs Shape: (64, 10, 4936)
옵티 마이저 및 손실 함수 정의
optimizer = tf.keras.optimizers.Adam()
def loss_function(real, pred):
# real shape = (BATCH_SIZE, max_length_output)
# pred shape = (BATCH_SIZE, max_length_output, tar_vocab_size )
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
loss = cross_entropy(y_true=real, y_pred=pred)
mask = tf.logical_not(tf.math.equal(real,0)) #output 0 for y=0 else output 1
mask = tf.cast(mask, dtype=loss.dtype)
loss = mask* loss
loss = tf.reduce_mean(loss)
return loss
체크 포인트 (객체 기반 저장)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
하나의 train_step 작업
@tf.function
def train_step(inp, targ, enc_hidden):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_h, enc_c = encoder(inp, enc_hidden)
dec_input = targ[ : , :-1 ] # Ignore <end> token
real = targ[ : , 1: ] # ignore <start> token
# Set the AttentionMechanism object with encoder_outputs
decoder.attention_mechanism.setup_memory(enc_output)
# Create AttentionWrapperState as initial_state for decoder
decoder_initial_state = decoder.build_initial_state(BATCH_SIZE, [enc_h, enc_c], tf.float32)
pred = decoder(dec_input, decoder_initial_state)
logits = pred.rnn_output
loss = loss_function(real, logits)
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
모델 훈련
EPOCHS = 10
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
# print(enc_hidden[0].shape, enc_hidden[1].shape)
for (batch, (inp, targ)) in enumerate(train_dataset.take(steps_per_epoch)):
batch_loss = train_step(inp, targ, enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1,
total_loss / steps_per_epoch))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
Epoch 1 Batch 0 Loss 5.1692 Epoch 1 Batch 100 Loss 2.2288 Epoch 1 Batch 200 Loss 1.9930 Epoch 1 Batch 300 Loss 1.7783 Epoch 1 Loss 1.6975 Time taken for 1 epoch 37.26002788543701 sec Epoch 2 Batch 0 Loss 1.6408 Epoch 2 Batch 100 Loss 1.5767 Epoch 2 Batch 200 Loss 1.4054 Epoch 2 Batch 300 Loss 1.3755 Epoch 2 Loss 1.1412 Time taken for 1 epoch 30.0094051361084 sec Epoch 3 Batch 0 Loss 1.0296 Epoch 3 Batch 100 Loss 1.0306 Epoch 3 Batch 200 Loss 1.0675 Epoch 3 Batch 300 Loss 0.9574 Epoch 3 Loss 0.8037 Time taken for 1 epoch 28.983767986297607 sec Epoch 4 Batch 0 Loss 0.5923 Epoch 4 Batch 100 Loss 0.7533 Epoch 4 Batch 200 Loss 0.7397 Epoch 4 Batch 300 Loss 0.6779 Epoch 4 Loss 0.5419 Time taken for 1 epoch 29.649972200393677 sec Epoch 5 Batch 0 Loss 0.4320 Epoch 5 Batch 100 Loss 0.4349 Epoch 5 Batch 200 Loss 0.4686 Epoch 5 Batch 300 Loss 0.4748 Epoch 5 Loss 0.3827 Time taken for 1 epoch 29.06334638595581 sec Epoch 6 Batch 0 Loss 0.3422 Epoch 6 Batch 100 Loss 0.3052 Epoch 6 Batch 200 Loss 0.3288 Epoch 6 Batch 300 Loss 0.3216 Epoch 6 Loss 0.2814 Time taken for 1 epoch 29.57170796394348 sec Epoch 7 Batch 0 Loss 0.2129 Epoch 7 Batch 100 Loss 0.2382 Epoch 7 Batch 200 Loss 0.2406 Epoch 7 Batch 300 Loss 0.2792 Epoch 7 Loss 0.2162 Time taken for 1 epoch 28.95500087738037 sec Epoch 8 Batch 0 Loss 0.2073 Epoch 8 Batch 100 Loss 0.2095 Epoch 8 Batch 200 Loss 0.1962 Epoch 8 Batch 300 Loss 0.1879 Epoch 8 Loss 0.1794 Time taken for 1 epoch 29.70877432823181 sec Epoch 9 Batch 0 Loss 0.1517 Epoch 9 Batch 100 Loss 0.2231 Epoch 9 Batch 200 Loss 0.2203 Epoch 9 Batch 300 Loss 0.2282 Epoch 9 Loss 0.1496 Time taken for 1 epoch 29.20821261405945 sec Epoch 10 Batch 0 Loss 0.1204 Epoch 10 Batch 100 Loss 0.1370 Epoch 10 Batch 200 Loss 0.1778 Epoch 10 Batch 300 Loss 0.2069 Epoch 10 Loss 0.1316 Time taken for 1 epoch 29.576894283294678 sec
디코딩을 위해 tf-addons BasicDecoder 사용
def evaluate_sentence(sentence):
sentence = dataset_creator.preprocess_sentence(sentence)
inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
maxlen=max_length_input,
padding='post')
inputs = tf.convert_to_tensor(inputs)
inference_batch_size = inputs.shape[0]
result = ''
enc_start_state = [tf.zeros((inference_batch_size, units)), tf.zeros((inference_batch_size,units))]
enc_out, enc_h, enc_c = encoder(inputs, enc_start_state)
dec_h = enc_h
dec_c = enc_c
start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
end_token = targ_lang.word_index['<end>']
greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler()
# Instantiate BasicDecoder object
decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, sampler=greedy_sampler, output_layer=decoder.fc)
# Setup Memory in decoder stack
decoder.attention_mechanism.setup_memory(enc_out)
# set decoder_initial_state
decoder_initial_state = decoder.build_initial_state(inference_batch_size, [enc_h, enc_c], tf.float32)
### Since the BasicDecoder wraps around Decoder's rnn cell only, you have to ensure that the inputs to BasicDecoder
### decoding step is output of embedding layer. tfa.seq2seq.GreedyEmbeddingSampler() takes care of this.
### You only need to get the weights of embedding layer, which can be done by decoder.embedding.variables[0] and pass this callabble to BasicDecoder's call() function
decoder_embedding_matrix = decoder.embedding.variables[0]
outputs, _, _ = decoder_instance(decoder_embedding_matrix, start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)
return outputs.sample_id.numpy()
def translate(sentence):
result = evaluate_sentence(sentence)
print(result)
result = targ_lang.sequences_to_texts(result)
print('Input: %s' % (sentence))
print('Predicted translation: {}'.format(result))
최신 체크 포인트 복원 및 테스트
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f9499417390>
translate(u'hace mucho frio aqui.')
[[ 11 12 49 224 40 4 3]] Input: hace mucho frio aqui. Predicted translation: ['it s very pretty here . <end>']
translate(u'esta es mi vida.')
[[ 20 9 22 190 4 3]] Input: esta es mi vida. Predicted translation: ['this is my life . <end>']
translate(u'¿todavia estan en casa?')
[[25 7 90 8 3]] Input: ¿todavia estan en casa? Predicted translation: ['are you home ? <end>']
# wrong translation
translate(u'trata de averiguarlo.')
[[126 16 892 11 75 4 3]] Input: trata de averiguarlo. Predicted translation: ['try to figure it out . <end>']
tf-addons BeamSearchDecoder 사용
def beam_evaluate_sentence(sentence, beam_width=3):
sentence = dataset_creator.preprocess_sentence(sentence)
inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
maxlen=max_length_input,
padding='post')
inputs = tf.convert_to_tensor(inputs)
inference_batch_size = inputs.shape[0]
result = ''
enc_start_state = [tf.zeros((inference_batch_size, units)), tf.zeros((inference_batch_size,units))]
enc_out, enc_h, enc_c = encoder(inputs, enc_start_state)
dec_h = enc_h
dec_c = enc_c
start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
end_token = targ_lang.word_index['<end>']
# From official documentation
# NOTE If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:
# The encoder output has been tiled to beam_width via tfa.seq2seq.tile_batch (NOT tf.tile).
# The batch_size argument passed to the get_initial_state method of this wrapper is equal to true_batch_size * beam_width.
# The initial state created with get_initial_state above contains a cell_state value containing properly tiled final state from the encoder.
enc_out = tfa.seq2seq.tile_batch(enc_out, multiplier=beam_width)
decoder.attention_mechanism.setup_memory(enc_out)
print("beam_with * [batch_size, max_length_input, rnn_units] : 3 * [1, 16, 1024]] :", enc_out.shape)
# set decoder_inital_state which is an AttentionWrapperState considering beam_width
hidden_state = tfa.seq2seq.tile_batch([enc_h, enc_c], multiplier=beam_width)
decoder_initial_state = decoder.rnn_cell.get_initial_state(batch_size=beam_width*inference_batch_size, dtype=tf.float32)
decoder_initial_state = decoder_initial_state.clone(cell_state=hidden_state)
# Instantiate BeamSearchDecoder
decoder_instance = tfa.seq2seq.BeamSearchDecoder(decoder.rnn_cell,beam_width=beam_width, output_layer=decoder.fc)
decoder_embedding_matrix = decoder.embedding.variables[0]
# The BeamSearchDecoder object's call() function takes care of everything.
outputs, final_state, sequence_lengths = decoder_instance(decoder_embedding_matrix, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state)
# outputs is tfa.seq2seq.FinalBeamSearchDecoderOutput object.
# The final beam predictions are stored in outputs.predicted_id
# outputs.beam_search_decoder_output is a tfa.seq2seq.BeamSearchDecoderOutput object which keep tracks of beam_scores and parent_ids while performing a beam decoding step
# final_state = tfa.seq2seq.BeamSearchDecoderState object.
# Sequence Length = [inference_batch_size, beam_width] details the maximum length of the beams that are generated
# outputs.predicted_id.shape = (inference_batch_size, time_step_outputs, beam_width)
# outputs.beam_search_decoder_output.scores.shape = (inference_batch_size, time_step_outputs, beam_width)
# Convert the shape of outputs and beam_scores to (inference_batch_size, beam_width, time_step_outputs)
final_outputs = tf.transpose(outputs.predicted_ids, perm=(0,2,1))
beam_scores = tf.transpose(outputs.beam_search_decoder_output.scores, perm=(0,2,1))
return final_outputs.numpy(), beam_scores.numpy()
def beam_translate(sentence):
result, beam_scores = beam_evaluate_sentence(sentence)
print(result.shape, beam_scores.shape)
for beam, score in zip(result, beam_scores):
print(beam.shape, score.shape)
output = targ_lang.sequences_to_texts(beam)
output = [a[:a.index('<end>')] for a in output]
beam_score = [a.sum() for a in score]
print('Input: %s' % (sentence))
for i in range(len(output)):
print('{} Predicted translation: {} {}'.format(i+1, output[i], beam_score[i]))
beam_translate(u'hace mucho frio aqui.')
beam_with * [batch_size, max_length_input, rnn_units] : 3 * [1, 16, 1024]] : (3, 16, 1024) (1, 3, 7) (1, 3, 7) (3, 7) (3, 7) Input: hace mucho frio aqui. 1 Predicted translation: it s very pretty here . -4.117094039916992 2 Predicted translation: it s very cold here . -14.85302734375 3 Predicted translation: it s very pretty news . -25.59416389465332
beam_translate(u'¿todavia estan en casa?')
beam_with * [batch_size, max_length_input, rnn_units] : 3 * [1, 16, 1024]] : (3, 16, 1024) (1, 3, 7) (1, 3, 7) (3, 7) (3, 7) Input: ¿todavia estan en casa? 1 Predicted translation: are you still home ? -4.036754131317139 2 Predicted translation: are you still at home ? -15.306867599487305 3 Predicted translation: are you still go home ? -20.533388137817383