RSVP для вашего местного мероприятия TensorFlow Everywhere сегодня!
Эта страница переведена с помощью Cloud Translation API.
Switch to English

Классификация статей Bangla с помощью TF-Hub

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть на GitHub Скачать блокнот

Этот colab является демонстрацией использования Tensorflow Hub для классификации текста на неанглийских / местных языках. Здесь мы выбираем бангла в качестве местного языка и используем предварительно обученные вложения слов для решения задачи мультиклассовой классификации, где мы классифицируем новостные статьи на бангла по 5 категориям. Предварительно обученные вложения для Bangla происходят из fastText , библиотеки Facebook с выпущенными предварительно обученными векторами слов для 157 языков.

Мы будем использовать предварительно обученный экспортер встраивания TF-Hub для преобразования встраиваемых слов в модуль встраивания текста, а затем использовать модуль для обучения классификатора с tf.keras , высокоуровневым пользовательским API Tensorflow для создания моделей глубокого обучения. Даже если мы используем здесь встраивания fastText, можно экспортировать любые другие вложения, предварительно обученные из других задач, и быстро получить результаты с помощью концентратора Tensorflow.

Настроить

# https://github.com/pypa/setuptools/issues/1694#issuecomment-466010982
pip install -q gdown --no-use-pep517
sudo apt-get install -y unzip
Reading package lists...
Building dependency tree...
Reading state information...
unzip is already the newest version (6.0-21ubuntu1).
The following packages were automatically installed and are no longer required:
  dconf-gsettings-backend dconf-service dkms freeglut3 freeglut3-dev
  glib-networking glib-networking-common glib-networking-services
  gsettings-desktop-schemas libcairo-gobject2 libcolord2 libdconf1
  libegl1-mesa libepoxy0 libglu1-mesa libglu1-mesa-dev libgtk-3-0
  libgtk-3-common libice-dev libjansson4 libjson-glib-1.0-0
  libjson-glib-1.0-common libproxy1v5 librest-0.7-0 libsm-dev
  libsoup-gnome2.4-1 libsoup2.4-1 libwayland-cursor0 libwayland-egl1 libxfont2
  libxi-dev libxkbcommon0 libxkbfile1 libxmu-dev libxmu-headers libxnvctrl0
  libxt-dev linux-gcp-headers-5.0.0-1026 linux-headers-5.0.0-1026-gcp
  linux-image-5.0.0-1026-gcp linux-modules-5.0.0-1026-gcp pkg-config
  policykit-1-gnome python3-xkit screen-resolution-extra x11-xkb-utils
  xserver-common xserver-xorg-core-hwe-18.04
Use 'sudo apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 102 not upgraded.

import os

import tensorflow as tf
import tensorflow_hub as hub

import gdown
import numpy as np
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns

Набор данных

Мы будем использовать BARD (набор данных Bangla Article Dataset), который содержит около 3,76,226 статей, собранных с различных новостных порталов Bangla и разделенных на 5 категорий: экономика, штат, международный, спорт и развлечения. Мы загружаем файл с Google Диска, на который ссылается эта ссылка ( bit.ly/BARD_DATASET ) из этого репозитория GitHub.

gdown.download(
    url='https://drive.google.com/uc?id=1Ag0jd21oRwJhVFIBohmX_ogeojVtapLy',
    output='bard.zip',
    quiet=True
)
'bard.zip'
unzip -qo bard.zip

Экспорт предварительно обученных векторов слов в модуль TF-Hub

TF-концентратор обеспечивает некоторые полезные скрипты для преобразования вложения слов в TF-концентраторов модулей текст встраиванию здесь . Чтобы сделать модуль для Bangla или любых других языков, нам просто нужно загрузить файл word embedding .txt или .vec в тот же каталог, что и export_v2.py, и запустить скрипт.

