Google I / O возвращается 18-20 мая! Зарезервируйте место и составьте свое расписание Зарегистрируйтесь сейчас

Базовая классификация текста

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

В этом руководстве демонстрируется классификация текста, начиная с простых текстовых файлов, хранящихся на диске. Вы научите бинарный классификатор выполнять анализ тональности набора данных IMDB. В конце записной книжки есть упражнение, которое вы можете попробовать, в котором вы научите мультиклассовый классификатор предсказывать тег для вопроса программирования в Stack Overflow.

import matplotlib.pyplot as plt
import os
import re
import shutil
import string
import tensorflow as tf

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
print(tf.__version__)
2.4.1

Анализ настроений

Эта записная книжка обучает модель анализа настроений для классификации отзывов о фильмах как положительных или отрицательных на основе текста обзора. Это пример двоичной или двухклассовой классификации, важной и широко применяемой задачи машинного обучения.

Вы будете использовать набор данных Large Movie Review , содержащий текст 50 000 обзоров фильмов из базы данных Internet Movie . Они разделены на 25 000 обзоров для обучения и 25 000 обзоров для тестирования. Наборы для обучения и тестирования сбалансированы , то есть содержат одинаковое количество положительных и отрицательных отзывов.

Загрузите и изучите набор данных IMDB

Давайте загрузим и извлечем набор данных, а затем исследуем структуру каталогов.

url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

dataset = tf.keras.utils.get_file("aclImdb_v1", url,
                                    untar=True, cache_dir='.',
                                    cache_subdir='')

dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
84131840/84125825 [==============================] - 6s 0us/step
os.listdir(dataset_dir)
['test', 'imdb.vocab', 'README', 'train', 'imdbEr.txt']
train_dir = os.path.join(dataset_dir, 'train')
os.listdir(train_dir)
['labeledBow.feat',
 'unsupBow.feat',
 'unsup',
 'urls_neg.txt',
 'neg',
 'urls_pos.txt',
 'urls_unsup.txt',
 'pos']

aclImdb/train/pos и aclImdb/train/neg содержат множество текстовых файлов, каждый из которых представляет собой отдельный обзор фильма. Давайте посмотрим на один из них.

sample_file = os.path.join(train_dir, 'pos/1181_9.txt')
with open(sample_file) as f:
  print(f.read())
Rachel Griffiths writes and directs this award winning short film. A heartwarming story about coping with grief and cherishing the memory of those we've loved and lost. Although, only 15 minutes long, Griffiths manages to capture so much emotion and truth onto film in the short space of time. Bud Tingwell gives a touching performance as Will, a widower struggling to cope with his wife's death. Will is confronted by the harsh reality of loneliness and helplessness as he proceeds to take care of Ruth's pet cow, Tulip. The film displays the grief and responsibility one feels for those they have loved and lost. Good cinematography, great direction, and superbly acted. It will bring tears to all those who have lost a loved one, and survived.

Загрузите набор данных

Затем вы загрузите данные с диска и подготовите их в формате, подходящем для обучения. Для этого вы будете использовать полезную утилиту text_dataset_from_directory , которая ожидает следующую структуру каталогов.

main_directory/
...class_a/
......a_text_1.txt
......a_text_2.txt
...class_b/
......b_text_1.txt
......b_text_2.txt

Чтобы подготовить набор данных для двоичной классификации, вам понадобятся две папки на диске, соответствующие class_a и class_b . Это будут положительные и отрицательные обзоры фильмов, которые можно найти в aclImdb/train/pos и aclImdb/train/neg . Поскольку набор данных IMDB содержит дополнительные папки, вы удалите их перед использованием этой утилиты.

remove_dir = os.path.join(train_dir, 'unsup')
shutil.rmtree(remove_dir)

Затем вы воспользуетесь утилитой text_dataset_from_directory для создания помеченногоtf.data.Dataset . tf.data - это мощный набор инструментов для работы с данными.

При проведении эксперимента с машинным обучением рекомендуется разделить набор данных на три части: обучение , проверка и тестирование .

Набор данных IMDB уже разделен на обучающий и тестовый, но в нем отсутствует набор для проверки. Давайте создадим набор проверки, используя разделение данных обучения 80:20, используя аргумент validation_split ниже.

batch_size = 32
seed = 42

raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/train', 
    batch_size=batch_size, 
    validation_split=0.2, 
    subset='training', 
    seed=seed)
Found 25000 files belonging to 2 classes.
Using 20000 files for training.

