本頁面由 Cloud Translation API 翻譯而成。
Switch to English

基本文字分類

在TensorFlow.org上查看 在Google Colab中運行 在GitHub上查看源代碼 下載筆記本

本教程演示了從存儲在磁盤上的純文本文件開始的文本分類。您將訓練一個二進制分類器以對IMDB數據集執行情感分析。在筆記本的最後,有一個練習可供您嘗試,其中您將訓練一個多類分類器來預測Stack Overflow上編程問題的標籤。

 import matplotlib.pyplot as plt
import os
import re
import shutil
import string
import tensorflow as tf

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
 
 print(tf.__version__)
 
2.3.0

情緒分析

該筆記本訓練一種情感分析模型,根據評論文本將電影評論分為正面評論或負面評論。這是二進制(或兩類)分類的示例,它是一種重要且廣泛適用的機器學習問題。

您將使用“ 大型電影評論”數據集 ,其中包含來自Internet電影數據庫的50,000個電影評論的文本。這些內容分為25,000條用於培訓的評論和25,000條用於測試的評論。培訓和測試集是平衡的 ,這意味著它們包含相同數量的正面和負面評論。

下載並瀏覽IMDB數據集

讓我們下載並提取數據集,然後探索目錄結構。

 url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

dataset = tf.keras.utils.get_file("aclImdb_v1.tar.gz", url,
                                    untar=True, cache_dir='.',
                                    cache_subdir='')

dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
 
Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
84131840/84125825 [==============================] - 7s 0us/step

 os.listdir(dataset_dir)
 
['imdb.vocab', 'train', 'test', 'README', 'imdbEr.txt']
 train_dir = os.path.join(dataset_dir, 'train')
os.listdir(train_dir)
 
['urls_pos.txt',
 'neg',
 'labeledBow.feat',
 'pos',
 'urls_neg.txt',
 'unsup',
 'unsupBow.feat',
 'urls_unsup.txt']

aclImdb/train/posaclImdb/train/neg目錄包含許多文本文件,每個文本文件都是一個電影評論。讓我們看看其中之一。

 sample_file = os.path.join(train_dir, 'pos/1181_9.txt')
with open(sample_file) as f:
  print(f.read())
 
Rachel Griffiths writes and directs this award winning short film. A heartwarming story about coping with grief and cherishing the memory of those we've loved and lost. Although, only 15 minutes long, Griffiths manages to capture so much emotion and truth onto film in the short space of time. Bud Tingwell gives a touching performance as Will, a widower struggling to cope with his wife's death. Will is confronted by the harsh reality of loneliness and helplessness as he proceeds to take care of Ruth's pet cow, Tulip. The film displays the grief and responsibility one feels for those they have loved and lost. Good cinematography, great direction, and superbly acted. It will bring tears to all those who have lost a loved one, and survived.

加載數據集

接下來,您將從磁盤上加載數據並將其準備為適合訓練的格式。為此,您將使用有用的text_dataset_from_directory實用程序,該實用程序需要如下的目錄結構。

 main_directory/
...class_a/
......a_text_1.txt
......a_text_2.txt
...class_b/
......b_text_1.txt
......b_text_2.txt
 

要準備用於二進制分類的數據集,您將在磁盤上需要兩個文件夾,分別對應於class_aclass_b 。這些將是正面和負面的電影評論,可以在aclImdb/train/posaclImdb/train/neg 。由於IMDB數據集包含其他文件夾,因此您將在使用此實用程序之前將其刪除。

 remove_dir = os.path.join(train_dir, 'unsup')
shutil.rmtree(remove_dir)
 

接下來,您將使用text_dataset_from_directory實用程序來創建標記為tf.data.Datasettf.data是用於處理數據的強大工具集合。

在進行機器學習實驗時,最好的做法是將數據集分為三個部分: trainvalidationtest

IMDB數據集已經分為訓練和測試,但是缺少驗證集。讓我們使用下面的validation_split參數,使用訓練數據的80:20比例創建驗證集。

 batch_size = 32
seed = 42

raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/train', 
    batch_size=batch_size, 
    validation_split=0.2, 
    subset='training', 
    seed=seed)
 
