このページは Cloud Translation API によって翻訳されました。
Switch to English

TensorFlowアドオンネットワーク:アテンションメカニズムを備えたシーケンスツーシーケンスNMT

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

概観

このノートブックでは、 Sequence to Sequenceモデルアーキテクチャの概要を説明しています。このノートブックでは、ニューラル機械翻訳に必要な4つの重要なトピックを幅広くカバーしています。

  • データのクリーニング
  • データの準備
  • 注意を伴う神経変換モデル
  • 最終翻訳

ただし、このようなモデルの背後にある基本的なアイデアは、エンコーダーデコーダーアーキテクチャのみです。これらのネットワークは通常、テキスト要約、機械翻訳、画像キャプションなどのさまざまなタスクに使用されます。このチュートリアルでは、必要に応じて技術用語を説明し、概念を実践的に理解します。 seq2seqモデルの最初のテストベッドであるニューラル機械翻訳(NMT)のタスクに焦点を当てます。

セットアップ

追加のリソース:

これらは、このノートブックを実行できるようにするためにインストールする必要がある一連の復活です。

  1. ドイツ語-英語データセット

データセットをダウンロードする必要があります。このノートブックをコンパイルするには、事前トレーニングされている埋め込みを使用できます。でもここで独自の研修を実施!!

#download data
print("Downloading Dataset:")
!wget --quiet http://www.manythings.org/anki/deu-eng.zip
!unzip deu-eng.zip
import csv
import string
import re
from typing import List, Tuple
from pickle import dump
from unicodedata import normalize
import numpy as np
import itertools
from pickle import load
from tensorflow.keras.utils import to_categorical
from keras.utils.vis_utils import plot_model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Embedding
from pickle import load
import random
import tensorflow as tf
from keras.models import load_model
from nltk.translate.bleu_score import corpus_bleu
from sklearn.model_selection import train_test_split
import tensorflow_addons as tfa
Using TensorFlow backend.

データのクリーニング

私たちのデータセットはドイツ語-英語翻訳データセットです。これには、英語からドイツ語のフェーズの152,820ペアが含まれています。1行に1ペアがあり、言語を区切るタブが付いています。これらのデータセットは整理されていますが、作業する前にクリーニングが必要です。これにより、トレーニング中に発生する可能性のある不要なバンプを削除できます。また、モデルが予測を開始および停止するタイミングを認識できるように、文の<start>および文の<end>も追加しました。

# Start of sentence
SOS = "<start>"
# End of sentence
EOS = "<end>"
# Relevant punctuation
PUNCTUATION = set("?,!.")


def load_dataset(filename: str) -> str:
    """
    load dataset into memory
    """
    with open(filename, mode="rt", encoding="utf-8") as fp:
        return fp.read()


def to_pairs(dataset: str, limit: int = None, shuffle=False) -> List[Tuple[str, str]]:
    """
    Split dataset into pairs of sentences, discards dataset line info.

    e.g.
    input -> 'Go.\tGeh.\tCC-BY 2.0 (France) Attribution: tatoeba.org
    #2877272 (CM) & #8597805 (Roujin)'
    output -> [('Go.', 'Geh.')]

    :param dataset: dataset containing examples of translations between
    two languages
    the examples are delimited by `\n` and the contents of the lines are
    delimited by `\t`
    :param limit: number that limit dataset size (optional)
    :param shuffle: default is True
    :return: list of pairs
    """
    assert isinstance(limit, (int, type(None))), TypeError(
        "the limit value must be an integer"
    )
    lines = dataset.strip().split("\n")
    # Radom dataset
    if shuffle is True:
        random.shuffle(lines)
    number_examples = limit or len(lines)  # if None get all
    pairs = []
    for line in lines[: abs(number_examples)]:
        # take only source and target
        src, trg, _ = line.split("\t")
        pairs.append((src, trg))

    # dataset size check
    assert len(pairs) == number_examples
    return pairs


def separe_punctuation(token: str) -> str:
    """
    Separe punctuation if exists
    """

    if not set(token).intersection(PUNCTUATION):
        return token
    for p in PUNCTUATION:
        token = f" {p} ".join(token.split(p))
    return " ".join(token.split())


def preprocess(sentence: str, add_start_end: bool=True) -> str:
    """

    - convert lowercase
    - remove numbers
    - remove special characters
    - separe punctuation
    - add start-of-sentence <start> and end-of-sentence <end>

    :param add_start_end: add SOS (start-of-sentence) and EOS (end-of-sentence)
    """
    re_print = re.compile(f"[^{re.escape(string.printable)}]")
    # convert lowercase and normalizing unicode characters
    sentence = (
        normalize("NFD", sentence.lower()).encode("ascii", "ignore").decode("UTF-8")
    )
    cleaned_tokens = []
    # tokenize sentence on white space
    for token in sentence.split():
        # removing non-printable chars form each token
        token = re_print.sub("", token).strip()
        # ignore tokens with numbers
        if re.findall("[0-9]", token):
            continue
        # add space between words and punctuation eg: "ok?go!" => "ok ? go !"
        token = separe_punctuation(token)
        cleaned_tokens.append(token)

    # rebuild sentence with space between tokens
    sentence = " ".join(cleaned_tokens)

    # adding a start and an end token to the sentence
    if add_start_end is True:
        sentence = f"{SOS} {sentence} {EOS}"
    return sentence