Экспортер считывает векторы внедрения и экспортирует их в SavedModel Tensorflow . SavedModel содержит полную программу TensorFlow, включая веса и график. TF-Hub может загрузить SavedModel как модуль, который мы будем использовать для построения модели классификации текста. Поскольку мы используем tf.keras для построения модели, мы будем использовать hub.KerasLayer, который предоставляет оболочку для модуля концентратора, который будет использоваться в качестве слоя Keras.

Во- первых , мы получим наши слова вложения от FastText и внедренный экспортер из TF-концентратора репо .

curl -O https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz
curl -O https://raw.githubusercontent.com/tensorflow/hub/master/examples/text_embeddings_v2/export_v2.py
gunzip -qf cc.bn.300.vec.gz --k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  840M  100  840M    0     0  13.2M      0  0:01:03  0:01:03 --:--:-- 13.8M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  7493  100  7493    0     0  16396      0 --:--:-- --:--:-- --:--:-- 16360

Затем мы запустим сценарий экспорта для нашего файла внедрения. Поскольку вложения fastText имеют строку заголовка и довольно большие (около 3,3 ГБ для Bangla после преобразования в модуль), мы игнорируем первую строку и экспортируем только первые 100 000 токенов в модуль встраивания текста.

python export_v2.py --embedding_file=cc.bn.300.vec --export_path=text_module --num_lines_to_ignore=1 --num_lines_to_use=100000
2020-11-24 16:38:16.506111: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-11-24 16:38:30.883816: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2020-11-24 16:38:31.588580: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:31.589356: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-11-24 16:38:31.589426: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-11-24 16:38:31.591384: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2020-11-24 16:38:31.593231: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2020-11-24 16:38:31.593622: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2020-11-24 16:38:31.595468: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2020-11-24 16:38:31.596320: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2020-11-24 16:38:31.599992: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-11-24 16:38:31.600141: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:31.600868: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:31.601510: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
2020-11-24 16:38:31.601959: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2020-11-24 16:38:31.608245: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2000185000 Hz
2020-11-24 16:38:31.608621: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5fa99c0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-11-24 16:38:31.608653: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-11-24 16:38:31.698228: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:31.699202: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5d3bde0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2020-11-24 16:38:31.699263: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla V100-SXM2-16GB, Compute Capability 7.0
2020-11-24 16:38:31.699512: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:31.700189: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-11-24 16:38:31.700229: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-11-24 16:38:31.700284: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2020-11-24 16:38:31.700300: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2020-11-24 16:38:31.700312: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2020-11-24 16:38:31.700325: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2020-11-24 16:38:31.700335: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2020-11-24 16:38:31.700351: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-11-24 16:38:31.700419: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:31.701102: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:31.701720: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
2020-11-24 16:38:31.701767: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-11-24 16:38:32.132285: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1257] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-11-24 16:38:32.132338: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1263]      0 
2020-11-24 16:38:32.132346: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 0:   N 
2020-11-24 16:38:32.132566: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:32.133335: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-11-24 16:38:32.134035: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 14764 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0)
2020-11-24 16:38:32.351857: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 240002400 exceeds 10% of free system memory.
2020-11-24 16:38:33.188668: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 240002400 exceeds 10% of free system memory.
INFO:tensorflow:Assets written to: text_module/assets
I1124 16:38:33.586180 140643110135616 builder_impl.py:775] Assets written to: text_module/assets

module_path = "text_module"
embedding_layer = hub.KerasLayer(module_path, trainable=False)

Модуль встраивания текста принимает пакет предложений в одномерном тензоре строк в качестве входных данных и выводит векторы внедрения формы (batch_size, embedding_dim), соответствующие предложениям. Он предварительно обрабатывает ввод, разбивая его на пробелы. sqrtn слов объединяются с встраиваемыми предложениями с sqrtn комбайнера sqrtn (см. Здесь ). Для демонстрации мы передаем список слов Bangla в качестве входных данных и получаем соответствующие векторы внедрения.