Found 25000 files belonging to 2 classes.
Using 20000 files for training.

如上所示,培訓文件夾中有25,000個示例,其中80%(或20,000)將用於培訓。稍後您將看到,您可以通過將數據集直接傳遞給model.fit來訓練模型。如果您不tf.data ,還可以遍歷數據集並打印出一些示例,如下所示。

 for text_batch, label_batch in raw_train_ds.take(1):
  for i in range(3):
    print("Review", text_batch.numpy()[i])
    print("Label", label_batch.numpy()[i])
 
Review b'"Pandemonium" is a horror movie spoof that comes off more stupid than funny. Believe me when I tell you, I love comedies. Especially comedy spoofs. "Airplane", "The Naked Gun" trilogy, "Blazing Saddles", "High Anxiety", and "Spaceballs" are some of my favorite comedies that spoof a particular genre. "Pandemonium" is not up there with those films. Most of the scenes in this movie had me sitting there in stunned silence because the movie wasn\'t all that funny. There are a few laughs in the film, but when you watch a comedy, you expect to laugh a lot more than a few times and that\'s all this film has going for it. Geez, "Scream" had more laughs than this film and that was more of a horror film. How bizarre is that?<br /><br />*1/2 (out of four)'
Label 0
Review b"David Mamet is a very interesting and a very un-equal director. His first movie 'House of Games' was the one I liked best, and it set a series of films with characters whose perspective of life changes as they get into complicated situations, and so does the perspective of the viewer.<br /><br />So is 'Homicide' which from the title tries to set the mind of the viewer to the usual crime drama. The principal characters are two cops, one Jewish and one Irish who deal with a racially charged area. The murder of an old Jewish shop owner who proves to be an ancient veteran of the Israeli Independence war triggers the Jewish identity in the mind and heart of the Jewish detective.<br /><br />This is were the flaws of the film are the more obvious. The process of awakening is theatrical and hard to believe, the group of Jewish militants is operatic, and the way the detective eventually walks to the final violent confrontation is pathetic. The end of the film itself is Mamet-like smart, but disappoints from a human emotional perspective.<br /><br />Joe Mantegna and William Macy give strong performances, but the flaws of the story are too evident to be easily compensated."
Label 0
Review b'Great documentary about the lives of NY firefighters during the worst terrorist attack of all time.. That reason alone is why this should be a must see collectors item.. What shocked me was not only the attacks, but the"High Fat Diet" and physical appearance of some of these firefighters. I think a lot of Doctors would agree with me that,in the physical shape they were in, some of these firefighters would NOT of made it to the 79th floor carrying over 60 lbs of gear. Having said that i now have a greater respect for firefighters and i realize becoming a firefighter is a life altering job. The French have a history of making great documentary\'s and that is what this is, a Great Documentary.....'
Label 1

請注意,評論包含原始文本(帶有標點符號和偶然的HTML標籤,如<br/> )。您將在下一節中展示如何處理這些問題。

標籤為0或1。要查看其中哪一個對應於正面和負面電影評論,可以檢查數據集上的class_names屬性。

 print("Label 0 corresponds to", raw_train_ds.class_names[0])
print("Label 1 corresponds to", raw_train_ds.class_names[1])
 
Label 0 corresponds to neg
Label 1 corresponds to pos

接下來,您將創建一個驗證和測試數據集。您將使用培訓集中剩餘的5,000條評論進行驗證。

 raw_val_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/train', 
    batch_size=batch_size, 
    validation_split=0.2, 
    subset='validation', 
    seed=seed)
 
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.

 raw_test_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/test', 
    batch_size=batch_size)
 
Found 25000 files belonging to 2 classes.

準備數據集進行訓練

接下來,您將使用有用的preprocessing.TextVectorization層對數據進行標準化,標記化和向量化。

標準化是指對文本進行預處理,通常是為了刪除標點符號或HTML元素以簡化數據集。令牌化是指將字符串拆分為令牌(例如,通過在空格上拆分,將句子拆分為單個單詞)。向量化是指將令牌轉換為數字,以便可以將其饋入神經網絡。所有這些任務都可以在此層完成。