def dataset_preprocess(dataset: List[Tuple[str, str]]) -> Tuple[List[str], List[str]]:
    """
    Returns processed database

    :param dataset: list of sentence pairs
    :return: list of paralel data e.g. 
    (['first source sentence', 'second', ...], ['first target sentence', 'second', ...])
    """
    source_cleaned = []
    target_cleaned = []
    for source, target in dataset:
        source_cleaned.append(preprocess(source))
        target_cleaned.append(preprocess(target))
    return source_cleaned, target_cleaned

データセットを作成

  • 例の数を制限する
  • データセットをペアにロードします[('Be nice.', 'Seien Sie nett!'), ('Beat it.', 'Geh weg!'), ...]
  • データセットの前処理
NUM_EXAMPLES = 10000 # Limit dataset size

# load from .txt
filename = 'deu.txt' #change filename if necessary
dataset = load_dataset(filename)
# get pairs limited into 1000
pairs = to_pairs(dataset, limit=NUM_EXAMPLES)
print(f"Dataset size: {len(pairs)}")
raw_data_en, raw_data_ge = dataset_preprocess(pairs)

# show last 5 pairs
for pair in zip(raw_data_en[-5:],raw_data_ge[-5:]):
    print(pair)
Dataset size: 10000
("<start> tom's hungover . <end>", '<start> tom ist verkatert . <end>')
("<start> tom's in there . <end>", '<start> tom ist da drinnen . <end>')
("<start> tom's innocent . <end>", '<start> tom ist unschuldig . <end>')
("<start> tom's laughing . <end>", '<start> tom lacht . <end>')
("<start> tom's not busy . <end>", '<start> tom ist nicht beschaftigt . <end>')

トークン化

en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
en_tokenizer.fit_on_texts(raw_data_en)

data_en = en_tokenizer.texts_to_sequences(raw_data_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,padding='post')

ge_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
ge_tokenizer.fit_on_texts(raw_data_ge)

data_ge = ge_tokenizer.texts_to_sequences(raw_data_ge)
data_ge = tf.keras.preprocessing.sequence.pad_sequences(data_ge,padding='post')
def max_len(tensor):
    #print( np.argmax([len(t) for t in tensor]))
    return max( len(t) for t in tensor)

モデルパラメータ

X_train,  X_test, Y_train, Y_test = train_test_split(data_en,data_ge,test_size=0.2)
BATCH_SIZE = 64
BUFFER_SIZE = len(X_train)
steps_per_epoch = BUFFER_SIZE//BATCH_SIZE
embedding_dims = 256
rnn_units = 1024
dense_units = 1024
Dtype = tf.float32   #used to initialize DecoderCell Zero state

データセットの準備

Tx = max_len(data_en)
Ty = max_len(data_ge)  

input_vocab_size = len(en_tokenizer.word_index)+1  
output_vocab_size = len(ge_tokenizer.word_index)+ 1
dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
example_X, example_Y = next(iter(dataset))
print(example_X.shape) 
print(example_Y.shape) 
(64, 9)
(64, 13)

NMTモデルの定義

#ENCODER
class EncoderNetwork(tf.keras.Model):
    def __init__(self,input_vocab_size,embedding_dims, rnn_units ):
        super().__init__()
        self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,
                                                           output_dim=embedding_dims)
        self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True, 
                                                     return_state=True )
    
#DECODER
class DecoderNetwork(tf.keras.Model):
    def __init__(self,output_vocab_size, embedding_dims, rnn_units):
        super().__init__()
        self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,
                                                           output_dim=embedding_dims) 
        self.dense_layer = tf.keras.layers.Dense(output_vocab_size)
        self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)
        # Sampler
        self.sampler = tfa.seq2seq.sampler.TrainingSampler()
        # Create attention mechanism with memory = None
        self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])
        self.rnn_cell =  self.build_rnn_cell(BATCH_SIZE)
        self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler= self.sampler,
                                                output_layer=self.dense_layer)

    def build_attention_mechanism(self, units,memory, memory_sequence_length):
        return tfa.seq2seq.LuongAttention(units, memory = memory, 
                                          memory_sequence_length=memory_sequence_length)
        #return tfa.seq2seq.BahdanauAttention(units, memory = memory, memory_sequence_length=memory_sequence_length)

    # wrap decodernn cell  
    def build_rnn_cell(self, batch_size ):
        rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell, self.attention_mechanism,
                                                attention_layer_size=dense_units)
        return rnn_cell
    
    def build_decoder_initial_state(self, batch_size, encoder_state,Dtype):
        decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size, 
                                                                dtype = Dtype)
        decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) 
        return decoder_initial_state

encoderNetwork = EncoderNetwork(input_vocab_size,embedding_dims, rnn_units)
decoderNetwork = DecoderNetwork(output_vocab_size,embedding_dims, rnn_units)
optimizer = tf.keras.optimizers.Adam()

トレーニング関数の初期化

def loss_function(y_pred, y):
   
    #shape of y [batch_size, ty]
    #shape of y_pred [batch_size, Ty, output_vocab_size] 
    sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                                                  reduction='none')
    loss = sparsecategoricalcrossentropy(y_true=y, y_pred=y_pred)
    mask = tf.logical_not(tf.math.equal(y,0))   #output 0 for y=0 else output 1
    mask = tf.cast(mask, dtype=loss.dtype)
    loss = mask* loss
    loss = tf.reduce_mean(loss)
    return loss

