Cette page a été traduite par l'API Cloud Translation.
Switch to English

Classification des articles Bangla avec TF-Hub

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher sur GitHub Télécharger le carnet

Ce colab est une démonstration de l'utilisation de Tensorflow Hub pour la classification de texte dans des langues non anglaises / locales. Ici, nous choisissons Bangla comme langue locale et utilisons des incorporations de mots pré-entraînés pour résoudre une tâche de classification multiclasse où nous classons les articles de presse Bangla en 5 catégories. Les incorporations pré-entraînées pour Bangla proviennent de fastText, une bibliothèque de Facebook avec des vecteurs de mots pré-entraînés publiés pour 157 langues.

Nous utiliserons d'abord l'exportateur d'intégration pré-formé de TF-Hub pour convertir les embeddings de mots en un module d'intégration de texte, puis utiliserons le module pour former un classificateur avec tf.keras , l'API conviviale de haut niveau de Tensorflow pour créer des modèles d'apprentissage en profondeur. Même si nous utilisons les incorporations fastText ici, il est possible d'exporter d'autres incorporations pré-entraînées à partir d'autres tâches et d'obtenir rapidement des résultats avec le hub Tensorflow.

Installer

# https://github.com/pypa/setuptools/issues/1694#issuecomment-466010982
pip install -q gdown --no-use-pep517
sudo apt-get install -y unzip
Reading package lists...
Building dependency tree...
Reading state information...
unzip is already the newest version (6.0-21ubuntu1).
The following packages were automatically installed and are no longer required:
  dconf-gsettings-backend dconf-service dkms freeglut3 freeglut3-dev
  glib-networking glib-networking-common glib-networking-services
  gsettings-desktop-schemas libcairo-gobject2 libcolord2 libdconf1
  libegl1-mesa libepoxy0 libglu1-mesa libglu1-mesa-dev libgtk-3-0
  libgtk-3-common libice-dev libjansson4 libjson-glib-1.0-0
  libjson-glib-1.0-common libproxy1v5 librest-0.7-0 libsm-dev
  libsoup-gnome2.4-1 libsoup2.4-1 libwayland-cursor0 libwayland-egl1 libxfont2
  libxi-dev libxkbcommon0 libxkbfile1 libxmu-dev libxmu-headers libxnvctrl0
  libxt-dev linux-gcp-headers-5.0.0-1026 linux-headers-5.0.0-1026-gcp
  linux-image-5.0.0-1026-gcp linux-modules-5.0.0-1026-gcp pkg-config
  policykit-1-gnome python3-xkit screen-resolution-extra x11-xkb-utils
  xserver-common xserver-xorg-core-hwe-18.04
Use 'sudo apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 94 not upgraded.

import os

import tensorflow as tf
import tensorflow_hub as hub

import gdown
import numpy as np
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns

Base de données

Nous utiliserons BARD (Bangla Article Dataset) qui compte environ 3 76226 articles collectés à partir de différents portails d'informations Bangla et étiquetés avec 5 catégories: économie, état, international, sports et divertissement. Nous téléchargeons le fichier à partir de Google Drive auquel ce lien ( bit.ly/BARD_DATASET ) fait référence à partir de ce référentiel GitHub.

gdown.download(
    url='https://drive.google.com/uc?id=1Ag0jd21oRwJhVFIBohmX_ogeojVtapLy',
    output='bard.zip',
    quiet=True
)
'bard.zip'
unzip -qo bard.zip

Exporter des vecteurs de mots pré-entraînés vers le module TF-Hub

TF-Hub fournit des scripts pratiques pour convertir mot incorporations aux modules de plongement texte TF-hub ici . Pour créer le module pour Bangla ou toute autre langue, nous devons simplement télécharger le fichier .txt ou .vec d'incorporation de mots dans le même répertoire que export_v2.py et exécuter le script.

L'exportateur lit les vecteurs d'incorporation et les exporte vers un modèle Tensorflow SavedModel . Un SavedModel contient un programme TensorFlow complet comprenant des poids et un graphique. TF-Hub peut charger le SavedModel en tant que module que nous utiliserons pour créer le modèle de classification de texte. Puisque nous utilisons tf.keras pour construire le modèle, nous utiliserons hub.KerasLayer qui fournit un wrapper pour un module hub à utiliser comme couche Keras.