embedding_layer(['বাস', 'বসবাস', 'ট্রেন', 'যাত্রী', 'ট্রাক'])
<tf.Tensor: shape=(5, 300), dtype=float64, numpy=
array([[ 0.0462, -0.0355,  0.0129, ...,  0.0025, -0.0966,  0.0216],
       [-0.0631, -0.0051,  0.085 , ...,  0.0249, -0.0149,  0.0203],
       [ 0.1371, -0.069 , -0.1176, ...,  0.029 ,  0.0508, -0.026 ],
       [ 0.0532, -0.0465, -0.0504, ...,  0.02  , -0.0023,  0.0011],
       [ 0.0908, -0.0404, -0.0536, ..., -0.0275,  0.0528,  0.0253]])>

Преобразовать в набор данных Tensorflow

Поскольку набор данных действительно велик, вместо того, чтобы загружать весь набор данных в память, мы будем использовать генератор для получения образцов во время выполнения в пакетном режиме с использованием функций набора данных Tensorflow . Набор данных также очень несбалансирован, поэтому перед использованием генератора мы перетасуем набор данных.

dir_names = ['economy', 'sports', 'entertainment', 'state', 'international']

file_paths = []
labels = []
for i, dir in enumerate(dir_names):
  file_names = ["/".join([dir, name]) for name in os.listdir(dir)]
  file_paths += file_names
  labels += [i] * len(os.listdir(dir))

np.random.seed(42)
permutation = np.random.permutation(len(file_paths))

file_paths = np.array(file_paths)[permutation]
labels = np.array(labels)[permutation]

Мы можем проверить распределение меток в примерах обучения и проверки после перетасовки.

train_frac = 0.8
train_size = int(len(file_paths) * train_frac)
# plot training vs validation distribution
plt.subplot(1, 2, 1)
plt.hist(labels[0:train_size])
plt.title("Train labels")
plt.subplot(1, 2, 2)
plt.hist(labels[train_size:])
plt.title("Validation labels")
plt.tight_layout()

PNG

Чтобы создать набор данных с использованием генератора, мы сначала пишем функцию генератора, которая считывает каждую из статей из file_paths и метки из массива меток и выдает один обучающий пример на каждом шаге. Мы передаем эту функцию генератора методу tf.data.Dataset.from_generator и указываем типы вывода. Каждый обучающий пример представляет собой кортеж, содержащий статью с типом данных tf.string и метку с горячим кодированием. Мы разделяем набор данных с разделением на проверку поезда 80-20, используя метод skip и take .

def load_file(path, label):
    return tf.io.read_file(path), label
def make_datasets(train_size):
  batch_size = 256

  train_files = file_paths[:train_size]
  train_labels = labels[:train_size]
  train_ds = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
  train_ds = train_ds.map(load_file).shuffle(5000)
  train_ds = train_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

  test_files = file_paths[train_size:]
  test_labels = labels[train_size:]
  test_ds = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
  test_ds = test_ds.map(load_file)
  test_ds = test_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)


  return train_ds, test_ds
train_data, validation_data = make_datasets(train_size)

Модельное обучение и оценка

Поскольку мы уже добавили оболочку вокруг нашего модуля, чтобы использовать его как любой другой слой в keras, мы можем создать небольшую последовательную модель, которая представляет собой линейный стек слоев. Мы можем добавить наш модуль встраивания текста с помощью model.add как и любой другой слой. Скомпилируем модель, указав потерю и оптимизатор, и обучаем ее 10 эпох. API tf.keras может обрабатывать наборы данных тензорного потока в качестве входных данных, поэтому мы можем передать экземпляр Dataset методу fit для обучения модели. Поскольку мы используем функцию генератора, tf.data будет обрабатывать создание образцов, их пакетирование и передачу в модель.

Модель

def create_model():
  model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=[], dtype=tf.string),
    embedding_layer,
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(5),
  ])
  model.compile(loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer="adam", metrics=['accuracy'])
  return model
model = create_model()
# Create earlystopping callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=3)
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Подготовка

history = model.fit(train_data, 
                    validation_data=validation_data, 
                    epochs=5, 
                    callbacks=[early_stopping_callback])