正如您在上面看到的,評論包含各種HTML標籤,例如<br /> 。這些標籤不會被TextVectorization層中的默認標準化程序刪除(該默認標準化程序將文本轉換為小寫並默認去除標點符號,但不去除HTML)。您將編寫一個自定義標準化功能來刪除HTML。

 def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
  return tf.strings.regex_replace(stripped_html,
                                  '[%s]' % re.escape(string.punctuation),
                                  '')
 

接下來,您將創建一個TextVectorization層。您將使用此層來標準化,標記化和向量化我們的數據。您將output_mode設置為int以為每個標記創建唯一的整數索引。

請注意,您使用的是默認的拆分功能以及上面定義的自定義標準化功能。您還將為模型定義一些常量,例如顯式的maximum sequence_length ,這將導致圖層將序列填充或截斷為完全sequence_length值。

 max_features = 10000
sequence_length = 250

vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length)
 

接下來,您將調用adapt以使預處理層的狀態適合數據集。這將導致模型建立字符串到整數的索引。

 # Make a text-only dataset (without labels), then call adapt
train_text = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(train_text)
 

讓我們創建一個函數來查看使用該層預處理某些數據的結果。

 def vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return vectorize_layer(text), label
 
 # retrieve a batch (of 32 reviews and labels) from the dataset
text_batch, label_batch = next(iter(raw_train_ds))
first_review, first_label = text_batch[0], label_batch[0]
print("Review", first_review)
print("Label", raw_train_ds.class_names[first_label])
print("Vectorized review", vectorize_text(first_review, first_label))
 
Review tf.Tensor(b'Silent Night, Deadly Night 5 is the very last of the series, and like part 4, it\'s unrelated to the first three except by title and the fact that it\'s a Christmas-themed horror flick.<br /><br />Except to the oblivious, there\'s some obvious things going on here...Mickey Rooney plays a toymaker named Joe Petto and his creepy son\'s name is Pino. Ring a bell, anyone? Now, a little boy named Derek heard a knock at the door one evening, and opened it to find a present on the doorstep for him. Even though it said "don\'t open till Christmas", he begins to open it anyway but is stopped by his dad, who scolds him and sends him to bed, and opens the gift himself. Inside is a little red ball that sprouts Santa arms and a head, and proceeds to kill dad. Oops, maybe he should have left well-enough alone. Of course Derek is then traumatized by the incident since he watched it from the stairs, but he doesn\'t grow up to be some killer Santa, he just stops talking.<br /><br />There\'s a mysterious stranger lurking around, who seems very interested in the toys that Joe Petto makes. We even see him buying a bunch when Derek\'s mom takes him to the store to find a gift for him to bring him out of his trauma. And what exactly is this guy doing? Well, we\'re not sure but he does seem to be taking these toys apart to see what makes them tick. He does keep his landlord from evicting him by promising him to pay him in cash the next day and presents him with a "Larry the Larvae" toy for his kid, but of course "Larry" is not a good toy and gets out of the box in the car and of course, well, things aren\'t pretty.<br /><br />Anyway, eventually what\'s going on with Joe Petto and Pino is of course revealed, and as with the old story, Pino is not a "real boy". Pino is probably even more agitated and naughty because he suffers from "Kenitalia" (a smooth plastic crotch) so that could account for his evil ways. And the identity of the lurking stranger is revealed too, and there\'s even kind of a happy ending of sorts. Whee.<br /><br />A step up from part 4, but not much of one. Again, Brian Yuzna is involved, and Screaming Mad George, so some decent special effects, but not enough to make this great. A few leftovers from part 4 are hanging around too, like Clint Howard and Neith Hunter, but that doesn\'t really make any difference. Anyway, I now have seeing the whole series out of my system. Now if I could get some of it out of my brain. 4 out of 5.', shape=(), dtype=string)
Label neg
Vectorized review (<tf.Tensor: shape=(1, 250), dtype=int64, numpy=
array([[1287,  313, 2380,  313,  661,    7,    2,   52,  229,    5,    2,
         200,    3,   38,  170,  669,   29, 5492,    6,    2,   83,  297,
         549,   32,  410,    3,    2,  186,   12,   29,    4,    1,  191,
         510,  549,    6,    2, 8229,  212,   46,  576,  175,  168,   20,
           1, 5361,  290,    4,    1,  761,  969,    1,    3,   24,  935,
        2271,  393,    7,    1, 1675,    4, 3747,  250,  148,    4,  112,
         436,  761, 3529,  548,    4, 3633,   31,    2, 1331,   28, 2096,
           3, 2912,    9,    6,  163,    4, 1006,   20,    2,    1,   15,
          85,   53,  147,    9,  292,   89,  959, 2314,  984,   27,  762,
           6,  959,    9,  564,   18,    7, 2140,   32,   24, 1254,   36,
           1,   85,    3, 3298,   85,    6, 1410,    3, 1936,    2, 3408,
         301,  965,    7,    4,  112,  740, 1977,   12,    1, 2014, 2772,
           3,    4,  428,    3, 5177,    6,  512, 1254,    1,  278,   27,
         139,   25,  308,    1,  579,    5,  259, 3529,    7,   92, 8981,
          32,    2, 3842,  230,   27,  289,    9,   35,    2, 5712,   18,
          27,  144, 2166,   56,    6,   26,   46,  466, 2014,   27,   40,
        2745,  657,  212,    4, 1376, 3002, 7080,  183,   36,  180,   52,
         920,    8,    2, 4028,   12,  969,    1,  158,   71,   53,   67,
          85, 2754,    4,  734,   51,    1, 1611,  294,   85,    6,    2,
        1164,    6,  163,    4, 3408,   15,   85,    6,  717,   85,   44,
           5,   24, 7158,    3,   48,  604,    7,   11,  225,  384,   73,
          65,   21,  242,   18,   27,  120,  295,    6,   26,  667,  129,
        4028,  948,    6,   67,   48,  158,   93,    1]])>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)