Tout d'abord, nous obtiendrons nos intégrations de mots de fastText et l'exportateur d'intégration à partir du repo TF-Hub.

curl -O https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz
curl -O https://raw.githubusercontent.com/tensorflow/hub/master/examples/text_embeddings_v2/export_v2.py
gunzip -qf cc.bn.300.vec.gz --k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  840M  100  840M    0     0  10.9M      0  0:01:17  0:01:17 --:--:-- 11.2M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  7493  100  7493    0     0  26017      0 --:--:-- --:--:-- --:--:-- 25927

Ensuite, nous exécuterons le script d'exportateur sur notre fichier d'intégration. Étant donné que les incorporations fastText ont une ligne d'en-tête et sont assez volumineuses (environ 3,3 Go pour Bangla après la conversion en module), nous ignorons la première ligne et n'exportons que les 100 000 premiers jetons vers le module d'intégration de texte.

python export_v2.py --embedding_file=cc.bn.300.vec --export_path=text_module --num_lines_to_ignore=1 --num_lines_to_use=100000
2020-10-02 11:10:32.781365: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-10-02 11:10:47.823850: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2020-10-02 11:10:48.521345: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:48.522054: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-10-02 11:10:48.522094: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-10-02 11:10:48.524160: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2020-10-02 11:10:48.526054: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2020-10-02 11:10:48.526480: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2020-10-02 11:10:48.528338: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2020-10-02 11:10:48.529195: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2020-10-02 11:10:48.532817: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-10-02 11:10:48.532969: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:48.533746: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:48.534368: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
2020-10-02 11:10:48.534755: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2020-10-02 11:10:48.541500: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2000160000 Hz
2020-10-02 11:10:48.541926: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0xdbebdb0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-10-02 11:10:48.541959: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-10-02 11:10:48.634142: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:48.634959: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x40ffc80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2020-10-02 11:10:48.634996: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla V100-SXM2-16GB, Compute Capability 7.0
2020-10-02 11:10:48.635284: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:48.635947: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-10-02 11:10:48.636007: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-10-02 11:10:48.636050: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2020-10-02 11:10:48.636097: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2020-10-02 11:10:48.636109: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2020-10-02 11:10:48.636122: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2020-10-02 11:10:48.636136: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2020-10-02 11:10:48.636152: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-10-02 11:10:48.636231: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:48.636897: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:48.637567: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
2020-10-02 11:10:48.637610: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-10-02 11:10:49.073177: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1257] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-10-02 11:10:49.073239: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1263]      0 
2020-10-02 11:10:49.073248: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 0:   N 
2020-10-02 11:10:49.073480: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:49.074233: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-10-02 11:10:49.074894: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 14764 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0)
INFO:tensorflow:Assets written to: text_module/assets
I1002 11:10:50.591525 139665650693952 builder_impl.py:775] Assets written to: text_module/assets

module_path = "text_module"
embedding_layer = hub.KerasLayer(module_path, trainable=False)

Le module d'intégration de texte prend un lot de phrases dans un tenseur 1D de chaînes en entrée et délivre les vecteurs d'incorporation de forme (batch_size, embedding_dim) correspondant aux phrases. Il prétraite l'entrée en fractionnant les espaces. Les incorporations de mots sont combinées aux incorporations de phrases avec le combineur sqrtn (voir ici ). Pour la démonstration, nous passons une liste de mots Bangla en entrée et obtenons les vecteurs d'incorporation correspondants.

embedding_layer(['বাস', 'বসবাস', 'ট্রেন', 'যাত্রী', 'ট্রাক']) 
<tf.Tensor: shape=(5, 300), dtype=float64, numpy=
array([[ 0.0462, -0.0355,  0.0129, ...,  0.0025, -0.0966,  0.0216],
       [-0.0631, -0.0051,  0.085 , ...,  0.0249, -0.0149,  0.0203],
       [ 0.1371, -0.069 , -0.1176, ...,  0.029 ,  0.0508, -0.026 ],
       [ 0.0532, -0.0465, -0.0504, ...,  0.02  , -0.0023,  0.0011],
       [ 0.0908, -0.0404, -0.0536, ..., -0.0275,  0.0528,  0.0253]])>