Как вы можете видеть выше, в папке обучения 25 000 примеров, из которых вы будете использовать 80% (или 20 000) для обучения. Как вы вскоре увидите, вы можете обучить модель, передав набор данных непосредственно в model.fit . Если вы новичок в tf.data , вы также можете tf.data набор данных и распечатать несколько примеров, как tf.data ниже.

for text_batch, label_batch in raw_train_ds.take(1):
  for i in range(3):
    print("Review", text_batch.numpy()[i])
    print("Label", label_batch.numpy()[i])
Review b'"Pandemonium" is a horror movie spoof that comes off more stupid than funny. Believe me when I tell you, I love comedies. Especially comedy spoofs. "Airplane", "The Naked Gun" trilogy, "Blazing Saddles", "High Anxiety", and "Spaceballs" are some of my favorite comedies that spoof a particular genre. "Pandemonium" is not up there with those films. Most of the scenes in this movie had me sitting there in stunned silence because the movie wasn\'t all that funny. There are a few laughs in the film, but when you watch a comedy, you expect to laugh a lot more than a few times and that\'s all this film has going for it. Geez, "Scream" had more laughs than this film and that was more of a horror film. How bizarre is that?<br /><br />*1/2 (out of four)'
Label 0
Review b"David Mamet is a very interesting and a very un-equal director. His first movie 'House of Games' was the one I liked best, and it set a series of films with characters whose perspective of life changes as they get into complicated situations, and so does the perspective of the viewer.<br /><br />So is 'Homicide' which from the title tries to set the mind of the viewer to the usual crime drama. The principal characters are two cops, one Jewish and one Irish who deal with a racially charged area. The murder of an old Jewish shop owner who proves to be an ancient veteran of the Israeli Independence war triggers the Jewish identity in the mind and heart of the Jewish detective.<br /><br />This is were the flaws of the film are the more obvious. The process of awakening is theatrical and hard to believe, the group of Jewish militants is operatic, and the way the detective eventually walks to the final violent confrontation is pathetic. The end of the film itself is Mamet-like smart, but disappoints from a human emotional perspective.<br /><br />Joe Mantegna and William Macy give strong performances, but the flaws of the story are too evident to be easily compensated."
Label 0
Review b'Great documentary about the lives of NY firefighters during the worst terrorist attack of all time.. That reason alone is why this should be a must see collectors item.. What shocked me was not only the attacks, but the"High Fat Diet" and physical appearance of some of these firefighters. I think a lot of Doctors would agree with me that,in the physical shape they were in, some of these firefighters would NOT of made it to the 79th floor carrying over 60 lbs of gear. Having said that i now have a greater respect for firefighters and i realize becoming a firefighter is a life altering job. The French have a history of making great documentary\'s and that is what this is, a Great Documentary.....'
Label 1

Обратите внимание, что обзоры содержат необработанный текст (с пунктуацией и случайными HTML-тегами, например <br/> ). В следующем разделе вы покажете, как с ними справиться.

Ярлыки равны 0 или 1. Чтобы увидеть, какие из них соответствуют положительным и отрицательным обзорам фильмов, вы можете проверить свойство class_names в наборе данных.

print("Label 0 corresponds to", raw_train_ds.class_names[0])
print("Label 1 corresponds to", raw_train_ds.class_names[1])
Label 0 corresponds to neg
Label 1 corresponds to pos

Затем вы создадите набор данных для проверки и тестирования. Вы будете использовать оставшиеся 5000 отзывов из обучающего набора для проверки.

raw_val_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/train', 
    batch_size=batch_size, 
    validation_split=0.2, 
    subset='validation', 
    seed=seed)
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.
raw_test_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/test', 
    batch_size=batch_size)
Found 25000 files belonging to 2 classes.

Подготовьте набор данных для обучения

Затем вы стандартизируете, токенизируете и векторизуете данные с помощью полезного слоя preprocessing.TextVectorization .

Стандартизация относится к предварительной обработке текста, обычно для удаления знаков препинания или элементов HTML для упрощения набора данных. Токенизация относится к разделению строк на токены (например, разделение предложения на отдельные слова путем разделения на пробелы). Векторизация относится к преобразованию токенов в числа, чтобы их можно было передать в нейронную сеть. Все эти задачи могут быть выполнены с помощью этого слоя.

Как вы видели выше, обзоры содержат различные HTML-теги, например <br /> . Эти теги не будут удалены стандартизатором по умолчанию в слое TextVectorization (который по умолчанию преобразует текст в нижний регистр и удаляет пунктуацию, но не удаляет HTML). Вы напишете специальную функцию стандартизации для удаления HTML.