如您在上面看到的,每個令牌都已被一個整數代替。您可以通過在圖層上調用.get_vocabulary()來查找每個整數對應的標記(字符串)。

 print("1287 ---> ",vectorize_layer.get_vocabulary()[1287])
print(" 313 ---> ",vectorize_layer.get_vocabulary()[313])
print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))
 
1287 --->  silent
 313 --->  night
Vocabulary size: 10000

您幾乎準備好訓練模型了。作為最後的預處理步驟,您將把之前創建的TextVectorization層應用於訓練,驗證和測試數據集。

 train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)
test_ds = raw_test_ds.map(vectorize_text)
 

配置數據集以提高性能

這是在加載數據時要確保I / O不會阻塞的兩種重要方法。

.cache()將數據從磁盤加載後將其保留在內存中。這將確保在訓練模型時數據集不會成為瓶頸。如果您的數據集太大而無法容納到內存中,則還可以使用此方法來創建高性能的磁盤緩存,其讀取效率比許多小文件都高。

.prefetch()在訓練時與數據預處理和模型執行重疊。

您可以在數據性能指南中了解有關這兩種方法以及如何將數據緩存到磁盤的更多信息

 AUTOTUNE = tf.data.experimental.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)
 

創建模型

是時候創建我們的神經網絡了:

 embedding_dim = 16
 
 model = tf.keras.Sequential([
  layers.Embedding(max_features + 1, embedding_dim),
  layers.Dropout(0.2),
  layers.GlobalAveragePooling1D(),
  layers.Dropout(0.2),
  layers.Dense(1)])

model.summary()
 
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 16)          160016    
_________________________________________________________________
dropout (Dropout)            (None, None, 16)          0         
_________________________________________________________________
global_average_pooling1d (Gl (None, 16)                0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 16)                0         
_________________________________________________________________
dense (Dense)                (None, 1)                 17        
=================================================================
Total params: 160,033
Trainable params: 160,033
Non-trainable params: 0
_________________________________________________________________