Convertir en ensemble de données Tensorflow

Étant donné que l'ensemble de données est vraiment volumineux au lieu de charger l'ensemble de données en mémoire, nous utiliserons un générateur pour générer des échantillons à l'exécution par lots à l'aide des fonctionnalités de l'ensemble de données Tensorflow . L'ensemble de données est également très déséquilibré, donc avant d'utiliser le générateur, nous allons mélanger l'ensemble de données.

dir_names = ['economy', 'sports', 'entertainment', 'state', 'international']

file_paths = []
labels = []
for i, dir in enumerate(dir_names):
  file_names = ["/".join([dir, name]) for name in os.listdir(dir)]
  file_paths += file_names
  labels += [i] * len(os.listdir(dir))
  
np.random.seed(42)
permutation = np.random.permutation(len(file_paths))

file_paths = np.array(file_paths)[permutation]
labels = np.array(labels)[permutation]

Nous pouvons vérifier la distribution des étiquettes dans les exemples de formation et de validation après mélange.

train_frac = 0.8
train_size = int(len(file_paths) * train_frac)
# plot training vs validation distribution
plt.subplot(1, 2, 1)
plt.hist(labels[0:train_size])
plt.title("Train labels")
plt.subplot(1, 2, 2)
plt.hist(labels[train_size:])
plt.title("Validation labels")
plt.tight_layout()

png

Pour créer un ensemble de données à l' aide du générateur, nous écrivons d'abord une fonction de générateur qui lit chacun des articles de file_paths et les étiquettes du tableau d'étiquettes, et donne un exemple d'apprentissage à chaque étape. Nous passons cette fonction de générateur à la méthode tf.data.Dataset.from_generator et spécifions les types de sortie. Chaque exemple d'apprentissage est un tuple contenant un article de type de données tf.string et une étiquette codée à chaud. Nous avons divisé l'ensemble de données avec un fractionnement de validation de train de 80-20 en utilisant la méthode skip and take .

def load_file(path, label):
    return tf.io.read_file(path), label
def make_datasets(train_size):
  batch_size = 256

  train_files = file_paths[:train_size]
  train_labels = labels[:train_size]
  train_ds = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
  train_ds = train_ds.map(load_file).shuffle(5000)
  train_ds = train_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

  test_files = file_paths[train_size:]
  test_labels = labels[train_size:]
  test_ds = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
  test_ds = test_ds.map(load_file)
  test_ds = test_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)


  return train_ds, test_ds
train_data, validation_data = make_datasets(train_size)

Formation et évaluation des modèles

Puisque nous avons déjà ajouté un wrapper autour de notre module pour l'utiliser comme n'importe quelle autre couche dans keras, nous pouvons créer un petit modèle séquentiel qui est une pile linéaire de couches. Nous pouvons ajouter notre module d'incorporation de texte avec model.add comme n'importe quel autre calque. Nous compilons le modèle en spécifiant la perte et l'optimiseur et le formons pour 10 époques. tf.keras API tf.keras peut gérer les ensembles de données tensorflow en entrée, afin que nous puissions transmettre une instance de Dataset à la méthode fit pour l'entraînement du modèle. Puisque nous utilisons une fonction de générateur, tf.data se chargera de générer les échantillons, de les regrouper et de les alimenter dans le modèle.

Modèle

def create_model():
  model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=[], dtype=tf.string),
    embedding_layer,
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(5),
  ])
  model.compile(loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer="adam", metrics=['accuracy'])
  return model
model = create_model()
# Create earlystopping callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=3)
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Warning:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.


Entraînement


history = model.fit(train_data, 
                    validation_data=validation_data, 
                    epochs=5, 
                    callbacks=[early_stopping_callback])