def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
  return tf.strings.regex_replace(stripped_html,
                                  '[%s]' % re.escape(string.punctuation),
                                  '')

Затем вы создадите слой TextVectorization . вы будете использовать этот слой для стандартизации, токенизации и векторизации наших данных. Вы устанавливаете output_mode в int чтобы создать уникальные целочисленные индексы для каждого токена.

Обратите внимание, что вы используете функцию разделения по умолчанию и пользовательскую функцию стандартизации, которую вы определили выше. Вы также определите некоторые константы для модели, например, явное максимальное значение sequence_length , которое заставит слой дополнять или усекать последовательности до точных значений sequence_length .

max_features = 10000
sequence_length = 250

vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length)

Затем вы вызовете adapt чтобы подогнать состояние слоя предварительной обработки к набору данных. Это заставит модель построить индекс строк для целых чисел.

# Make a text-only dataset (without labels), then call adapt
train_text = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(train_text)

Давайте создадим функцию, чтобы увидеть результат использования этого слоя для предварительной обработки некоторых данных.

def vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return vectorize_layer(text), label
# retrieve a batch (of 32 reviews and labels) from the dataset
text_batch, label_batch = next(iter(raw_train_ds))
first_review, first_label = text_batch[0], label_batch[0]
print("Review", first_review)
print("Label", raw_train_ds.class_names[first_label])
print("Vectorized review", vectorize_text(first_review, first_label))
Review tf.Tensor(b'Silent Night, Deadly Night 5 is the very last of the series, and like part 4, it\'s unrelated to the first three except by title and the fact that it\'s a Christmas-themed horror flick.<br /><br />Except to the oblivious, there\'s some obvious things going on here...Mickey Rooney plays a toymaker named Joe Petto and his creepy son\'s name is Pino. Ring a bell, anyone? Now, a little boy named Derek heard a knock at the door one evening, and opened it to find a present on the doorstep for him. Even though it said "don\'t open till Christmas", he begins to open it anyway but is stopped by his dad, who scolds him and sends him to bed, and opens the gift himself. Inside is a little red ball that sprouts Santa arms and a head, and proceeds to kill dad. Oops, maybe he should have left well-enough alone. Of course Derek is then traumatized by the incident since he watched it from the stairs, but he doesn\'t grow up to be some killer Santa, he just stops talking.<br /><br />There\'s a mysterious stranger lurking around, who seems very interested in the toys that Joe Petto makes. We even see him buying a bunch when Derek\'s mom takes him to the store to find a gift for him to bring him out of his trauma. And what exactly is this guy doing? Well, we\'re not sure but he does seem to be taking these toys apart to see what makes them tick. He does keep his landlord from evicting him by promising him to pay him in cash the next day and presents him with a "Larry the Larvae" toy for his kid, but of course "Larry" is not a good toy and gets out of the box in the car and of course, well, things aren\'t pretty.<br /><br />Anyway, eventually what\'s going on with Joe Petto and Pino is of course revealed, and as with the old story, Pino is not a "real boy". Pino is probably even more agitated and naughty because he suffers from "Kenitalia" (a smooth plastic crotch) so that could account for his evil ways. And the identity of the lurking stranger is revealed too, and there\'s even kind of a happy ending of sorts. Whee.<br /><br />A step up from part 4, but not much of one. Again, Brian Yuzna is involved, and Screaming Mad George, so some decent special effects, but not enough to make this great. A few leftovers from part 4 are hanging around too, like Clint Howard and Neith Hunter, but that doesn\'t really make any difference. Anyway, I now have seeing the whole series out of my system. Now if I could get some of it out of my brain. 4 out of 5.', shape=(), dtype=string)
Label neg
Vectorized review (<tf.Tensor: shape=(1, 250), dtype=int64, numpy=
array([[1287,  313, 2380,  313,  661,    7,    2,   52,  229,    5,    2,
         200,    3,   38,  170,  669,   29, 5492,    6,    2,   83,  297,
         549,   32,  410,    3,    2,  186,   12,   29,    4,    1,  191,
         510,  549,    6,    2, 8229,  212,   46,  576,  175,  168,   20,
           1, 5361,  290,    4,    1,  761,  969,    1,    3,   24,  935,
        2271,  393,    7,    1, 1675,    4, 3747,  250,  148,    4,  112,
         436,  761, 3529,  548,    4, 3633,   31,    2, 1331,   28, 2096,
           3, 2912,    9,    6,  163,    4, 1006,   20,    2,    1,   15,
          85,   53,  147,    9,  292,   89,  959, 2314,  984,   27,  762,
           6,  959,    9,  564,   18,    7, 2140,   32,   24, 1254,   36,
           1,   85,    3, 3298,   85,    6, 1410,    3, 1936,    2, 3408,
         301,  965,    7,    4,  112,  740, 1977,   12,    1, 2014, 2772,
           3,    4,  428,    3, 5177,    6,  512, 1254,    1,  278,   27,
         139,   25,  308,    1,  579,    5,  259, 3529,    7,   92, 8981,
          32,    2, 3842,  230,   27,  289,    9,   35,    2, 5712,   18,
          27,  144, 2166,   56,    6,   26,   46,  466, 2014,   27,   40,
        2745,  657,  212,    4, 1376, 3002, 7080,  183,   36,  180,   52,
         920,    8,    2, 4028,   12,  969,    1,  158,   71,   53,   67,
          85, 2754,    4,  734,   51,    1, 1611,  294,   85,    6,    2,
        1164,    6,  163,    4, 3408,   15,   85,    6,  717,   85,   44,
           5,   24, 7158,    3,   48,  604,    7,   11,  225,  384,   73,
          65,   21,  242,   18,   27,  120,  295,    6,   26,  667,  129,
        4028,  948,    6,   67,   48,  158,   93,    1]])>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)