def train_step(input_batch, output_batch,encoder_initial_cell_state):
    #initialize loss = 0
    loss = 0
    with tf.GradientTape() as tape:
        encoder_emb_inp = encoderNetwork.encoder_embedding(input_batch)
        a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp, 
                                                        initial_state =encoder_initial_cell_state)

        #[last step activations,last memory_state] of encoder passed as input to decoder Network
        
         
        # Prepare correct Decoder input & output sequence data
        decoder_input = output_batch[:,:-1] # ignore <end>
        #compare logits with timestepped +1 version of decoder_input
        decoder_output = output_batch[:,1:] #ignore <start>


        # Decoder Embeddings
        decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)

        #Setting up decoder memory from encoder output and Zero State for AttentionWrapperState
        decoderNetwork.attention_mechanism.setup_memory(a)
        decoder_initial_state = decoderNetwork.build_decoder_initial_state(BATCH_SIZE,
                                                                           encoder_state=[a_tx, c_tx],
                                                                           Dtype=tf.float32)
        
        #BasicDecoderOutput        
        outputs, _, _ = decoderNetwork.decoder(decoder_emb_inp,initial_state=decoder_initial_state,
                                               sequence_length=BATCH_SIZE*[Ty-1])

        logits = outputs.rnn_output
        #Calculate loss

        loss = loss_function(logits, decoder_output)

    #Returns the list of all layer variables / weights.
    variables = encoderNetwork.trainable_variables + decoderNetwork.trainable_variables  
    # differentiate loss wrt variables
    gradients = tape.gradient(loss, variables)

    #grads_and_vars – List of(gradient, variable) pairs.
    grads_and_vars = zip(gradients,variables)
    optimizer.apply_gradients(grads_and_vars)
    return loss
#RNN LSTM hidden and memory state initializer
def initialize_initial_state():
        return [tf.zeros((BATCH_SIZE, rnn_units)), tf.zeros((BATCH_SIZE, rnn_units))]

トレーニング

epochs = 15
for i in range(1, epochs+1):

    encoder_initial_cell_state = initialize_initial_state()
    total_loss = 0.0

    for ( batch , (input_batch, output_batch)) in enumerate(dataset.take(steps_per_epoch)):
        batch_loss = train_step(input_batch, output_batch, encoder_initial_cell_state)
        total_loss += batch_loss
        if (batch+1)%5 == 0:
            print("total loss: {} epoch {} batch {} ".format(batch_loss.numpy(), i, batch+1))
