![]() |
![]() |
![]() |
![]() |
This colab is a demonstration of using Tensorflow Hub for text classification in non-English/local languages. Here we choose Bangla as the local language and use pretrained word embeddings to solve a multiclass classification task where we classify Bangla news articles in 5 categories. The pretrained embeddings for Bangla comes from fastText which is a library by Facebook with released pretrained word vectors for 157 languages.
We'll use TF-Hub's pretrained embedding exporter for converting the word embeddings to a text embedding module first and then use the module to train a classifier with tf.keras, Tensorflow's high level user friendly API to build deep learning models. Even if we are using fastText embeddings here, it's possible to export any other embeddings pretrained from other tasks and quickly get results with Tensorflow hub.
Setup
# https://github.com/pypa/setuptools/issues/1694#issuecomment-466010982
pip install -q gdown --no-use-pep517
sudo apt-get install -y unzip
Reading package lists... Building dependency tree... Reading state information... unzip is already the newest version (6.0-21ubuntu1). The following packages were automatically installed and are no longer required: dconf-gsettings-backend dconf-service dkms freeglut3 freeglut3-dev glib-networking glib-networking-common glib-networking-services gsettings-desktop-schemas libcairo-gobject2 libcolord2 libdconf1 libegl1-mesa libepoxy0 libglu1-mesa libglu1-mesa-dev libgtk-3-0 libgtk-3-common libice-dev libjansson4 libjson-glib-1.0-0 libjson-glib-1.0-common libproxy1v5 librest-0.7-0 libsm-dev libsoup-gnome2.4-1 libsoup2.4-1 libwayland-cursor0 libwayland-egl1 libxfont2 libxi-dev libxkbcommon0 libxkbfile1 libxmu-dev libxmu-headers libxnvctrl0 libxt-dev linux-gcp-headers-5.0.0-1026 linux-headers-5.0.0-1026-gcp linux-image-5.0.0-1026-gcp linux-modules-5.0.0-1026-gcp pkg-config policykit-1-gnome python3-xkit screen-resolution-extra x11-xkb-utils xserver-common xserver-xorg-core-hwe-18.04 Use 'sudo apt autoremove' to remove them. 0 upgraded, 0 newly installed, 0 to remove and 102 not upgraded.
import os
import tensorflow as tf
import tensorflow_hub as hub
import gdown
import numpy as np
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns
Dataset
We will use BARD (Bangla Article Dataset) which has around 3,76,226 articles collected from different Bangla news portals and labelled with 5 categories : economy, state, international, sports and entertainment. We download the file from Google Drive this (bit.ly/BARD_DATASET) link is referring to from this GitHub repository.
gdown.download(
url='https://drive.google.com/uc?id=1Ag0jd21oRwJhVFIBohmX_ogeojVtapLy',
output='bard.zip',
quiet=True
)
'bard.zip'
unzip -qo bard.zip
Export pretrained word vectors to TF-Hub module
TF-Hub provides some handy scripts for converting word embeddings to TF-hub text embedding modules here. To make the module for Bangla or any other languages we simply have to download the word embedding .txt or .vec file to the same directory as export_v2.py and run the script.
The exporter reads the embedding vectors and exports it to a Tensorflow SavedModel. A SavedModel contains a complete TensorFlow program including weights and graph. TF-Hub can load the SavedModel as a module which we will use to build the model for text classification. Since we are using tf.keras to build the model we will use hub.KerasLayer which provides a wrapper for a hub module to use as a Keras Layer.
First we will get our word embeddings from fastText and embedding exporter from TF-Hub repo.
curl -O https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz
curl -O https://raw.githubusercontent.com/tensorflow/hub/master/examples/text_embeddings_v2/export_v2.py
gunzip -qf cc.bn.300.vec.gz --k
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 840M 100 840M 0 0 13.2M 0 0:01:03 0:01:03 --:--:-- 13.8M % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 7493 100 7493 0 0 16396 0 --:--:-- --:--:-- --:--:-- 16360
Then we will run the exporter script on our embedding file. Since fastText embeddings has a header line and are pretty large(around 3.3 GB for bangla after converting to a module) we ignore the first line and export only the first 100, 000 tokens to the text embedding module.
python export_v2.py --embedding_file=cc.bn.300.vec --export_path=text_module --num_lines_to_ignore=1 --num_lines_to_use=100000
2020-11-24 16:38:16.506111: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1 2020-11-24 16:38:30.883816: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1 2020-11-24 16:38:31.588580: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:31.589356: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0 coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s 2020-11-24 16:38:31.589426: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1 2020-11-24 16:38:31.591384: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10 2020-11-24 16:38:31.593231: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10 2020-11-24 16:38:31.593622: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10 2020-11-24 16:38:31.595468: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10 2020-11-24 16:38:31.596320: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10 2020-11-24 16:38:31.599992: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7 2020-11-24 16:38:31.600141: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:31.600868: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:31.601510: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0 2020-11-24 16:38:31.601959: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2020-11-24 16:38:31.608245: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2000185000 Hz 2020-11-24 16:38:31.608621: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5fa99c0 initialized for platform Host (this does not guarantee that XLA will be used). Devices: 2020-11-24 16:38:31.608653: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version 2020-11-24 16:38:31.698228: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:31.699202: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5d3bde0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: 2020-11-24 16:38:31.699263: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Tesla V100-SXM2-16GB, Compute Capability 7.0 2020-11-24 16:38:31.699512: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:31.700189: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0 coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s 2020-11-24 16:38:31.700229: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1 2020-11-24 16:38:31.700284: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10 2020-11-24 16:38:31.700300: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10 2020-11-24 16:38:31.700312: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10 2020-11-24 16:38:31.700325: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10 2020-11-24 16:38:31.700335: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10 2020-11-24 16:38:31.700351: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7 2020-11-24 16:38:31.700419: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:31.701102: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:31.701720: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0 2020-11-24 16:38:31.701767: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1 2020-11-24 16:38:32.132285: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1257] Device interconnect StreamExecutor with strength 1 edge matrix: 2020-11-24 16:38:32.132338: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1263] 0 2020-11-24 16:38:32.132346: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 0: N 2020-11-24 16:38:32.132566: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:32.133335: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2020-11-24 16:38:32.134035: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 14764 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0) 2020-11-24 16:38:32.351857: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 240002400 exceeds 10% of free system memory. 2020-11-24 16:38:33.188668: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 240002400 exceeds 10% of free system memory. INFO:tensorflow:Assets written to: text_module/assets I1124 16:38:33.586180 140643110135616 builder_impl.py:775] Assets written to: text_module/assets
module_path = "text_module"
embedding_layer = hub.KerasLayer(module_path, trainable=False)
The text embedding module takes a batch of sentences in a 1D tensor of strings as input and outputs the embedding vectors of shape (batch_size, embedding_dim) corresponding to the sentences. It preprocesses the input by splitting on spaces. Word embeddings are combined to sentence embeddings with the sqrtn
combiner(See here). For demonstration we pass a list of Bangla words as input and get the corresponding embedding vectors.
embedding_layer(['বাস', 'বসবাস', 'ট্রেন', 'যাত্রী', 'ট্রাক'])
<tf.Tensor: shape=(5, 300), dtype=float64, numpy= array([[ 0.0462, -0.0355, 0.0129, ..., 0.0025, -0.0966, 0.0216], [-0.0631, -0.0051, 0.085 , ..., 0.0249, -0.0149, 0.0203], [ 0.1371, -0.069 , -0.1176, ..., 0.029 , 0.0508, -0.026 ], [ 0.0532, -0.0465, -0.0504, ..., 0.02 , -0.0023, 0.0011], [ 0.0908, -0.0404, -0.0536, ..., -0.0275, 0.0528, 0.0253]])>
Convert to Tensorflow Dataset
Since the dataset is really large instead of loading the entire dataset in memory we will use a generator to yield samples in run-time in batches using Tensorflow Dataset functionalities. The dataset is also very imbalanced, so before using the generator we will shuffle the dataset.
dir_names = ['economy', 'sports', 'entertainment', 'state', 'international']
file_paths = []
labels = []
for i, dir in enumerate(dir_names):
file_names = ["/".join([dir, name]) for name in os.listdir(dir)]
file_paths += file_names
labels += [i] * len(os.listdir(dir))
np.random.seed(42)
permutation = np.random.permutation(len(file_paths))
file_paths = np.array(file_paths)[permutation]
labels = np.array(labels)[permutation]
We can check the distribution of labels in the training and validation examples after shuffling.
train_frac = 0.8
train_size = int(len(file_paths) * train_frac)
# plot training vs validation distribution
plt.subplot(1, 2, 1)
plt.hist(labels[0:train_size])
plt.title("Train labels")
plt.subplot(1, 2, 2)
plt.hist(labels[train_size:])
plt.title("Validation labels")
plt.tight_layout()
To create a Dataset using generator we first write a generator function which reads each of the articles from file_paths and the labels from the label array, and yields one training example at each step. We pass this generator function to the tf.data.Dataset.from_generator method and specify the output types. Each training example is a tuple containing an article of tf.string data type and one-hot encoded label. We split the dataset with a train-validation split of 80-20 using the skip
and take
method.
def load_file(path, label):
return tf.io.read_file(path), label
def make_datasets(train_size):
batch_size = 256
train_files = file_paths[:train_size]
train_labels = labels[:train_size]
train_ds = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
train_ds = train_ds.map(load_file).shuffle(5000)
train_ds = train_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
test_files = file_paths[train_size:]
test_labels = labels[train_size:]
test_ds = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
test_ds = test_ds.map(load_file)
test_ds = test_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
return train_ds, test_ds
train_data, validation_data = make_datasets(train_size)
Model Training and Evaluation
Since we have already added a wrapper around our module to use it as any other layer in keras we can create a small Sequential model which is a linear stack of layers. We can add our text embedding module with model.add
just like any other layer. We compile the model by specifying the loss and optimizer and train it for 10 epochs. tf.keras
API can handle tensorflow datasets as input, so we can pass a Dataset instance to the fit method for model training. Since we are using a generator function, tf.data will handle generating the samples, batching them and feeding them to the model.
Model
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=[], dtype=tf.string),
embedding_layer,
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(5),
])
model.compile(loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="adam", metrics=['accuracy'])
return model
model = create_model()
# Create earlystopping callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=3)
WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx. If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2. To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor. Warning:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx. If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2. To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
Training
history = model.fit(train_data,
validation_data=validation_data,
epochs=5,
callbacks=[early_stopping_callback])
Epoch 1/5 1176/1176 [==============================] - 53s 45ms/step - loss: 0.2174 - accuracy: 0.9273 - val_loss: 0.1489 - val_accuracy: 0.9478 Epoch 2/5 1176/1176 [==============================] - 51s 43ms/step - loss: 0.1425 - accuracy: 0.9504 - val_loss: 0.1393 - val_accuracy: 0.9503 Epoch 3/5 1176/1176 [==============================] - 51s 43ms/step - loss: 0.1308 - accuracy: 0.9535 - val_loss: 0.1292 - val_accuracy: 0.9534 Epoch 4/5 1176/1176 [==============================] - 51s 43ms/step - loss: 0.1238 - accuracy: 0.9554 - val_loss: 0.1250 - val_accuracy: 0.9548 Epoch 5/5 1176/1176 [==============================] - 52s 44ms/step - loss: 0.1196 - accuracy: 0.9568 - val_loss: 0.1193 - val_accuracy: 0.9567
Evaluation
We can visualize the accuracy and loss curves for training and validation data using the history
object returned by the fit
method which contains the loss and accuracy value for each epoch.
# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
Prediction
We can get the predictions for the validation data and check the confusion matrix to see the model's performance for each of the 5 classes. As predict
method returns us the n-d array for probabilities for each class which we convert to class labels using np.argmax
.
y_pred = model.predict(validation_data)
y_pred = np.argmax(y_pred, axis=1)
samples = file_paths[0:3]
for i, sample in enumerate(samples):
f = open(sample)
text = f.read()
print(text[0:100])
print("True Class: ", sample.split("/")[0])
print("Predicted Class: ", dir_names[y_pred[i]])
f.close()
বাংলা বর্ষ ১৪২২ বিদায় এবং নববর্ষ ১৪২৩ উদ্যাপন উপলক্ষে বাংলাদেশ জাতীয় জাদুঘর তিন দিনব্যাপী লোকজ ম True Class: entertainment Predicted Class: state শেরপুরের ঝিনাইগাতী সীমান্তে ফের বন্য হাতির আক্রমণে উত্তম মারাক (৫৫) নামের এক কৃষক নিহত হয়েছেন। নিহত True Class: state Predicted Class: state আর্জেন্টিনার খেলা চলছে। চিলির মোটেল ফ্যান্টাসিয়ায় তিল ধারণের ঠাঁই নেই। সবাই অপেক্ষা করে আছেন, কখন True Class: sports Predicted Class: state
Compare Performance
Now we can take the correct labels for the validation data from labels
and compare it with our predictions to get the classification_report.
y_true = np.array(labels[train_size:])
print(classification_report(y_true, y_pred, target_names=dir_names))
precision recall f1-score support economy 0.84 0.75 0.79 3897 sports 0.99 0.98 0.99 10204 entertainment 0.91 0.94 0.92 6256 state 0.97 0.98 0.97 48512 international 0.94 0.92 0.93 6377 accuracy 0.96 75246 macro avg 0.93 0.91 0.92 75246 weighted avg 0.96 0.96 0.96 75246
We can also compare our model's performance with the published results obtained in the original paper who report a 0.96 precision .The original authors described many preprocessing steps done on the dataset like dropping punctuations and digits, removing top 25 most frequest stop words. As we can see in the classification_report, we also gain a 0.96 precision and accuracy after training only 5 epochs without any preprocessing!
In this example when we created the Keras layer from our embedding module we set trainable=False
, which means the embedding weights will not be updated during training. Try setting it to True to reach 97% accuracy with this dataset with only 2 epochs.