Epoch 1/5
1176/1176 [==============================] - 53s 45ms/step - loss: 0.2174 - accuracy: 0.9273 - val_loss: 0.1489 - val_accuracy: 0.9478
Epoch 2/5
1176/1176 [==============================] - 51s 43ms/step - loss: 0.1425 - accuracy: 0.9504 - val_loss: 0.1393 - val_accuracy: 0.9503
Epoch 3/5
1176/1176 [==============================] - 51s 43ms/step - loss: 0.1308 - accuracy: 0.9535 - val_loss: 0.1292 - val_accuracy: 0.9534
Epoch 4/5
1176/1176 [==============================] - 51s 43ms/step - loss: 0.1238 - accuracy: 0.9554 - val_loss: 0.1250 - val_accuracy: 0.9548
Epoch 5/5
1176/1176 [==============================] - 52s 44ms/step - loss: 0.1196 - accuracy: 0.9568 - val_loss: 0.1193 - val_accuracy: 0.9567

Оценка

Мы можем визуализировать кривые точности и потерь для данных обучения и проверки, используя объект history возвращаемый методом fit который содержит значение потерь и точности для каждой эпохи.

# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

PNG

PNG

Предсказание

Мы можем получить прогнозы для данных проверки и проверить матрицу путаницы, чтобы увидеть производительность модели для каждого из 5 классов. Как predict метод возвращает нас к - й массив для вероятностей для каждого класса , который мы преобразуем к классу меток с использованием np.argmax .

y_pred = model.predict(validation_data)
y_pred = np.argmax(y_pred, axis=1)
samples = file_paths[0:3]
for i, sample in enumerate(samples):
  f = open(sample)
  text = f.read()
  print(text[0:100])
  print("True Class: ", sample.split("/")[0])
  print("Predicted Class: ", dir_names[y_pred[i]])
  f.close()

বাংলা বর্ষ ১৪২২ বিদায় এবং নববর্ষ ১৪২৩ উদ্‌যাপন উপলক্ষে বাংলাদেশ জাতীয় জাদুঘর তিন দিনব্যাপী লোকজ ম
True Class:  entertainment
Predicted Class:  state

শেরপুরের ঝিনাইগাতী সীমান্তে ফের বন্য হাতির আক্রমণে উত্তম মারাক (৫৫) নামের এক কৃষক নিহত হয়েছেন। নিহত
True Class:  state
Predicted Class:  state

আর্জেন্টিনার খেলা চলছে। চিলির মোটেল ফ্যান্টাসিয়ায় তিল ধারণের ঠাঁই নেই। সবাই অপেক্ষা করে আছেন, কখন 
True Class:  sports
Predicted Class:  state

Сравнить производительность

Теперь мы можем взять правильные метки для данных проверки из labels и сравнить их с нашими прогнозами, чтобы получить классификационный отчет .

y_true = np.array(labels[train_size:])
print(classification_report(y_true, y_pred, target_names=dir_names))
               precision    recall  f1-score   support

      economy       0.84      0.75      0.79      3897
       sports       0.99      0.98      0.99     10204
entertainment       0.91      0.94      0.92      6256
        state       0.97      0.98      0.97     48512
international       0.94      0.92      0.93      6377

     accuracy                           0.96     75246
    macro avg       0.93      0.91      0.92     75246
 weighted avg       0.96      0.96      0.96     75246


Мы также можем сравнить производительность нашей модели с опубликованными результатами, полученными в исходной статье, которые сообщают о точности 0,96. Первоначальные авторы описали многие шаги предварительной обработки, выполняемые с набором данных, такие как удаление знаков препинания и цифр, удаление 25 самых частых стоп-слов. Как мы можем видеть в классификации_report, мы также получаем точность и точность 0,96 после обучения всего за 5 эпох без какой-либо предварительной обработки!

В этом примере, когда мы создали слой Keras из нашего модуля встраивания, мы установили trainable=False , что означает, что веса встраивания не будут обновляться во время обучения. Попробуйте установить для него значение True, чтобы достичь точности 97% с этим набором данных только с 2 эпохами.