total loss: 4.11566686630249 epoch 1 batch 5 
total loss: 2.993711233139038 epoch 1 batch 10 
total loss: 2.456167459487915 epoch 1 batch 15 
total loss: 2.1430583000183105 epoch 1 batch 20 
total loss: 2.202500820159912 epoch 1 batch 25 
total loss: 2.0447075366973877 epoch 1 batch 30 
total loss: 1.943502426147461 epoch 1 batch 35 
total loss: 1.8647733926773071 epoch 1 batch 40 
total loss: 1.887935757637024 epoch 1 batch 45 
total loss: 2.0422816276550293 epoch 1 batch 50 
total loss: 1.7727909088134766 epoch 1 batch 55 
total loss: 1.64265775680542 epoch 1 batch 60 
total loss: 1.708620548248291 epoch 1 batch 65 
total loss: 1.663000464439392 epoch 1 batch 70 
total loss: 1.733208179473877 epoch 1 batch 75 
total loss: 1.6179828643798828 epoch 1 batch 80 
total loss: 1.6496108770370483 epoch 1 batch 85 
total loss: 1.7499841451644897 epoch 1 batch 90 
total loss: 1.6253910064697266 epoch 1 batch 95 
total loss: 1.6513166427612305 epoch 1 batch 100 
total loss: 1.6768431663513184 epoch 1 batch 105 
total loss: 1.5870885848999023 epoch 1 batch 110 
total loss: 1.5872650146484375 epoch 1 batch 115 
total loss: 1.6291579008102417 epoch 1 batch 120 
total loss: 1.5899494886398315 epoch 1 batch 125 
total loss: 1.4799423217773438 epoch 2 batch 5 
total loss: 1.5262809991836548 epoch 2 batch 10 
total loss: 1.5344295501708984 epoch 2 batch 15 
total loss: 1.4018179178237915 epoch 2 batch 20 
total loss: 1.2517988681793213 epoch 2 batch 25 
total loss: 1.3529373407363892 epoch 2 batch 30 
total loss: 1.3586145639419556 epoch 2 batch 35 
total loss: 1.445724606513977 epoch 2 batch 40 
total loss: 1.3398317098617554 epoch 2 batch 45 
total loss: 1.3680673837661743 epoch 2 batch 50 
total loss: 1.2311376333236694 epoch 2 batch 55 
total loss: 1.4052680730819702 epoch 2 batch 60 
total loss: 1.333901286125183 epoch 2 batch 65 
total loss: 1.3520702123641968 epoch 2 batch 70 
total loss: 1.3728466033935547 epoch 2 batch 75 
total loss: 1.2714239358901978 epoch 2 batch 80 
total loss: 1.2564586400985718 epoch 2 batch 85 
total loss: 1.3137339353561401 epoch 2 batch 90 
total loss: 1.2975138425827026 epoch 2 batch 95 
total loss: 1.387586236000061 epoch 2 batch 100 
total loss: 1.3724294900894165 epoch 2 batch 105 
total loss: 1.2119255065917969 epoch 2 batch 110 
total loss: 1.3122206926345825 epoch 2 batch 115 
total loss: 1.2198586463928223 epoch 2 batch 120 
total loss: 1.3899301290512085 epoch 2 batch 125 
total loss: 1.165706753730774 epoch 3 batch 5 
total loss: 1.1252425909042358 epoch 3 batch 10 
total loss: 1.238714575767517 epoch 3 batch 15 
total loss: 1.1738801002502441 epoch 3 batch 20 
total loss: 1.0969573259353638 epoch 3 batch 25 
total loss: 1.333953857421875 epoch 3 batch 30 
total loss: 1.150842547416687 epoch 3 batch 35 
total loss: 1.188685417175293 epoch 3 batch 40 
total loss: 1.1465986967086792 epoch 3 batch 45 
total loss: 1.192483901977539 epoch 3 batch 50 
total loss: 1.0624631643295288 epoch 3 batch 55 
total loss: 1.0510236024856567 epoch 3 batch 60 
total loss: 1.2240933179855347 epoch 3 batch 65 
total loss: 1.1826471090316772 epoch 3 batch 70 
total loss: 1.1488090753555298 epoch 3 batch 75 
total loss: 1.0666199922561646 epoch 3 batch 80 
total loss: 1.099870204925537 epoch 3 batch 85 
total loss: 1.2090061902999878 epoch 3 batch 90 
total loss: 1.0520261526107788 epoch 3 batch 95 
total loss: 1.1468778848648071 epoch 3 batch 100 
total loss: 1.1265020370483398 epoch 3 batch 105 
total loss: 1.1606225967407227 epoch 3 batch 110 
total loss: 1.0110392570495605 epoch 3 batch 115 
total loss: 1.0859214067459106 epoch 3 batch 120 
total loss: 1.0578597784042358 epoch 3 batch 125 
total loss: 0.9532763957977295 epoch 4 batch 5 
total loss: 0.9986910223960876 epoch 4 batch 10 
total loss: 0.956193208694458 epoch 4 batch 15 
total loss: 0.9690749049186707 epoch 4 batch 20 
total loss: 1.101245403289795 epoch 4 batch 25 
total loss: 0.9993131160736084 epoch 4 batch 30 
total loss: 0.9002986550331116 epoch 4 batch 35 
total loss: 0.9263710975646973 epoch 4 batch 40 
total loss: 0.9197303652763367 epoch 4 batch 45 
total loss: 0.928395688533783 epoch 4 batch 50 
total loss: 1.0114047527313232 epoch 4 batch 55 
total loss: 1.083633542060852 epoch 4 batch 60 
total loss: 0.9597204327583313 epoch 4 batch 65 
total loss: 0.948369562625885 epoch 4 batch 70 
total loss: 0.9748582243919373 epoch 4 batch 75 
total loss: 1.1318320035934448 epoch 4 batch 80 
total loss: 0.9337785243988037 epoch 4 batch 85 
total loss: 1.066165566444397 epoch 4 batch 90 
total loss: 0.896867573261261 epoch 4 batch 95 
total loss: 0.8608654141426086 epoch 4 batch 100 
total loss: 1.0241042375564575 epoch 4 batch 105 
total loss: 0.9655657410621643 epoch 4 batch 110 
total loss: 0.9644956588745117 epoch 4 batch 115 
total loss: 0.9764884114265442 epoch 4 batch 120 
total loss: 0.9435749650001526 epoch 4 batch 125 
total loss: 0.8453261852264404 epoch 5 batch 5 
total loss: 0.8299359679222107 epoch 5 batch 10 
total loss: 0.7513504028320312 epoch 5 batch 15 
total loss: 0.9070504307746887 epoch 5 batch 20 
total loss: 0.8284425139427185 epoch 5 batch 25 
total loss: 0.7402708530426025 epoch 5 batch 30 
total loss: 0.8092263340950012 epoch 5 batch 35 
total loss: 0.8729425072669983 epoch 5 batch 40 
total loss: 0.8656869530677795 epoch 5 batch 45 
total loss: 0.7958595752716064 epoch 5 batch 50 
total loss: 0.8578026294708252 epoch 5 batch 55 
total loss: 0.778644859790802 epoch 5 batch 60 
total loss: 0.7277960777282715 epoch 5 batch 65 
total loss: 0.8738289475440979 epoch 5 batch 70 
total loss: 0.6856063008308411 epoch 5 batch 75 
total loss: 0.8267806172370911 epoch 5 batch 80 
total loss: 0.946643054485321 epoch 5 batch 85 
total loss: 0.9214975237846375 epoch 5 batch 90 
total loss: 0.796623706817627 epoch 5 batch 95 
total loss: 0.8234322667121887 epoch 5 batch 100 
total loss: 0.8310582041740417 epoch 5 batch 105 
total loss: 0.7187271118164062 epoch 5 batch 110 
total loss: 0.881775438785553 epoch 5 batch 115 
total loss: 0.8496475219726562 epoch 5 batch 120 
total loss: 0.7749930024147034 epoch 5 batch 125 
total loss: 0.6057237386703491 epoch 6 batch 5 
total loss: 0.5688410401344299 epoch 6 batch 10 
total loss: 0.6365471482276917 epoch 6 batch 15 
total loss: 0.6626251935958862 epoch 6 batch 20 
total loss: 0.6636946797370911 epoch 6 batch 25 
total loss: 0.6313133835792542 epoch 6 batch 30 
total loss: 0.5917147397994995 epoch 6 batch 35 
total loss: 0.6965726017951965 epoch 6 batch 40 
total loss: 0.6281453371047974 epoch 6 batch 45 
total loss: 0.6475895047187805 epoch 6 batch 50 
total loss: 0.7765358090400696 epoch 6 batch 55 
total loss: 0.5973318219184875 epoch 6 batch 60 
total loss: 0.713416576385498 epoch 6 batch 65 
total loss: 0.7173630595207214 epoch 6 batch 70 
total loss: 0.7002382874488831 epoch 6 batch 75 
total loss: 0.6431768536567688 epoch 6 batch 80 
total loss: 0.6381948590278625 epoch 6 batch 85 
total loss: 0.7046375870704651 epoch 6 batch 90 
total loss: 0.6564927697181702 epoch 6 batch 95 
total loss: 0.7156146168708801 epoch 6 batch 100 
total loss: 0.7078973650932312 epoch 6 batch 105 
total loss: 0.6482166647911072 epoch 6 batch 110 
total loss: 0.5653694868087769 epoch 6 batch 115 
total loss: 0.768178403377533 epoch 6 batch 120 
total loss: 0.6993356347084045 epoch 6 batch 125 
total loss: 0.4355561435222626 epoch 7 batch 5 
total loss: 0.5312787294387817 epoch 7 batch 10 
total loss: 0.5179179906845093 epoch 7 batch 15 
total loss: 0.5177888870239258 epoch 7 batch 20 
total loss: 0.5274668335914612 epoch 7 batch 25 
total loss: 0.4485582113265991 epoch 7 batch 30 
total loss: 0.5205077528953552 epoch 7 batch 35 
total loss: 0.6028087735176086 epoch 7 batch 40 
total loss: 0.4433538615703583 epoch 7 batch 45 
total loss: 0.5281521677970886 epoch 7 batch 50 
total loss: 0.5123710036277771 epoch 7 batch 55 
total loss: 0.4892776906490326 epoch 7 batch 60 
total loss: 0.5777449011802673 epoch 7 batch 65 
total loss: 0.5938393473625183 epoch 7 batch 70 
total loss: 0.5447298884391785 epoch 7 batch 75 
total loss: 0.5399925112724304 epoch 7 batch 80 
total loss: 0.549943745136261 epoch 7 batch 85 
total loss: 0.5606051683425903 epoch 7 batch 90 
total loss: 0.6317020058631897 epoch 7 batch 95 
total loss: 0.5499157309532166 epoch 7 batch 100 
total loss: 0.5369137525558472 epoch 7 batch 105 
total loss: 0.6119964718818665 epoch 7 batch 110 
total loss: 0.6122032403945923 epoch 7 batch 115 
total loss: 0.6180634498596191 epoch 7 batch 120 
total loss: 0.5060015320777893 epoch 7 batch 125 
total loss: 0.4102749526500702 epoch 8 batch 5 
total loss: 0.4113573729991913 epoch 8 batch 10 
total loss: 0.34586894512176514 epoch 8 batch 15 
total loss: 0.4162067174911499 epoch 8 batch 20 
total loss: 0.4488414227962494 epoch 8 batch 25 
total loss: 0.47596967220306396 epoch 8 batch 30 
total loss: 0.43868470191955566 epoch 8 batch 35 
total loss: 0.4669533669948578 epoch 8 batch 40 
total loss: 0.4095423221588135 epoch 8 batch 45 
total loss: 0.4171658754348755 epoch 8 batch 50 
total loss: 0.41935643553733826 epoch 8 batch 55 
total loss: 0.42487478256225586 epoch 8 batch 60 
total loss: 0.5020427107810974 epoch 8 batch 65 
total loss: 0.46865570545196533 epoch 8 batch 70 
total loss: 0.48575273156166077 epoch 8 batch 75 
total loss: 0.402313232421875 epoch 8 batch 80 
total loss: 0.5250392556190491 epoch 8 batch 85 
total loss: 0.5152303576469421 epoch 8 batch 90 
total loss: 0.4697692394256592 epoch 8 batch 95 
total loss: 0.4108094274997711 epoch 8 batch 100 
total loss: 0.4215029776096344 epoch 8 batch 105 
total loss: 0.43752169609069824 epoch 8 batch 110 
total loss: 0.45470383763313293 epoch 8 batch 115 
total loss: 0.5394885540008545 epoch 8 batch 120 
total loss: 0.46421656012535095 epoch 8 batch 125 
total loss: 0.38278815150260925 epoch 9 batch 5 
total loss: 0.3325278162956238 epoch 9 batch 10 
total loss: 0.25561612844467163 epoch 9 batch 15 
total loss: 0.39196979999542236 epoch 9 batch 20 
total loss: 0.3144271671772003 epoch 9 batch 25 
total loss: 0.3374980390071869 epoch 9 batch 30 
total loss: 0.3220641613006592 epoch 9 batch 35 
total loss: 0.28498175740242004 epoch 9 batch 40 
total loss: 0.34717854857444763 epoch 9 batch 45 
total loss: 0.27360835671424866 epoch 9 batch 50 
total loss: 0.34681805968284607 epoch 9 batch 55 
total loss: 0.34650543332099915 epoch 9 batch 60 
total loss: 0.377156525850296 epoch 9 batch 65 
total loss: 0.3942091464996338 epoch 9 batch 70 
total loss: 0.40023937821388245 epoch 9 batch 75 
total loss: 0.3928321301937103 epoch 9 batch 80 
total loss: 0.3811839818954468 epoch 9 batch 85 
total loss: 0.3996661901473999 epoch 9 batch 90 
total loss: 0.4434957504272461 epoch 9 batch 95 
total loss: 0.36710819602012634 epoch 9 batch 100 
total loss: 0.4244243800640106 epoch 9 batch 105 
total loss: 0.39385613799095154 epoch 9 batch 110 
total loss: 0.40314915776252747 epoch 9 batch 115 
total loss: 0.38281798362731934 epoch 9 batch 120 
total loss: 0.34032365679740906 epoch 9 batch 125 
total loss: 0.25806427001953125 epoch 10 batch 5 
total loss: 0.2550051212310791 epoch 10 batch 10 
total loss: 0.23162484169006348 epoch 10 batch 15 
total loss: 0.26205796003341675 epoch 10 batch 20 
total loss: 0.2918882369995117 epoch 10 batch 25 
total loss: 0.28667184710502625 epoch 10 batch 30 
total loss: 0.30746373534202576 epoch 10 batch 35 
total loss: 0.24943065643310547 epoch 10 batch 40 
total loss: 0.24033032357692719 epoch 10 batch 45 
total loss: 0.29537105560302734 epoch 10 batch 50 
total loss: 0.3474333584308624 epoch 10 batch 55 
total loss: 0.31821370124816895 epoch 10 batch 60 
total loss: 0.35506772994995117 epoch 10 batch 65 
total loss: 0.40117380023002625 epoch 10 batch 70 
total loss: 0.2801777422428131 epoch 10 batch 75 
total loss: 0.26276424527168274 epoch 10 batch 80 
total loss: 0.33141613006591797 epoch 10 batch 85 
total loss: 0.2891913056373596 epoch 10 batch 90 
total loss: 0.34682735800743103 epoch 10 batch 95 
total loss: 0.39360567927360535 epoch 10 batch 100 
total loss: 0.40213945508003235 epoch 10 batch 105 
total loss: 0.2949744462966919 epoch 10 batch 110 
total loss: 0.27941974997520447 epoch 10 batch 115 
total loss: 0.28911301493644714 epoch 10 batch 120 
total loss: 0.3066214621067047 epoch 10 batch 125 
total loss: 0.24125127494335175 epoch 11 batch 5 
total loss: 0.20186877250671387 epoch 11 batch 10 
total loss: 0.2145632952451706 epoch 11 batch 15 
total loss: 0.23457588255405426 epoch 11 batch 20 
total loss: 0.2408994436264038 epoch 11 batch 25 
total loss: 0.1797456294298172 epoch 11 batch 30 
total loss: 0.19768892228603363 epoch 11 batch 35 
total loss: 0.21545785665512085 epoch 11 batch 40 
total loss: 0.23571373522281647 epoch 11 batch 45 
total loss: 0.25327250361442566 epoch 11 batch 50 
total loss: 0.2649385631084442 epoch 11 batch 55 
total loss: 0.30291682481765747 epoch 11 batch 60 
total loss: 0.2986145317554474 epoch 11 batch 65 
total loss: 0.20132605731487274 epoch 11 batch 70 
total loss: 0.24036192893981934 epoch 11 batch 75 
total loss: 0.29774945974349976 epoch 11 batch 80 
total loss: 0.24990446865558624 epoch 11 batch 85 
total loss: 0.27169445157051086 epoch 11 batch 90 
total loss: 0.2602415978908539 epoch 11 batch 95 
total loss: 0.26800140738487244 epoch 11 batch 100 
total loss: 0.27735427021980286 epoch 11 batch 105 
total loss: 0.26872459053993225 epoch 11 batch 110 
total loss: 0.27796170115470886 epoch 11 batch 115 
total loss: 0.3037005364894867 epoch 11 batch 120 
total loss: 0.3586186468601227 epoch 11 batch 125 
total loss: 0.1714320331811905 epoch 12 batch 5 
total loss: 0.17246638238430023 epoch 12 batch 10 
total loss: 0.2304478883743286 epoch 12 batch 15 
total loss: 0.20280466973781586 epoch 12 batch 20 
total loss: 0.1980326622724533 epoch 12 batch 25 
total loss: 0.24768061935901642 epoch 12 batch 30 
total loss: 0.17778398096561432 epoch 12 batch 35 
total loss: 0.2015562802553177 epoch 12 batch 40 
total loss: 0.1770702451467514 epoch 12 batch 45 
total loss: 0.2334766387939453 epoch 12 batch 50 
total loss: 0.20495925843715668 epoch 12 batch 55 
total loss: 0.21376878023147583 epoch 12 batch 60 
total loss: 0.24144266545772552 epoch 12 batch 65 
total loss: 0.2306946963071823 epoch 12 batch 70 
total loss: 0.23844711482524872 epoch 12 batch 75 
total loss: 0.24324734508991241 epoch 12 batch 80 
total loss: 0.1984959989786148 epoch 12 batch 85 
total loss: 0.2658829689025879 epoch 12 batch 90 
total loss: 0.24130244553089142 epoch 12 batch 95 
total loss: 0.23028753697872162 epoch 12 batch 100 
total loss: 0.27955183386802673 epoch 12 batch 105 
total loss: 0.269803524017334 epoch 12 batch 110 
total loss: 0.24687449634075165 epoch 12 batch 115 
total loss: 0.2637614905834198 epoch 12 batch 120 
total loss: 0.2655775249004364 epoch 12 batch 125 
total loss: 0.1553117036819458 epoch 13 batch 5 
total loss: 0.12917208671569824 epoch 13 batch 10 
total loss: 0.23377186059951782 epoch 13 batch 15 
total loss: 0.17143402993679047 epoch 13 batch 20 
total loss: 0.19789159297943115 epoch 13 batch 25 
total loss: 0.17325706779956818 epoch 13 batch 30 
total loss: 0.1461445689201355 epoch 13 batch 35 
total loss: 0.1638738512992859 epoch 13 batch 40 
total loss: 0.23124034702777863 epoch 13 batch 45 
total loss: 0.19878023862838745 epoch 13 batch 50 
total loss: 0.1812722235918045 epoch 13 batch 55 
total loss: 0.24695098400115967 epoch 13 batch 60 
total loss: 0.15736332535743713 epoch 13 batch 65 
total loss: 0.18134035170078278 epoch 13 batch 70 
total loss: 0.20316295325756073 epoch 13 batch 75 
total loss: 0.17294305562973022 epoch 13 batch 80 
total loss: 0.2048470824956894 epoch 13 batch 85 
total loss: 0.1972559690475464 epoch 13 batch 90 
total loss: 0.20555488765239716 epoch 13 batch 95 
total loss: 0.15902088582515717 epoch 13 batch 100 
total loss: 0.27476567029953003 epoch 13 batch 105 
total loss: 0.24714398384094238 epoch 13 batch 110 
total loss: 0.25630465149879456 epoch 13 batch 115 
total loss: 0.269127756357193 epoch 13 batch 120 
total loss: 0.23399965465068817 epoch 13 batch 125 
total loss: 0.14865447580814362 epoch 14 batch 5 
total loss: 0.16153651475906372 epoch 14 batch 10 
total loss: 0.17261719703674316 epoch 14 batch 15 
total loss: 0.22619158029556274 epoch 14 batch 20 
total loss: 0.13681507110595703 epoch 14 batch 25 
total loss: 0.16032403707504272 epoch 14 batch 30 
total loss: 0.14292384684085846 epoch 14 batch 35 
total loss: 0.13681446015834808 epoch 14 batch 40 
total loss: 0.18409228324890137 epoch 14 batch 45 
total loss: 0.1674126237630844 epoch 14 batch 50 
total loss: 0.14732179045677185 epoch 14 batch 55 
total loss: 0.13022463023662567 epoch 14 batch 60 
total loss: 0.18770740926265717 epoch 14 batch 65 
total loss: 0.16499507427215576 epoch 14 batch 70 
total loss: 0.13566173613071442 epoch 14 batch 75 
total loss: 0.15898260474205017 epoch 14 batch 80 
total loss: 0.16641056537628174 epoch 14 batch 85 
total loss: 0.1944132298231125 epoch 14 batch 90 
total loss: 0.2262207269668579 epoch 14 batch 95 
total loss: 0.20676560699939728 epoch 14 batch 100 
total loss: 0.2102840393781662 epoch 14 batch 105 
total loss: 0.19340692460536957 epoch 14 batch 110 
total loss: 0.187296524643898 epoch 14 batch 115 
total loss: 0.17335641384124756 epoch 14 batch 120 
total loss: 0.2099289447069168 epoch 14 batch 125 
total loss: 0.14340081810951233 epoch 15 batch 5 
total loss: 0.14579172432422638 epoch 15 batch 10 
total loss: 0.1293977051973343 epoch 15 batch 15 
total loss: 0.15074902772903442 epoch 15 batch 20 
total loss: 0.13329613208770752 epoch 15 batch 25 
total loss: 0.1491243988275528 epoch 15 batch 30 
total loss: 0.14245960116386414 epoch 15 batch 35 
total loss: 0.14042304456233978 epoch 15 batch 40 
total loss: 0.17087322473526 epoch 15 batch 45 
total loss: 0.18867500126361847 epoch 15 batch 50 
total loss: 0.17223608493804932 epoch 15 batch 55 
total loss: 0.16629959642887115 epoch 15 batch 60 
total loss: 0.15043802559375763 epoch 15 batch 65 
total loss: 0.16201333701610565 epoch 15 batch 70 
total loss: 0.1867101788520813 epoch 15 batch 75 
total loss: 0.17749939858913422 epoch 15 batch 80 
total loss: 0.18169927597045898 epoch 15 batch 85 
total loss: 0.18131349980831146 epoch 15 batch 90 
total loss: 0.18957491219043732 epoch 15 batch 95 
total loss: 0.15851835906505585 epoch 15 batch 100 
total loss: 0.15743960440158844 epoch 15 batch 105 
total loss: 0.22563040256500244 epoch 15 batch 110 
total loss: 0.17509043216705322 epoch 15 batch 115 
total loss: 0.16400296986103058 epoch 15 batch 120 
total loss: 0.20385797321796417 epoch 15 batch 125 