Epoch 1/5
1176/1176 [==============================] - 55s 47ms/step - loss: 0.2299 - accuracy: 0.9230 - val_loss: 0.1513 - val_accuracy: 0.9470
Epoch 2/5
1176/1176 [==============================] - 54s 46ms/step - loss: 0.1406 - accuracy: 0.9508 - val_loss: 0.1333 - val_accuracy: 0.9525
Epoch 3/5
1176/1176 [==============================] - 54s 46ms/step - loss: 0.1285 - accuracy: 0.9543 - val_loss: 0.1255 - val_accuracy: 0.9548
Epoch 4/5
1176/1176 [==============================] - 54s 46ms/step - loss: 0.1216 - accuracy: 0.9561 - val_loss: 0.1223 - val_accuracy: 0.9559
Epoch 5/5
1176/1176 [==============================] - 54s 46ms/step - loss: 0.1166 - accuracy: 0.9576 - val_loss: 0.1180 - val_accuracy: 0.9576

Évaluation

Nous pouvons visualiser les courbes de précision et de perte pour les données d'apprentissage et de validation à l'aide de l'objet history renvoyé par la méthode d' fit qui contient la valeur de perte et de précision pour chaque époque.

# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

png

png

Prédiction

Nous pouvons obtenir les prédictions pour les données de validation et vérifier la matrice de confusion pour voir les performances du modèle pour chacune des 5 classes. Comme la méthode predict nous renvoie le nd tableau des probabilités pour chaque classe que nous convertissons en étiquettes de classe en utilisant np.argmax .

y_pred = model.predict(validation_data)
y_pred = np.argmax(y_pred, axis=1)
samples = file_paths[0:3]
for i, sample in enumerate(samples):
  f = open(sample)
  text = f.read()
  print(text[0:100])
  print("True Class: ", sample.split("/")[0])
  print("Predicted Class: ", dir_names[y_pred[i]])
  f.close()
  

গল্প যেমনই হোক, নায়ক কে, সেটাই তো দেখার বিষয়। অজয় দেবগনের সঙ্গে অভিনয়ের সুযোগটা তাই হাতছাড়া করতে চ
True Class:  entertainment
Predicted Class:  state

বাংলাদেশ টেলিভিশন চট্টগ্রাম কেন্দ্র থেকে প্রতিদিন ছয় ঘণ্টা অনুষ্ঠান সম্প্রচারের দাবির সঙ্গে একাত্মত
True Class:  state
Predicted Class:  state

নিজের ৪১তম লিস্ট ‘এ’ ম্যাচে এসে প্রথম সেঞ্চুরিটি পেলেন নুরুল হাসান। জাতীয় ক্রিকেট দলের নবীন এই সদস্
True Class:  sports
Predicted Class:  state

Comparez les performances

Maintenant, nous pouvons prendre les bonnes étiquettes pour les données de validation des labels et les comparer avec nos prédictions pour obtenir le classement_report .

y_true = np.array(labels[train_size:])
print(classification_report(y_true, y_pred, target_names=dir_names))
               precision    recall  f1-score   support

      economy       0.79      0.82      0.81      3897
       sports       0.98      0.99      0.99     10204
entertainment       0.92      0.94      0.93      6256
        state       0.97      0.97      0.97     48512
international       0.93      0.93      0.93      6377

     accuracy                           0.96     75246
    macro avg       0.92      0.93      0.92     75246
 weighted avg       0.96      0.96      0.96     75246


Nous pouvons également comparer les performances de notre modèle avec les résultats publiés obtenus dans l' article original qui rapportent une précision de 0,96. Les auteurs originaux ont décrit de nombreuses étapes de prétraitement effectuées sur l'ensemble de données, telles que la suppression des ponctuations et des chiffres, la suppression des 25 mots vides les plus fréquents. Comme nous pouvons le voir dans le classement_report, nous gagnons également une précision et une précision de 0.96 après un entraînement de seulement 5 époques sans aucun prétraitement!

Dans cet exemple, lorsque nous avons créé la couche Keras à partir de notre module d'intégration, nous définissons trainable=False , ce qui signifie que les poids d'incorporation ne seront pas mis à jour pendant l'entraînement. Essayez de le définir sur True pour atteindre une précision de 97% avec cet ensemble de données avec seulement 2 époques.