依次堆疊各層以構建分類器:

  1. 第一層是Embedding層。該層接受整數編碼的評論,並為每個單詞索引查找嵌入向量。這些向量是在模型訓練中學習的。向量將維度添加到輸出數組。生成的尺寸為:( (batch, sequence, embedding) 。要了解有關嵌入的更多信息,請參閱單詞嵌入教程
  2. 接下來, GlobalAveragePooling1D層通過對序列維進行平均,為每個示例返回固定長度的輸出向量。這允許模型以最簡單的方式處理可變長度的輸入。
  3. 該固定長度的輸出矢量通過具有16個隱藏單元的完全連接( Dense )層進行管道傳輸。
  4. 最後一層與單個輸出節點緊密連接。

損失函數和優化器

模型需要損失函數和用於訓練的優化器。由於這是一個二元分類問題和模型輸出的概率(以S活化的單單元層)中,將使用losses.BinaryCrossentropy損失函數。

現在,配置模型以使用優化器和損失函數:

 model.compile(loss=losses.BinaryCrossentropy(from_logits=True), optimizer='adam', metrics=tf.metrics.BinaryAccuracy(threshold=0.0))
 

訓練模型

您將通過將dataset對像傳遞給fit方法來訓練模型。

 epochs = 10
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs)
 
Epoch 1/10
625/625 [==============================] - 3s 5ms/step - loss: 0.6632 - binary_accuracy: 0.6931 - val_loss: 0.6135 - val_binary_accuracy: 0.7752
Epoch 2/10
625/625 [==============================] - 3s 4ms/step - loss: 0.5472 - binary_accuracy: 0.8003 - val_loss: 0.4968 - val_binary_accuracy: 0.8220
Epoch 3/10
625/625 [==============================] - 3s 4ms/step - loss: 0.4434 - binary_accuracy: 0.8459 - val_loss: 0.4187 - val_binary_accuracy: 0.8486
Epoch 4/10
625/625 [==============================] - 3s 4ms/step - loss: 0.3770 - binary_accuracy: 0.8660 - val_loss: 0.3726 - val_binary_accuracy: 0.8622
Epoch 5/10
625/625 [==============================] - 2s 4ms/step - loss: 0.3349 - binary_accuracy: 0.8786 - val_loss: 0.3442 - val_binary_accuracy: 0.8678
Epoch 6/10
625/625 [==============================] - 2s 4ms/step - loss: 0.3046 - binary_accuracy: 0.8889 - val_loss: 0.3253 - val_binary_accuracy: 0.8722
Epoch 7/10
625/625 [==============================] - 2s 4ms/step - loss: 0.2807 - binary_accuracy: 0.8977 - val_loss: 0.3118 - val_binary_accuracy: 0.8726
Epoch 8/10
625/625 [==============================] - 2s 4ms/step - loss: 0.2609 - binary_accuracy: 0.9046 - val_loss: 0.3026 - val_binary_accuracy: 0.8762
Epoch 9/10
625/625 [==============================] - 2s 4ms/step - loss: 0.2443 - binary_accuracy: 0.9123 - val_loss: 0.2961 - val_binary_accuracy: 0.8774
Epoch 10/10
625/625 [==============================] - 2s 4ms/step - loss: 0.2309 - binary_accuracy: 0.9163 - val_loss: 0.2915 - val_binary_accuracy: 0.8804

評估模型

讓我們看看模型的表現。將返回兩個值。損失(代表我們的錯誤的數字,較低的值更好)和準確性。

 loss, accuracy = model.evaluate(test_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)
 
782/782 [==============================] - 2s 3ms/step - loss: 0.3097 - binary_accuracy: 0.8740
Loss:  0.30967268347740173
Accuracy:  0.8740400075912476

這種相當幼稚的方法可以達到約86%的精度。

創建準確度和時間損失圖

model.fit()返回一個History對象,該對象包含一個字典,其中包含訓練期間發生的所有事情:

 history_dict = history.history
history_dict.keys()
 
dict_keys(['loss', 'binary_accuracy', 'val_loss', 'val_binary_accuracy'])

有四個條目:在訓練和驗證期間每個受監視的指標一個。您可以使用這些來繪製訓練和驗證損失以進行比較,以及訓練和驗證準確性:

 acc = history_dict['binary_accuracy']
val_acc = history_dict['val_binary_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']