評価

#In this section we evaluate our model on a raw_input converted to german, for this the entire sentence has to be passed
#through the length of the model, for this we use greedsampler to run through the decoder
#and the final embedding matrix trained on the data is used to generate embeddings
input_raw='how are you'

# We have a transcript file containing English-German pairs
# Preprocess X
input_raw = preprocess(input_raw, add_start_end=False)
input_lines = [f'{SOS} {input_raw}']
input_sequences = [[en_tokenizer.word_index[w] for w in line.split()] for line in input_lines]
input_sequences = tf.keras.preprocessing.sequence.pad_sequences(input_sequences,
                                                                maxlen=Tx, padding='post')
inp = tf.convert_to_tensor(input_sequences)
#print(inp.shape)
inference_batch_size = input_sequences.shape[0]
encoder_initial_cell_state = [tf.zeros((inference_batch_size, rnn_units)),
                              tf.zeros((inference_batch_size, rnn_units))]
encoder_emb_inp = encoderNetwork.encoder_embedding(inp)
a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp,
                                                initial_state =encoder_initial_cell_state)
print('a_tx :', a_tx.shape)
print('c_tx :', c_tx.shape)

start_tokens = tf.fill([inference_batch_size],ge_tokenizer.word_index[SOS])

end_token = ge_tokenizer.word_index[EOS]

greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler()

decoder_input = tf.expand_dims([ge_tokenizer.word_index[SOS]]* inference_batch_size,1)
decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)

decoder_instance = tfa.seq2seq.BasicDecoder(cell = decoderNetwork.rnn_cell, sampler = greedy_sampler,
                                            output_layer=decoderNetwork.dense_layer)
decoderNetwork.attention_mechanism.setup_memory(a)
#pass [ last step activations , encoder memory_state ] as input to decoder for LSTM
print(f"decoder_initial_state = [a_tx, c_tx] : {np.array([a_tx, c_tx]).shape}")
decoder_initial_state = decoderNetwork.build_decoder_initial_state(inference_batch_size,
                                                                   encoder_state=[a_tx, c_tx],
                                                                   Dtype=tf.float32)
print(f"""
Compared to simple encoder-decoder without attention, the decoder_initial_state
is an AttentionWrapperState object containing s_prev tensors and context and alignment vector

decoder initial state shape: {np.array(decoder_initial_state).shape}
decoder_initial_state tensor
{decoder_initial_state}
""")

# Since we do not know the target sequence lengths in advance, we use maximum_iterations to limit the translation lengths.
# One heuristic is to decode up to two times the source sentence lengths.
maximum_iterations = tf.round(tf.reduce_max(Tx) * 2)

