![]() |
![]() |
![]() |
![]() |
This tutorial provides an example of how to use tf.data.TextLineDataset
to load examples from text files. TextLineDataset
is designed to create a dataset from a text file, in which each example is a line of text from the original file. This is potentially useful for any text data that is primarily line-based (for example, poetry or error logs).
In this tutorial, we'll use three different English translations of the same work, Homer's Illiad, and train a model to identify the translator given a single line of text.
Setup
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import tensorflow_datasets as tfds
import os
The texts of the three translations are by:
The text files used in this tutorial have undergone some typical preprocessing tasks, mostly removing stuff — document header and footer, line numbers, chapter titles. Download these lightly munged files locally.
DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']
for name in FILE_NAMES:
text_dir = tf.keras.utils.get_file(name, origin=DIRECTORY_URL+name)
parent_dir = os.path.dirname(text_dir)
parent_dir
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt 819200/815980 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt 811008/809730 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt 811008/807992 [==============================] - 0s 0us/step '/home/kbuilder/.keras/datasets'
Load text into datasets
Iterate through the files, loading each one into its own dataset.
Each example needs to be individually labeled, so use tf.data.Dataset.map
to apply a labeler function to each one. This will iterate over every example in the dataset, returning (example, label
) pairs.
def labeler(example, index):
return example, tf.cast(index, tf.int64)
labeled_data_sets = []
for i, file_name in enumerate(FILE_NAMES):
lines_dataset = tf.data.TextLineDataset(os.path.join(parent_dir, file_name))
labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))
labeled_data_sets.append(labeled_dataset)
Combine these labeled datasets into a single dataset, and shuffle it.
BUFFER_SIZE = 50000
BATCH_SIZE = 64
TAKE_SIZE = 5000
all_labeled_data = labeled_data_sets[0]
for labeled_dataset in labeled_data_sets[1:]:
all_labeled_data = all_labeled_data.concatenate(labeled_dataset)
all_labeled_data = all_labeled_data.shuffle(
BUFFER_SIZE, reshuffle_each_iteration=False)
You can use tf.data.Dataset.take
and print
to see what the (example, label)
pairs look like. The numpy
property shows each Tensor's value.
for ex in all_labeled_data.take(5):
print(ex)
(<tf.Tensor: id=74, shape=(), dtype=string, numpy=b'I tell you plain. I will not yield my spouse.'>, <tf.Tensor: id=75, shape=(), dtype=int64, numpy=0>) (<tf.Tensor: id=76, shape=(), dtype=string, numpy=b'of the sacred building; there, upon the knees of Minerva, let her lay'>, <tf.Tensor: id=77, shape=(), dtype=int64, numpy=2>) (<tf.Tensor: id=78, shape=(), dtype=string, numpy=b'listen, but it would have been indeed better if I had done so. Now that'>, <tf.Tensor: id=79, shape=(), dtype=int64, numpy=2>) (<tf.Tensor: id=80, shape=(), dtype=string, numpy=b'Shall rule all kingdoms bordering on his own.'>, <tf.Tensor: id=81, shape=(), dtype=int64, numpy=0>) (<tf.Tensor: id=82, shape=(), dtype=string, numpy=b'they were shepherding, but he had taken a ransom for them; now,'>, <tf.Tensor: id=83, shape=(), dtype=int64, numpy=2>)
Encode text lines as numbers
Machine learning models work on numbers, not words, so the string values need to be converted into lists of numbers. To do that, map each unique word to a unique integer.
Build vocabulary
First, build a vocabulary by tokenizing the text into a collection of individual unique words. There are a few ways to do this in both TensorFlow and Python. For this tutorial:
- Iterate over each example's
numpy
value. - Use
tfds.features.text.Tokenizer
to split it into tokens. - Collect these tokens into a Python set, to remove duplicates.
- Get the size of the vocabulary for later use.
tokenizer = tfds.features.text.Tokenizer()
vocabulary_set = set()
for text_tensor, _ in all_labeled_data:
some_tokens = tokenizer.tokenize(text_tensor.numpy())
vocabulary_set.update(some_tokens)
vocab_size = len(vocabulary_set)
vocab_size
17178
Encode examples
Create an encoder by passing the vocabulary_set
to tfds.features.text.TokenTextEncoder
. The encoder's encode
method takes in a string of text and returns a list of integers.
encoder = tfds.features.text.TokenTextEncoder(vocabulary_set)
You can try this on a single line to see what the output looks like.
example_text = next(iter(all_labeled_data))[0].numpy()
print(example_text)
b'I tell you plain. I will not yield my spouse.'
encoded_example = encoder.encode(example_text)
print(encoded_example)
[9511, 14346, 16172, 12902, 9511, 4994, 5916, 722, 16304, 15699]
Now run the encoder on the dataset by wrapping it in tf.py_function
and passing that to the dataset's map
method.
def encode(text_tensor, label):
encoded_text = encoder.encode(text_tensor.numpy())
return encoded_text, label
def encode_map_fn(text, label):
return tf.py_function(encode, inp=[text, label], Tout=(tf.int64, tf.int64))
all_encoded_data = all_labeled_data.map(encode_map_fn)
Split the dataset into test and train batches
Use tf.data.Dataset.take
and tf.data.Dataset.skip
to create a small test dataset and a larger training set.
Before being passed into the model, the datasets need to be batched. Typically, the examples inside of a batch need to be the same size and shape. But, the examples in these datasets are not all the same size — each line of text had a different number of words. So use tf.data.Dataset.padded_batch
(instead of batch
) to pad the examples to the same size.
train_data = all_encoded_data.skip(TAKE_SIZE).shuffle(BUFFER_SIZE)
train_data = train_data.padded_batch(BATCH_SIZE, padded_shapes=([-1],[]))
test_data = all_encoded_data.take(TAKE_SIZE)
test_data = test_data.padded_batch(BATCH_SIZE, padded_shapes=([-1],[]))
Now, test_data
and train_data
are not collections of (example, label
) pairs, but collections of batches. Each batch is a pair of (many examples, many labels) represented as arrays.
To illustrate:
sample_text, sample_labels = next(iter(test_data))
sample_text[0], sample_labels[0]
(<tf.Tensor: id=99547, shape=(17,), dtype=int64, numpy= array([ 9511, 14346, 16172, 12902, 9511, 4994, 5916, 722, 16304, 15699, 0, 0, 0, 0, 0, 0, 0])>, <tf.Tensor: id=99551, shape=(), dtype=int64, numpy=0>)
Since we have introduced a new token encoding (the zero used for padding), the vocabulary size has increased by one.
vocab_size += 1
Build the model
model = tf.keras.Sequential()
The first layer converts integer representations to dense vector embeddings. See the word embeddings tutorial or more details.
model.add(tf.keras.layers.Embedding(vocab_size, 64))
The next layer is a Long Short-Term Memory layer, which lets the model understand words in their context with other words. A bidirectional wrapper on the LSTM helps it to learn about the datapoints in relationship to the datapoints that came before it and after it.
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)))
Finally we'll have a series of one or more densely connected layers, with the last one being the output layer. The output layer produces a probability for all the labels. The one with the highest probability is the models prediction of an example's label.
# One or more dense layers.
# Edit the list in the `for` line to experiment with layer sizes.
for units in [64, 64]:
model.add(tf.keras.layers.Dense(units, activation='relu'))
# Output layer. The first argument is the number of labels.
model.add(tf.keras.layers.Dense(3, activation='softmax'))
Finally, compile the model. For a softmax categorization model, use sparse_categorical_crossentropy
as the loss function. You can try other optimizers, but adam
is very common.
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
Train the model
This model running on this data produces decent results (about 83%).
model.fit(train_data, epochs=3, validation_data=test_data)
Epoch 1/3 697/697 [==============================] - 33s 47ms/step - loss: 0.5192 - accuracy: 0.7501 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00 Epoch 2/3 697/697 [==============================] - 26s 37ms/step - loss: 0.3010 - accuracy: 0.8671 - val_loss: 0.3636 - val_accuracy: 0.8344 Epoch 3/3 697/697 [==============================] - 25s 37ms/step - loss: 0.2286 - accuracy: 0.8987 - val_loss: 0.3881 - val_accuracy: 0.8354 <tensorflow.python.keras.callbacks.History at 0x7ffb6c4b3cf8>
eval_loss, eval_acc = model.evaluate(test_data)
print('\nEval loss: {:.3f}, Eval accuracy: {:.3f}'.format(eval_loss, eval_acc))
79/79 [==============================] - 3s 41ms/step - loss: 0.3881 - accuracy: 0.8354 Eval loss: 0.388, Eval accuracy: 0.835