epochs = range(1, len(acc) + 1)

# "bo" is for "blue dot"
plt.plot(epochs, loss, 'bo', label='Training loss')
# b is for "solid blue line"
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()
 

png

 plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')

plt.show()
 

png

在該圖中,點表示訓練損失和準確性,實線表示驗證損失和準確性。

請注意,訓練損失在每個時期都會減少 ,而訓練精度在每個時期都增加 。當使用梯度下降優化時,這是可以預期的-它應該在每次迭代中將所需的數量最小化。

驗證損失和準確性不是這種情況,它們似乎在訓練準確性之前達到頂峰。這是過度擬合的一個例子:該模型在訓練數據上的性能比在以前從未見過的數據上更好。此後,模型會過度優化並學習特定於訓練數據的表示,而這些表示不能推廣到測試數據。

對於這種特殊情況,您可以通過在驗證準確性不再提高時停止訓練來防止過度擬合。一種方法是使用EarlyStopping回調

導出模型

在上面的代碼中,在將文本輸入模型之前, TextVectorization圖層應用於數據集。如果要使模型能夠處理原始字符串(例如,為了簡化部署),可以在模型內部包括TextVectorization層。為此,您可以使用剛剛訓練的權重創建一個新模型。

 export_model = tf.keras.Sequential([
  vectorize_layer,
  model,
  layers.Activation('sigmoid')
])

export_model.compile(
    loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy']
)

# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_test_ds)
print(accuracy)
 
782/782 [==============================] - 3s 4ms/step - loss: 0.3097 - accuracy: 0.8740
0.8740400075912476

在模型中包括文本預處理邏輯,可以導出用於生產的模型,從而簡化部署並減少訓練/測試偏斜的可能性。

選擇將TextVectorization圖層應用到何處時,要記住性能差異。在模型之外使用它可以使您在GPU上訓練時進行異步CPU處理和數據緩衝。因此,如果要在GPU上訓練模型,則可能希望使用此選項在開發模型時獲得最佳性能,然後在準備準備進行部署時切換到在模型內部包括TextVectorization層。 。

訪問本教程以了解有關保存模型的更多信息。

練習:關於堆棧溢出問題的多類分類

本教程顯示瞭如何從頭訓練IMDB數據集上的二進制分類器。作為練習,您可以修改此筆記本以訓練多類分類器,以預測Stack Overflow上編程問題的標記。

我們準備了一個數據集供您使用,其中包含數千個編程問題(例如,“如何在Python中按值對字典進行排序?”的正文),該問題已發佈到Stack Overflow。其中的每一個都標記有一個標籤(Python,CSharp,JavaScript或Java)。您的任務是將一個問題作為輸入,並預測適當的標籤(在本例中為Python)。

您將使用的數據集包含從BigQuery上更大的公共Stack Overflow數據集中提取的數千個問題,該數據集包含1700萬個帖子。

下載數據集後,您會發現它具有與以前使用的IMDB數據集相似的目錄結構:

 train/
...python/
......0.txt
......1.txt
...javascript/
......0.txt
......1.txt
...csharp/
......0.txt
......1.txt
...java/
......0.txt
......1.txt
 

要完成此練習,您應該通過執行以下修改來修改此筆記本以使用Stack Overflow數據集:

  1. 在筆記本頂部,用代碼更新下載IMDB數據集的代碼,以下載我們預先設定的Stack Overflow數據集 。由於Stack Overflow數據集具有類似的目錄結構,因此您無需進行太多修改。

  2. 修改模型的最後一層以讀取Dense(4) ,因為現在有四個輸出類。

  3. 編譯模型時,將損失更改為SparseCategoricalCrossentropy 。這是使用一個多類分類問題的正確損失功能,當為每個類標籤為整數(在我們的情況下,它們可以是0,1,2,3)。

  4. 這些更改完成後,您將能夠訓練多類分類器。

如果遇到困難,可以在這裡找到解決方案。

了解更多

本教程從頭開始介紹了文本分類。要總體上了解有關文本分類工作流程的更多信息,建議您閱讀Google Developers的本指南

 
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.