#initialize inference decoder
decoder_embedding_matrix = decoderNetwork.decoder_embedding.variables[0] 
(first_finished, first_inputs,first_state) = decoder_instance.initialize(decoder_embedding_matrix,
                             start_tokens = start_tokens,
                             end_token=end_token,
                             initial_state = decoder_initial_state)
#print( first_finished.shape)
print(f"first_inputs returns the same decoder_input i.e. embedding of  {SOS} : {first_inputs.shape}")
print(f"start_index_emb_avg {tf.reduce_sum(tf.reduce_mean(first_inputs, axis=0))}") # mean along the batch

inputs = first_inputs
state = first_state  
predictions = np.empty((inference_batch_size,0), dtype = np.int32)                                                                             
for j in range(maximum_iterations):
    outputs, next_state, next_inputs, finished = decoder_instance.step(j,inputs,state)
    inputs = next_inputs
    state = next_state
    outputs = np.expand_dims(outputs.sample_id,axis = -1)
    predictions = np.append(predictions, outputs, axis = -1)
a_tx : (1, 1024)
c_tx : (1, 1024)
decoder_initial_state = [a_tx, c_tx] : (2, 1, 1024)

Compared to simple encoder-decoder without attention, the decoder_initial_state
is an AttentionWrapperState object containing s_prev tensors and context and alignment vector