Как вы можете видеть выше, каждый токен был заменен целым числом. Вы можете найти токен (строку), которому соответствует каждое целое число, вызвав .get_vocabulary() на уровне.

print("1287 ---> ",vectorize_layer.get_vocabulary()[1287])
print(" 313 ---> ",vectorize_layer.get_vocabulary()[313])
print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))
1287 --->  silent
 313 --->  night
Vocabulary size: 10000

Вы почти готовы обучать свою модель. В качестве последнего шага предварительной обработки вы примените слой TextVectorization, который вы создали ранее, к набору данных для обучения, проверки и тестирования.

train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)
test_ds = raw_test_ds.map(vectorize_text)

Настройте набор данных для повышения производительности

Это два важных метода, которые вы должны использовать при загрузке данных, чтобы убедиться, что ввод-вывод не становится блокирующим.

.cache() сохраняет данные в памяти после их загрузки с диска. Это гарантирует, что набор данных не станет узким местом при обучении вашей модели. Если ваш набор данных слишком велик для размещения в памяти, вы также можете использовать этот метод для создания производительного кеша на диске, который более эффективен для чтения, чем многие небольшие файлы.

.prefetch() перекрывает предварительную обработку данных и выполнение модели во время обучения.

Вы можете узнать больше об обоих методах, а также о том, как кэшировать данные на диск, в руководстве по производительности данных .

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

Создать модель

Пришло время создать нашу нейронную сеть:

embedding_dim = 16
model = tf.keras.Sequential([
  layers.Embedding(max_features + 1, embedding_dim),
  layers.Dropout(0.2),
  layers.GlobalAveragePooling1D(),
  layers.Dropout(0.2),
  layers.Dense(1)])

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 16)          160016    
_________________________________________________________________
dropout (Dropout)            (None, None, 16)          0         
_________________________________________________________________
global_average_pooling1d (Gl (None, 16)                0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 16)                0         
_________________________________________________________________
dense (Dense)                (None, 1)                 17        
=================================================================
Total params: 160,033
Trainable params: 160,033
Non-trainable params: 0
_________________________________________________________________

Слои располагаются последовательно для построения классификатора:

  1. Первый слой - это слой Embedding . Этот уровень принимает обзоры в целочисленной кодировке и ищет вектор внедрения для каждого индекса слова. Эти векторы изучаются по мере обучения модели. Векторы добавляют измерение к выходному массиву. В результате получаются следующие размеры: (batch, sequence, embedding) . Чтобы узнать больше о встраивании, см. Руководство по встраиванию слов .
  2. Затем слой GlobalAveragePooling1D возвращает выходной вектор фиксированной длины для каждого примера путем усреднения по измерению последовательности. Это позволяет модели обрабатывать ввод переменной длины самым простым способом.
  3. Этот выходной вектор фиксированной длины передается по конвейеру через полностью связанный ( Dense ) слой с 16 скрытыми блоками.
  4. Последний слой плотно связан с единственным выходным узлом.

Функция потерь и оптимизатор

Модель нуждается в функции потерь и оптимизаторе для обучения. Поскольку это проблема бинарной классификации и модель выводит вероятность ( losses.BinaryCrossentropy слой с активацией сигмоида), вы будете использовать функцию losses.BinaryCrossentropy потери.

Теперь настройте модель для использования оптимизатора и функции потерь:

model.compile(loss=losses.BinaryCrossentropy(from_logits=True),
              optimizer='adam',
              metrics=tf.metrics.BinaryAccuracy(threshold=0.0))

Обучите модель

Вы обучите модель, передав объект dataset методу fit.

epochs = 10
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs)
Epoch 1/10
625/625 [==============================] - 5s 6ms/step - loss: 0.6824 - binary_accuracy: 0.6126 - val_loss: 0.6136 - val_binary_accuracy: 0.7708
Epoch 2/10
625/625 [==============================] - 3s 4ms/step - loss: 0.5785 - binary_accuracy: 0.7843 - val_loss: 0.4967 - val_binary_accuracy: 0.8240
Epoch 3/10
625/625 [==============================] - 3s 4ms/step - loss: 0.4644 - binary_accuracy: 0.8365 - val_loss: 0.4192 - val_binary_accuracy: 0.8480
Epoch 4/10
625/625 [==============================] - 3s 4ms/step - loss: 0.3906 - binary_accuracy: 0.8610 - val_loss: 0.3732 - val_binary_accuracy: 0.8596
Epoch 5/10
625/625 [==============================] - 3s 4ms/step - loss: 0.3439 - binary_accuracy: 0.8731 - val_loss: 0.3445 - val_binary_accuracy: 0.8662
Epoch 6/10
625/625 [==============================] - 3s 4ms/step - loss: 0.3109 - binary_accuracy: 0.8871 - val_loss: 0.3256 - val_binary_accuracy: 0.8718
Epoch 7/10
625/625 [==============================] - 3s 4ms/step - loss: 0.2852 - binary_accuracy: 0.8963 - val_loss: 0.3121 - val_binary_accuracy: 0.8730
Epoch 8/10
625/625 [==============================] - 3s 4ms/step - loss: 0.2660 - binary_accuracy: 0.9039 - val_loss: 0.3027 - val_binary_accuracy: 0.8754
Epoch 9/10
625/625 [==============================] - 3s 4ms/step - loss: 0.2481 - binary_accuracy: 0.9109 - val_loss: 0.2963 - val_binary_accuracy: 0.8774
Epoch 10/10
625/625 [==============================] - 3s 4ms/step - loss: 0.2338 - binary_accuracy: 0.9162 - val_loss: 0.2916 - val_binary_accuracy: 0.8796

Оцените модель

Посмотрим, как модель работает. Будут возвращены два значения. Потеря (число, которое представляет нашу ошибку, чем ниже значение, тем лучше) и точность.

loss, accuracy = model.evaluate(test_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)
782/782 [==============================] - 3s 4ms/step - loss: 0.3102 - binary_accuracy: 0.8735
Loss:  0.31024304032325745
Accuracy:  0.8734800219535828

Этот довольно наивный подход обеспечивает точность около 86%.

Постройте график точности и потерь с течением времени

model.fit() возвращает объект History который содержит словарь со всем, что произошло во время обучения:

history_dict = history.history
history_dict.keys()
dict_keys(['loss', 'binary_accuracy', 'val_loss', 'val_binary_accuracy'])

Есть четыре записи: по одной для каждой отслеживаемой метрики во время обучения и проверки. Вы можете использовать их для построения графика потерь при обучении и проверке для сравнения, а также для оценки точности обучения и проверки:

acc = history_dict['binary_accuracy']
val_acc = history_dict['val_binary_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']

epochs = range(1, len(acc) + 1)

# "bo" is for "blue dot"
plt.plot(epochs, loss, 'bo', label='Training loss')
# b is for "solid blue line"
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

PNG

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')

plt.show()

PNG

На этом графике точки представляют потери при обучении и точность, а сплошные линии - потери при проверке и точность.

Обратите внимание, что потери при обучении уменьшаются с каждой эпохой, а точность обучения увеличивается с каждой эпохой. Это ожидается при использовании оптимизации градиентного спуска - она ​​должна минимизировать желаемое количество на каждой итерации.

Это не относится к потерям при проверке и точности - они, кажется, достигают пика до точности обучения. Это пример переобучения: модель лучше работает с данными обучения, чем с данными, которых она никогда раньше не видела. После этого модель чрезмерно оптимизирует и изучает представления, специфичные для обучающих данных, которые не обобщаются для тестовых данных.