decoder initial state shape: (6,)
decoder_initial_state tensor
AttentionWrapperState(cell_state=[<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=
array([[ 0.0218722 , -0.00386145, -0.34212956, ..., -0.0818582 ,
         0.0042587 , -0.06107492]], dtype=float32)>, <tf.Tensor: shape=(1, 1024), dtype=float32, numpy=
array([[ 0.07267428, -0.01349923, -1.1421771 , ..., -0.27573448,
         0.01418022, -0.14704482]], dtype=float32)>], attention=<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=array([[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, time=<tf.Tensor: shape=(), dtype=int32, numpy=0>, alignments=<tf.Tensor: shape=(1, 9), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>, alignment_history=(), attention_state=<tf.Tensor: shape=(1, 9), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>)

first_inputs returns the same decoder_input i.e. embedding of  <start> : (1, 256)
start_index_emb_avg -1.4379956722259521

最終翻訳

#prediction based on our sentence earlier
print("English Sentence:")
print(input_raw)
print("\nGerman Translation:")
for i in range(len(predictions)):
    line = predictions[i,:]
    seq = list(itertools.takewhile( lambda index: index !=2, line))
    print(" ".join( [ge_tokenizer.index_word[w] for w in seq]))
English Sentence:
how are you

German Translation:
wie du bist !

精度は、以下を実装することで改善できます。

  • ビーム検索またはレキシコン検索
  • 双方向エンコーダー/デコーダーモデル