В этом конкретном случае вы можете предотвратить переобучение, просто остановив обучение, когда точность проверки больше не увеличивается. Один из способов сделать это - использовать tf.keras.callbacks.EarlyStopping вызов tf.keras.callbacks.EarlyStopping .

Экспорт модели

В приведенном выше коде вы применили слой TextVectorization к набору данных перед подачей текста в модель. Если вы хотите, чтобы ваша модель могла обрабатывать необработанные строки (например, чтобы упростить ее развертывание), вы можете включить слой TextVectorization в вашу модель. Для этого вы можете создать новую модель, используя только что обученные веса.

export_model = tf.keras.Sequential([
  vectorize_layer,
  model,
  layers.Activation('sigmoid')
])

export_model.compile(
    loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy']
)

# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_test_ds)
print(accuracy)
782/782 [==============================] - 5s 6ms/step - loss: 0.3115 - accuracy: 0.8738
0.8734800219535828

Вывод по новым данным

Чтобы получить прогнозы для новых примеров, вы можете просто вызвать model.predict() .

examples = [
  "The movie was great!",
  "The movie was okay.",
  "The movie was terrible..."
]

export_model.predict(examples)
array([[0.6198051 ],
       [0.4428497 ],
       [0.35866177]], dtype=float32)

Включение логики предварительной обработки текста в вашу модель позволяет экспортировать модель для производства, что упрощает развертывание и снижает вероятность перекоса при обучении / тестировании .

При выборе места для применения слоя TextVectorization следует учитывать разницу в производительности. Использование его вне модели позволяет выполнять асинхронную обработку ЦП и буферизацию данных при обучении на ГП. Итак, если вы тренируете свою модель на графическом процессоре, вы, вероятно, захотите воспользоваться этой опцией, чтобы получить максимальную производительность при разработке своей модели, а затем переключитесь на включение слоя TextVectorization в вашу модель, когда вы будете готовы к подготовке к развертыванию. .

Посетите это руководство, чтобы узнать больше о сохранении моделей.

Упражнение: мультиклассовая классификация вопросов о переполнении стека

В этом руководстве показано, как обучить двоичный классификатор с нуля на наборе данных IMDB. В качестве упражнения вы можете изменить эту записную книжку, чтобы обучить мультиклассовый классификатор предсказывать тег вопроса программирования в Stack Overflow .

Мы подготовили для вас набор данных , содержащий несколько тысяч вопросов по программированию (например, «Как отсортировать словарь по значению в Python?»), Размещенных в Stack Overflow. Каждый из них помечен ровно одним тегом (Python, CSharp, JavaScript или Java). Ваша задача - принять вопрос в качестве входных данных и предсказать соответствующий тег, в данном случае Python.

Набор данных, с которым вы будете работать, содержит несколько тысяч вопросов, извлеченных из гораздо более крупного общедоступного набора данных Stack Overflow на BigQuery , который содержит более 17 миллионов сообщений.

После загрузки набора данных вы обнаружите, что он имеет структуру каталогов, аналогичную набору данных IMDB, с которым вы работали ранее:

train/
...python/
......0.txt
......1.txt
...javascript/
......0.txt
......1.txt
...csharp/
......0.txt
......1.txt
...java/
......0.txt
......1.txt

Чтобы выполнить это упражнение, вам следует изменить эту записную книжку для работы с набором данных Stack Overflow, внеся следующие изменения:

  1. В верхней части записной книжки обновите код, загружающий набор данных IMDB, с помощью кода, чтобы загрузить подготовленный нами набор данных Stack Overflow . Поскольку набор данных Stack Overflow имеет аналогичную структуру каталогов, вам не нужно будет вносить много изменений.

  2. Измените последний слой вашей модели так, чтобы он читался как Dense(4) , так как теперь существует четыре выходных класса.

  3. При компиляции модели измените потери на losses.SparseCategoricalCrossentropy . Это правильная функция потерь для использования в задаче многоклассовой классификации, когда метки для каждого класса являются целыми числами (в нашем случае они могут быть 0, 1 , 2 или 3 ).

  4. После внесения этих изменений вы сможете обучать мультиклассовый классификатор.

Узнать больше

В этом руководстве с нуля была введена классификация текста. Чтобы узнать больше о рабочем процессе классификации текста в целом, мы рекомендуем прочитать это руководство от разработчиков Google.

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.