Pomoc chronić Wielkiej Rafy Koralowej z TensorFlow na Kaggle Dołącz Wyzwanie

Wczytaj tekst

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHubPobierz notatnik

W tym samouczku przedstawiono dwa sposoby ładowania i wstępnego przetwarzania tekstu.

# Be sure you're using the stable versions of both `tensorflow` and
# `tensorflow-text`, for binary compatibility.
pip uninstall -y tf-nightly keras-nightly
pip install tensorflow
pip install tensorflow-text
import collections
import pathlib

import tensorflow as tf

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import utils
from tensorflow.keras.layers import TextVectorization

import tensorflow_datasets as tfds
import tensorflow_text as tf_text

Przykład 1: Przewidywanie tagu dla pytania o przepełnienie stosu

Jako pierwszy przykład pobierzesz zestaw danych z pytaniami programistycznymi ze Stack Overflow. Każde pytanie ( „Jak mogę posortować słownika wartości?”) Jest oznaczony z dokładnie jeden tag ( Python , CSharp , JavaScript lub Java ). Twoim zadaniem jest opracowanie modelu, który przewiduje tag dla pytania. Jest to przykład klasyfikacji wieloklasowej — ważnego i szeroko stosowanego problemu uczenia maszynowego.

Pobierz i poznaj zbiór danych

Rozpocznij pobierając zestaw danych z przepełnieniem stosu przy użyciu tf.keras.utils.get_file i odkrywania struktury katalogów:

data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz'

dataset_dir = utils.get_file(
    origin=data_url,
    untar=True,
    cache_dir='stack_overflow',
    cache_subdir='')

dataset_dir = pathlib.Path(dataset_dir).parent
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz
6053888/6053168 [==============================] - 0s 0us/step
6062080/6053168 [==============================] - 0s 0us/step
list(dataset_dir.iterdir())
[PosixPath('/tmp/.keras/train'),
 PosixPath('/tmp/.keras/README.md'),
 PosixPath('/tmp/.keras/stack_overflow_16k.tar.gz'),
 PosixPath('/tmp/.keras/test')]
train_dir = dataset_dir/'train'
list(train_dir.iterdir())
[PosixPath('/tmp/.keras/train/java'),
 PosixPath('/tmp/.keras/train/csharp'),
 PosixPath('/tmp/.keras/train/javascript'),
 PosixPath('/tmp/.keras/train/python')]

train/csharp , train/java , train/python i train/javascript katalogi zawierają wiele plików tekstowych, z których każdy jest pytanie przepełnienie stosu.

Wydrukuj przykładowy plik i sprawdź dane:

sample_file = train_dir/'python/1755.txt'

with open(sample_file) as f:
  print(f.read())
why does this blank program print true x=true.def stupid():.    x=false.stupid().print x

Załaduj zbiór danych

Następnie załadujesz dane z dysku i przygotujesz je do formatu odpowiedniego do treningu. Aby to zrobić, można użyć tf.keras.utils.text_dataset_from_directory narzędzie do tworzenia oznaczony tf.data.Dataset . Jeśli jesteś nowym tf.data , to potężny zestaw narzędzi do budowy rurociągów wejściowych. (Dowiedz się więcej w tf.data: Budowa TensorFlow rurociągi wejściowe instrukcji).

tf.keras.utils.text_dataset_from_directory API spodziewa strukturę katalogów w następujący sposób:

train/
...csharp/
......1.txt
......2.txt
...java/
......1.txt
......2.txt
...javascript/
......1.txt
......2.txt
...python/
......1.txt
......2.txt

Podczas prowadzenia eksperymentu uczenia maszynowego, to najlepszym rozwiązaniem jest zestaw danych podzielić na trzy podziały: szkolenia , walidacji i testów .

Zestaw danych Stack Overflow został już podzielony na zestawy treningowe i testowe, ale brakuje w nim zestawu walidacyjnego.

Utwórz zestaw walidacji przy użyciu 80:20 podział danych treningowych za pomocą tf.keras.utils.text_dataset_from_directory z validation_split ustawionym na 0.2 (czyli 20%):

batch_size = 32
seed = 42

raw_train_ds = utils.text_dataset_from_directory(
    train_dir,
    batch_size=batch_size,
    validation_split=0.2,
    subset='training',
    seed=seed)
Found 8000 files belonging to 4 classes.
Using 6400 files for training.

Jak sugeruje poprzedni wynik komórki, w folderze szkoleniowym znajduje się 8000 przykładów, z których 80% (lub 6400) wykorzystasz do szkolenia. Dowiesz się za chwilę, że można trenować model przepuszczając tf.data.Dataset bezpośrednio do Model.fit .

Najpierw przeprowadź iterację zestawu danych i wydrukuj kilka przykładów, aby poznać dane.

for text_batch, label_batch in raw_train_ds.take(1):
  for i in range(10):
    print("Question: ", text_batch.numpy()[i])
    print("Label:", label_batch.numpy()[i])
Question:  b'"my tester is going to the wrong constructor i am new to programming so if i ask a question that can be easily fixed, please forgive me. my program has a tester class with a main. when i send that to my regularpolygon class, it sends it to the wrong constructor. i have two constructors. 1 without perameters..public regularpolygon().    {.       mynumsides = 5;.       mysidelength = 30;.    }//end default constructor...and my second, with perameters. ..public regularpolygon(int numsides, double sidelength).    {.        mynumsides = numsides;.        mysidelength = sidelength;.    }// end constructor...in my tester class i have these two lines:..regularpolygon shape = new regularpolygon(numsides, sidelength);.        shape.menu();...numsides and sidelength were declared and initialized earlier in the testing class...so what i want to happen, is the tester class sends numsides and sidelength to the second constructor and use it in that class. but it only uses the default constructor, which therefor ruins the whole rest of the program. can somebody help me?..for those of you who want to see more of my code: here you go..public double vertexangle().    {.        system.out.println(""the vertex angle method: "" + mynumsides);// prints out 5.        system.out.println(""the vertex angle method: "" + mysidelength); // prints out 30..        double vertexangle;.        vertexangle = ((mynumsides - 2.0) / mynumsides) * 180.0;.        return vertexangle;.    }//end method vertexangle..public void menu().{.    system.out.println(mynumsides); // prints out what the user puts in.    system.out.println(mysidelength); // prints out what the user puts in.    gotographic();.    calcr(mynumsides, mysidelength);.    calcr(mynumsides, mysidelength);.    print(); .}// end menu...this is my entire tester class:..public static void main(string[] arg).{.    int numsides;.    double sidelength;.    scanner keyboard = new scanner(system.in);..    system.out.println(""welcome to the regular polygon program!"");.    system.out.println();..    system.out.print(""enter the number of sides of the polygon ==> "");.    numsides = keyboard.nextint();.    system.out.println();..    system.out.print(""enter the side length of each side ==> "");.    sidelength = keyboard.nextdouble();.    system.out.println();..    regularpolygon shape = new regularpolygon(numsides, sidelength);.    shape.menu();.}//end main...for testing it i sent it numsides 4 and sidelength 100."\n'
Label: 1
Question:  b'"blank code slow skin detection this code changes the color space to lab and using a threshold finds the skin area of an image. but it\'s ridiculously slow. i don\'t know how to make it faster ?    ..from colormath.color_objects import *..def skindetection(img, treshold=80, color=[255,20,147]):..    print img.shape.    res=img.copy().    for x in range(img.shape[0]):.        for y in range(img.shape[1]):.            rgbimg=rgbcolor(img[x,y,0],img[x,y,1],img[x,y,2]).            labimg=rgbimg.convert_to(\'lab\', debug=false).            if (labimg.lab_l > treshold):.                res[x,y,:]=color.            else: .                res[x,y,:]=img[x,y,:]..    return res"\n'
Label: 3
Question:  b'"option and validation in blank i want to add a new option on my system where i want to add two text files, both rental.txt and customer.txt. inside each text are id numbers of the customer, the videotape they need and the price...i want to place it as an option on my code. right now i have:...add customer.rent return.view list.search.exit...i want to add this as my sixth option. say for example i ordered a video, it would display the price and would let me confirm the price and if i am going to buy it or not...here is my current code:..  import blank.io.*;.    import blank.util.arraylist;.    import static blank.lang.system.out;..    public class rentalsystem{.    static bufferedreader input = new bufferedreader(new inputstreamreader(system.in));.    static file file = new file(""file.txt"");.    static arraylist<string> list = new arraylist<string>();.    static int rows;..    public static void main(string[] args) throws exception{.        introduction();.        system.out.print(""nn"");.        login();.        system.out.print(""nnnnnnnnnnnnnnnnnnnnnn"");.        introduction();.        string repeat;.        do{.            loadfile();.            system.out.print(""nwhat do you want to do?nn"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                    |     1. add customer    |   2. rent return |n"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                    |     3. view list       |   4. search      |n"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                                             |   5. exit        |n"");.            system.out.print(""n                                              - - - - - - - - - -"");.            system.out.print(""nnchoice:"");.            int choice = integer.parseint(input.readline());.            switch(choice){.                case 1:.                    writedata();.                    break;.                case 2:.                    rentdata();.                    break;.                case 3:.                    viewlist();.                    break;.                case 4:.                    search();.                    break;.                case 5:.                    system.out.println(""goodbye!"");.                    system.exit(0);.                default:.                    system.out.print(""invalid choice: "");.                    break;.            }.            system.out.print(""ndo another task? [y/n] "");.            repeat = input.readline();.        }while(repeat.equals(""y""));..        if(repeat!=""y"") system.out.println(""ngoodbye!"");..    }..    public static void writedata() throws exception{.        system.out.print(""nname: "");.        string cname = input.readline();.        system.out.print(""address: "");.        string add = input.readline();.        system.out.print(""phone no.: "");.        string pno = input.readline();.        system.out.print(""rental amount: "");.        string ramount = input.readline();.        system.out.print(""tapenumber: "");.        string tno = input.readline();.        system.out.print(""title: "");.        string title = input.readline();.        system.out.print(""date borrowed: "");.        string dborrowed = input.readline();.        system.out.print(""due date: "");.        string ddate = input.readline();.        createline(cname, add, pno, ramount,tno, title, dborrowed, ddate);.        rentdata();.    }..    public static void createline(string name, string address, string phone , string rental, string tapenumber, string title, string borrowed, string due) throws exception{.        filewriter fw = new filewriter(file, true);.        fw.write(""nname: ""+name + ""naddress: "" + address +""nphone no.: ""+ phone+""nrentalamount: ""+rental+""ntape no.: ""+ tapenumber+""ntitle: ""+ title+""ndate borrowed: ""+borrowed +""ndue date: ""+ due+"":rn"");.        fw.close();.    }..    public static void loadfile() throws exception{.        try{.            list.clear();.            fileinputstream fstream = new fileinputstream(file);.            bufferedreader br = new bufferedreader(new inputstreamreader(fstream));.            rows = 0;.            while( br.ready()).            {.                list.add(br.readline());.                rows++;.            }.            br.close();.        } catch(exception e){.            system.out.println(""list not yet loaded."");.        }.    }..    public static void viewlist(){.        system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |list of all costumers|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        for(int i = 0; i <rows; i++){.            system.out.println(list.get(i));.        }.    }.        public static void rentdata()throws exception.    {   system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |rent data list|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print(""nenter customer name: "");.        string cname = input.readline();.        system.out.print(""date borrowed: "");.        string dborrowed = input.readline();.        system.out.print(""due date: "");.        string ddate = input.readline();.        system.out.print(""return date: "");.        string rdate = input.readline();.        system.out.print(""rent amount: "");.        string ramount = input.readline();..        system.out.print(""you pay:""+ramount);...    }.    public static void search()throws exception.    {   system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |search costumers|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print(""nenter costumer name: "");.        string cname = input.readline();.        boolean found = false;..        for(int i=0; i < rows; i++){.            string temp[] = list.get(i).split("","");..            if(cname.equals(temp[0])){.            system.out.println(""search result:nyou are "" + temp[0] + "" from "" + temp[1] + "".""+ temp[2] + "".""+ temp[3] + "".""+ temp[4] + "".""+ temp[5] + "" is "" + temp[6] + "".""+ temp[7] + "" is "" + temp[8] + ""."");.                found = true;.            }.        }..        if(!found){.            system.out.print(""no results."");.        }..    }..        public static boolean evaluate(string uname, string pass){.        if (uname.equals(""admin"")&&pass.equals(""12345"")) return true;.        else return false;.    }..    public static string login()throws exception{.        bufferedreader input=new bufferedreader(new inputstreamreader(system.in));.        int counter=0;.        do{.            system.out.print(""username:"");.            string uname =input.readline();.            system.out.print(""password:"");.            string pass =input.readline();..            boolean accept= evaluate(uname,pass);..            if(accept){.                break;.                }else{.                    system.out.println(""incorrect username or password!"");.                    counter ++;.                    }.        }while(counter<3);..            if(counter !=3) return ""login successful"";.            else return ""login failed"";.            }.        public static void introduction() throws exception{..        system.out.println(""                  - - - - - - - - - - - - - - - - - - - - - - - - -"");.        system.out.println(""                  !                  r e n t a l                  !"");.        system.out.println(""                   ! ~ ~ ~ ~ ~ !  =================  ! ~ ~ ~ ~ ~ !"");.        system.out.println(""                  !                  s y s t e m                  !"");.        system.out.println(""                  - - - - - - - - - - - - - - - - - - - - - - - - -"");.        }..}"\n'
Label: 1
Question:  b'"exception: dynamic sql generation for the updatecommand is not supported against a selectcommand that does not return any key i dont know what is the problem this my code : ..string nomtable;..datatable listeetablissementtable = new datatable();.datatable listeinteretstable = new datatable();.dataset ds = new dataset();.sqldataadapter da;.sqlcommandbuilder cmdb;..private void listeinterets_click(object sender, eventargs e).{.    nomtable = ""listeinteretstable"";.    d.cnx.open();.    da = new sqldataadapter(""select nome from offices"", d.cnx);.    ds = new dataset();.    da.fill(ds, nomtable);.    datagridview1.datasource = ds.tables[nomtable];.}..private void sauvgarder_click(object sender, eventargs e).{.    d.cnx.open();.    cmdb = new sqlcommandbuilder(da);.    da.update(ds, nomtable);.    d.cnx.close();.}"\n'
Label: 0
Question:  b'"parameter with question mark and super in blank, i\'ve come across a method that is formatted like this:..public final subscription subscribe(final action1<? super t> onnext, final action1<throwable> onerror) {.}...in the first parameter, what does the question mark and super mean?"\n'
Label: 1
Question:  b'call two objects wsdl the first time i got a very strange wsdl. ..i would like to call the object (interface - invoicecheck_out) do you know how?....i would like to call the object (variable) do you know how?..try to call (it`s ok)....try to call (how call this?)\n'
Label: 0
Question:  b"how to correctly make the icon for systemtray in blank using icon sizes of any dimension for systemtray doesn't look good overall. .what is the correct way of making icons for windows system tray?..screenshots: http://imgur.com/zsibwn9..icon: http://imgur.com/vsh4zo8\n"
Label: 0
Question:  b'"is there a way to check a variable that exists in a different script than the original one? i\'m trying to check if a variable, which was previously set to true in 2.py in 1.py, as 1.py is only supposed to continue if the variable is true...2.py..import os..completed = false..#some stuff here..completed = true...1.py..import 2 ..if completed == true.   #do things...however i get a syntax error at ..if completed == true"\n'
Label: 3
Question:  b'"blank control flow i made a number which asks for 2 numbers with blank and responds with  the corresponding message for the case. how come it doesnt work  for the second number ? .regardless what i enter for the second number , i am getting the message ""your number is in the range 0-10""...using system;.using system.collections.generic;.using system.linq;.using system.text;..namespace consoleapplication1.{.    class program.    {.        static void main(string[] args).        {.            string myinput;  // declaring the type of the variables.            int myint;..            string number1;.            int number;...            console.writeline(""enter a number"");.            myinput = console.readline(); //muyinput is a string  which is entry input.            myint = int32.parse(myinput); // myint converts the string into an integer..            if (myint > 0).                console.writeline(""your number {0} is greater than zero."", myint);.            else if (myint < 0).                console.writeline(""your number {0} is  less  than zero."", myint);.            else.                console.writeline(""your number {0} is equal zero."", myint);..            console.writeline(""enter another number"");.            number1 = console.readline(); .            number = int32.parse(myinput); ..            if (number < 0 || number == 0).                console.writeline(""your number {0} is  less  than zero or equal zero."", number);.            else if (number > 0 && number <= 10).                console.writeline(""your number {0} is  in the range from 0 to 10."", number);.            else.                console.writeline(""your number {0} is greater than 10."", number);..            console.writeline(""enter another number"");..        }.    }    .}"\n'
Label: 0
Question:  b'"credentials cannot be used for ntlm authentication i am getting org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials cannot be used for ntlm authentication: exception in eclipse..whether it is possible mention eclipse to take system proxy settings directly?..public class httpgetproxy {.    private static final string proxy_host = ""proxy.****.com"";.    private static final int proxy_port = 6050;..    public static void main(string[] args) {.        httpclient client = new httpclient();.        httpmethod method = new getmethod(""https://kodeblank.org"");..        hostconfiguration config = client.gethostconfiguration();.        config.setproxy(proxy_host, proxy_port);..        string username = ""*****"";.        string password = ""*****"";.        credentials credentials = new usernamepasswordcredentials(username, password);.        authscope authscope = new authscope(proxy_host, proxy_port);..        client.getstate().setproxycredentials(authscope, credentials);..        try {.            client.executemethod(method);..            if (method.getstatuscode() == httpstatus.sc_ok) {.                string response = method.getresponsebodyasstring();.                system.out.println(""response = "" + response);.            }.        } catch (ioexception e) {.            e.printstacktrace();.        } finally {.            method.releaseconnection();.        }.    }.}...exception:...  dec 08, 2017 1:41:39 pm .          org.apache.commons.httpclient.auth.authchallengeprocessor selectauthscheme.         info: ntlm authentication scheme selected.       dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector executeconnect.         severe: credentials cannot be used for ntlm authentication: .           org.apache.commons.httpclient.usernamepasswordcredentials.           org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials .         cannot be used for ntlm authentication: .        enter code here .          org.apache.commons.httpclient.usernamepasswordcredentials.      at org.apache.commons.httpclient.auth.ntlmscheme.authenticate(ntlmscheme.blank:332).        at org.apache.commons.httpclient.httpmethoddirector.authenticateproxy(httpmethoddirector.blank:320).      at org.apache.commons.httpclient.httpmethoddirector.executeconnect(httpmethoddirector.blank:491).      at org.apache.commons.httpclient.httpmethoddirector.executewithretry(httpmethoddirector.blank:391).      at org.apache.commons.httpclient.httpmethoddirector.executemethod(httpmethoddirector.blank:171).      at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:397).      at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:323).      at httpgetproxy.main(httpgetproxy.blank:31).  dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector processproxyauthchallenge.  info: failure authenticating with ntlm @proxy.****.com:6050"\n'
Label: 1

Etykiety są 0 , 1 , 2 lub 3 . Aby sprawdzić, które z nich odpowiadają których ciąg etykiet, można sprawdzić class_names obiekt w zbiorze:

for i, label in enumerate(raw_train_ds.class_names):
  print("Label", i, "corresponds to", label)
Label 0 corresponds to csharp
Label 1 corresponds to java
Label 2 corresponds to javascript
Label 3 corresponds to python

Następnie należy utworzyć walidacji i test ustawić za pomocą tf.keras.utils.text_dataset_from_directory . Do walidacji wykorzystasz pozostałe 1600 recenzji z zestawu szkoleniowego.

# Create a validation set.
raw_val_ds = utils.text_dataset_from_directory(
    train_dir,
    batch_size=batch_size,
    validation_split=0.2,
    subset='validation',
    seed=seed)
Found 8000 files belonging to 4 classes.
Using 1600 files for validation.
test_dir = dataset_dir/'test'

# Create a test set.
raw_test_ds = utils.text_dataset_from_directory(
    test_dir,
    batch_size=batch_size)
Found 8000 files belonging to 4 classes.

Przygotuj zbiór danych do szkolenia

Następnie można standaryzować, tokenize i wektoryzacji danych za pomocą tf.keras.layers.TextVectorization warstwę.

  • Standaryzacja oznacza przerób tekst, zwykle w celu usunięcia interpunkcyjne lub HTML elementy uproszczenie zestawu danych.
  • Atomizacja odnosi się do łańcuchów do przecinania się tokenów (na przykład rozszczepiania kary na poszczególne słowa poprzez rozdzielenie na spacji).
  • Wektoryzacji się do konwersji znaki na liczby, aby mogły być wprowadzone do sieci neuronowych.

Wszystkie te zadania można wykonać za pomocą tej warstwy. (Możesz dowiedzieć się więcej na temat każdego z nich w tf.keras.layers.TextVectorization Dokumentacja API).

Zwróć uwagę, że:

  • Tekst domyślny nawróceni normalizacyjne na małe i usuwa znaki interpunkcyjne ( standardize='lower_and_strip_punctuation' ).
  • Domyślną tokenizer podziałów na spacji ( split='whitespace' ).
  • Domyślnym trybem jest wektoryzacja 'int' ( output_mode='int' ). Daje to indeksy liczb całkowitych (jeden na token). Ten tryb może służyć do budowania modeli uwzględniających szyk wyrazów. Można również skorzystać z innych trybów podobny 'binary' -to kompilacji bag-of-słów modeli.

Będziesz budować dwa modele, aby dowiedzieć się więcej o normalizacji, tokeny i wektoryzacji z TextVectorization :

  • Po pierwsze, można użyć 'binary' tryb wektoryzacja zbudować bag-of-słów model.
  • Następnie można użyć 'int' tryb z 1D ConvNet.
VOCAB_SIZE = 10000

binary_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='binary')

Dla 'int' trybie, oprócz maksymalnego rozmiaru słownictwa, trzeba ustawić wyraźny długość sekwencji maksymalna ( MAX_SEQUENCE_LENGTH ), który spowoduje, że warstwa pad lub obciąć sekwencje dokładnie output_sequence_length wartości:

MAX_SEQUENCE_LENGTH = 250

int_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

Następnie zadzwonić TextVectorization.adapt aby pasowały do stanu warstwy przerób do zbioru danych. Spowoduje to, że model zbuduje indeks ciągów do liczb całkowitych.

# Make a text-only dataset (without labels), then call `TextVectorization.adapt`.
train_text = raw_train_ds.map(lambda text, labels: text)
binary_vectorize_layer.adapt(train_text)
int_vectorize_layer.adapt(train_text)

Wydrukuj wynik użycia tych warstw do wstępnego przetwarzania danych:

def binary_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return binary_vectorize_layer(text), label
def int_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return int_vectorize_layer(text), label
# Retrieve a batch (of 32 reviews and labels) from the dataset.
text_batch, label_batch = next(iter(raw_train_ds))
first_question, first_label = text_batch[0], label_batch[0]
print("Question", first_question)
print("Label", first_label)
Question tf.Tensor(b'"what is the difference between these two ways to create an element? var a = document.createelement(\'div\');..a.id = ""mydiv"";...and..var a = document.createelement(\'div\').id = ""mydiv"";...what is the difference between them such that the first one works and the second one doesn\'t?"\n', shape=(), dtype=string)
Label tf.Tensor(2, shape=(), dtype=int32)
print("'binary' vectorized question:",
      binary_vectorize_text(first_question, first_label)[0])
'binary' vectorized question: tf.Tensor([[1. 1. 0. ... 0. 0. 0.]], shape=(1, 10000), dtype=float32)
print("'int' vectorized question:",
      int_vectorize_text(first_question, first_label)[0])
'int' vectorized question: tf.Tensor(
[[ 55   6   2 410 211 229 121 895   4 124  32 245  43   5   1   1   5   1
    1   6   2 410 211 191 318  14   2  98  71 188   8   2 199  71 178   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]], shape=(1, 250), dtype=int64)

Jak pokazano powyżej, TextVectorization jest 'binary' tryb Zwraca oznaczający Tablica które istnieją znaczniki co najmniej raz na wejściu, a 'int' tryb Następnie każdy znacznik przez liczbę całkowitą w celu zachowania ich kolejności.

Można odnośnika token (string), że każda liczba odpowiada dzwoniąc TextVectorization.get_vocabulary na warstwie:

print("1289 ---> ", int_vectorize_layer.get_vocabulary()[1289])
print("313 ---> ", int_vectorize_layer.get_vocabulary()[313])
print("Vocabulary size: {}".format(len(int_vectorize_layer.get_vocabulary())))
1289 --->  roman
313 --->  source
Vocabulary size: 10000

Jesteś już prawie gotowy do trenowania swojego modelu.

W końcowym etapie przetwarzania wstępnego, można zastosować TextVectorization utworzonych wcześniej zestawów szkolenia, walidacji i testów warstw:

binary_train_ds = raw_train_ds.map(binary_vectorize_text)
binary_val_ds = raw_val_ds.map(binary_vectorize_text)
binary_test_ds = raw_test_ds.map(binary_vectorize_text)

int_train_ds = raw_train_ds.map(int_vectorize_text)
int_val_ds = raw_val_ds.map(int_vectorize_text)
int_test_ds = raw_test_ds.map(int_vectorize_text)

Skonfiguruj zbiór danych pod kątem wydajności

Są to dwie ważne metody, których należy użyć podczas ładowania danych, aby upewnić się, że operacje we/wy nie zostaną zablokowane.

  • Dataset.cache przechowuje dane w pamięci po jego załadowaniu off dysku. Zapewni to, że zestaw danych nie stanie się wąskim gardłem podczas trenowania modelu. Jeśli zestaw danych jest zbyt duży, aby zmieścić się w pamięci, możesz również użyć tej metody, aby utworzyć wydajną pamięć podręczną na dysku, która jest bardziej wydajna do odczytu niż wiele małych plików.
  • Dataset.prefetch pokrywa danych przerób i model wykonanie podczas treningu.

Możesz dowiedzieć się więcej na temat obu metod, a także w jaki sposób buforowania danych na dysku w sekcji prefetching o wydajności lepiej z tf.data API przewodnika.

AUTOTUNE = tf.data.AUTOTUNE

def configure_dataset(dataset):
  return dataset.cache().prefetch(buffer_size=AUTOTUNE)
binary_train_ds = configure_dataset(binary_train_ds)
binary_val_ds = configure_dataset(binary_val_ds)
binary_test_ds = configure_dataset(binary_test_ds)

int_train_ds = configure_dataset(int_train_ds)
int_val_ds = configure_dataset(int_val_ds)
int_test_ds = configure_dataset(int_test_ds)

Trenuj modelkę

Czas stworzyć swoją sieć neuronową.

Dla 'binary' vectorized danych, definiowanie prosty bag-of-słów model liniowy, następnie skonfigurować i szkolić go:

binary_model = tf.keras.Sequential([layers.Dense(4)])

binary_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

history = binary_model.fit(
    binary_train_ds, validation_data=binary_val_ds, epochs=10)
Epoch 1/10
200/200 [==============================] - 2s 5ms/step - loss: 1.1204 - accuracy: 0.6480 - val_loss: 0.9132 - val_accuracy: 0.7819
Epoch 2/10
200/200 [==============================] - 0s 2ms/step - loss: 0.7799 - accuracy: 0.8194 - val_loss: 0.7490 - val_accuracy: 0.7956
Epoch 3/10
200/200 [==============================] - 0s 2ms/step - loss: 0.6284 - accuracy: 0.8614 - val_loss: 0.6635 - val_accuracy: 0.8081
Epoch 4/10
200/200 [==============================] - 0s 2ms/step - loss: 0.5349 - accuracy: 0.8861 - val_loss: 0.6102 - val_accuracy: 0.8200
Epoch 5/10
200/200 [==============================] - 0s 2ms/step - loss: 0.4688 - accuracy: 0.9036 - val_loss: 0.5736 - val_accuracy: 0.8306
Epoch 6/10
200/200 [==============================] - 0s 2ms/step - loss: 0.4185 - accuracy: 0.9173 - val_loss: 0.5471 - val_accuracy: 0.8331
Epoch 7/10
200/200 [==============================] - 0s 2ms/step - loss: 0.3782 - accuracy: 0.9291 - val_loss: 0.5270 - val_accuracy: 0.8356
Epoch 8/10
200/200 [==============================] - 0s 2ms/step - loss: 0.3449 - accuracy: 0.9366 - val_loss: 0.5115 - val_accuracy: 0.8406
Epoch 9/10
200/200 [==============================] - 0s 2ms/step - loss: 0.3166 - accuracy: 0.9430 - val_loss: 0.4993 - val_accuracy: 0.8419
Epoch 10/10
200/200 [==============================] - 0s 2ms/step - loss: 0.2922 - accuracy: 0.9484 - val_loss: 0.4896 - val_accuracy: 0.8413

Następnie można użyć 'int' wektorowy warstwę zbudować 1D ConvNet:

def create_model(vocab_size, num_labels):
  model = tf.keras.Sequential([
      layers.Embedding(vocab_size, 64, mask_zero=True),
      layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
      layers.GlobalMaxPooling1D(),
      layers.Dense(num_labels)
  ])
  return model
# `vocab_size` is `VOCAB_SIZE + 1` since `0` is used additionally for padding.
int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=4)
int_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)
Epoch 1/5
200/200 [==============================] - 3s 4ms/step - loss: 1.1914 - accuracy: 0.4891 - val_loss: 0.7911 - val_accuracy: 0.6869
Epoch 2/5
200/200 [==============================] - 1s 3ms/step - loss: 0.6364 - accuracy: 0.7548 - val_loss: 0.5485 - val_accuracy: 0.7975
Epoch 3/5
200/200 [==============================] - 1s 3ms/step - loss: 0.3837 - accuracy: 0.8802 - val_loss: 0.4838 - val_accuracy: 0.8075
Epoch 4/5
200/200 [==============================] - 1s 3ms/step - loss: 0.2152 - accuracy: 0.9483 - val_loss: 0.4821 - val_accuracy: 0.8156
Epoch 5/5
200/200 [==============================] - 1s 3ms/step - loss: 0.1084 - accuracy: 0.9820 - val_loss: 0.5071 - val_accuracy: 0.8125

Porównaj dwa modele:

print("Linear model on binary vectorized data:")
print(binary_model.summary())
Linear model on binary vectorized data:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 4)                 40004     
=================================================================
Total params: 40,004
Trainable params: 40,004
Non-trainable params: 0
_________________________________________________________________
None
print("ConvNet model on int vectorized data:")
print(int_model.summary())
ConvNet model on int vectorized data:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 64)          640064    
_________________________________________________________________
conv1d (Conv1D)              (None, None, 64)          20544     
_________________________________________________________________
global_max_pooling1d (Global (None, 64)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 260       
=================================================================
Total params: 660,868
Trainable params: 660,868
Non-trainable params: 0
_________________________________________________________________
None

Oceń oba modele na podstawie danych testowych:

binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds)
int_loss, int_accuracy = int_model.evaluate(int_test_ds)

print("Binary model accuracy: {:2.2%}".format(binary_accuracy))
print("Int model accuracy: {:2.2%}".format(int_accuracy))
250/250 [==============================] - 1s 3ms/step - loss: 0.5174 - accuracy: 0.8146
250/250 [==============================] - 1s 2ms/step - loss: 0.5146 - accuracy: 0.8131
Binary model accuracy: 81.46%
Int model accuracy: 81.31%

Eksportuj model

W powyższym kodzie, należy zastosować tf.keras.layers.TextVectorization do zbioru danych przed wprowadzeniem tekstu do modelu. Jeśli chcesz, aby twój model zdolny do przetwarzania surowców struny (na przykład, w celu uproszczenia wdrażania go), można zawierać TextVectorization warstwę wewnątrz modelu.

Aby to zrobić, możesz utworzyć nowy model, używając wag, które właśnie wytrenowałeś:

export_model = tf.keras.Sequential(
    [binary_vectorize_layer, binary_model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])

# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_test_ds)
print("Accuracy: {:2.2%}".format(binary_accuracy))
250/250 [==============================] - 1s 4ms/step - loss: 0.5174 - accuracy: 0.8146
Accuracy: 81.46%

Teraz Twój model może podjąć surowe ciągi jako wejście i przewidzieć wynik dla każdej etykiety przy użyciu Model.predict . Zdefiniuj funkcję, aby znaleźć etykietę z maksymalnym wynikiem:

def get_string_labels(predicted_scores_batch):
  predicted_int_labels = tf.argmax(predicted_scores_batch, axis=1)
  predicted_labels = tf.gather(raw_train_ds.class_names, predicted_int_labels)
  return predicted_labels

Uruchom wnioskowanie na nowych danych

inputs = [
    "how do I extract keys from a dict into a list?",  # 'python'
    "debug public static void main(string[] args) {...}",  # 'java'
]
predicted_scores = export_model.predict(inputs)
predicted_labels = get_string_labels(predicted_scores)
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label.numpy())
Question:  how do I extract keys from a dict into a list?
Predicted label:  b'python'
Question:  debug public static void main(string[] args) {...}
Predicted label:  b'java'

W tym logikę tekstu przerób wewnątrz modelu pozwala na eksport modelu do produkcji, które upraszcza wdrażanie i zmniejsza potencjał do pociągu / test skośnej .

Jest różnica wydajności, aby pamiętać przy wyborze gdzie stosuje tf.keras.layers.TextVectorization . Używanie go poza modelem umożliwia asynchroniczne przetwarzanie procesora i buforowanie danych podczas uczenia na GPU. Tak więc, jeśli trenujesz swój model na GPU, prawdopodobnie chcesz iść z tej opcji, aby uzyskać najlepszą wydajność przy opracowywaniu modelu, a następnie przełączyć się na tym TextVectorization warstwę wewnątrz modelu, kiedy jesteś gotowy, aby przygotować się do wdrożenia .

Odwiedź Zapisz i modele obciążeń samouczka, aby dowiedzieć się więcej o zapisywaniu modeli.

Przykład 2: Przewiduj autora tłumaczeń Iliady

Poniżej przedstawiono przykład użycia tf.data.TextLineDataset przykładów obciążenia z plików tekstowych i TensorFlow Tekst do Preprocesuj dane. Użyjesz trzech różnych angielskich tłumaczeń tego samego dzieła, Iliady Homera, i wytrenujesz model identyfikowania tłumacza na podstawie pojedynczej linijki tekstu.

Pobierz i poznaj zbiór danych

Teksty trzech tłumaczeń są autorstwa:

Pliki tekstowe użyte w tym samouczku przeszły kilka typowych zadań wstępnego przetwarzania, takich jak usunięcie nagłówka i stopki dokumentu, numerów wierszy i tytułów rozdziałów.

Pobierz lokalnie te lekko zmodyfikowane pliki:

DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']

for name in FILE_NAMES:
  text_dir = utils.get_file(name, origin=DIRECTORY_URL + name)

parent_dir = pathlib.Path(text_dir).parent
list(parent_dir.iterdir())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
827392/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
819200/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step
819200/807992 [==============================] - 0s 0us/step
[PosixPath('/home/kbuilder/.keras/datasets/derby.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/flower_photos.tar.gz'),
 PosixPath('/home/kbuilder/.keras/datasets/butler.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/flower_photos'),
 PosixPath('/home/kbuilder/.keras/datasets/image.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/cowper.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/ImageNetLabels.txt')]

Załaduj zbiór danych

Wcześniej z tf.keras.utils.text_dataset_from_directory cała zawartość pliku były traktowane jako jeden przykład. Tutaj będziesz używać tf.data.TextLineDataset , który został zaprojektowany, aby stworzyć tf.data.Dataset z pliku tekstowego, w którym każdy przykładem jest linia tekstu z oryginalnego pliku. TextLineDataset jest przydatna dla danych tekstowych, które są oparte przede wszystkim linia (na przykład, poezja czy dzienników błędów).

Iteruj przez te pliki, ładując każdy z nich do własnego zestawu danych. Każdy przykład musi być indywidualnie oznakowane, więc używaj Dataset.map zastosować funkcję Labeler do każdego z nich. To iteracyjnego Każdy przykład w zestawie danych powrotu ( example, label ) par.

def labeler(example, index):
  return example, tf.cast(index, tf.int64)
labeled_data_sets = []

for i, file_name in enumerate(FILE_NAMES):
  lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name))
  labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))
  labeled_data_sets.append(labeled_dataset)

Następnie musisz połączyć te oznaczone zbiorów danych w jednym zbiorze danych za pomocą Dataset.concatenate i przetasować je Dataset.shuffle :

BUFFER_SIZE = 50000
BATCH_SIZE = 64
VALIDATION_SIZE = 5000
all_labeled_data = labeled_data_sets[0]
for labeled_dataset in labeled_data_sets[1:]:
  all_labeled_data = all_labeled_data.concatenate(labeled_dataset)

all_labeled_data = all_labeled_data.shuffle(
    BUFFER_SIZE, reshuffle_each_iteration=False)

Wydrukuj kilka przykładów jak poprzednio. Zbiór danych nie został jeszcze batched, stąd każdy wpis w all_labeled_data odpowiada jednemu punktu danych:

for text, label in all_labeled_data.take(10):
  print("Sentence: ", text.numpy())
  print("Label:", label.numpy())
Sentence:  b'By hostile hands laid prostrate in the dust,'
Label: 1
Sentence:  b"Watch over her no longer; all are gain'd"
Label: 1
Sentence:  b"Her home, and parents; o'er her head she threw"
Label: 1
Sentence:  b'Diomed himself with glory.'
Label: 2
Sentence:  b'therefore, the Trojans and Lycians on the one hand, and the Myrmidons'
Label: 2
Sentence:  b'Though now in blissful ignorance they feast."'
Label: 1
Sentence:  b'rich with bronze and his panting steeds in charge of Eurymedon, son of'
Label: 2
Sentence:  b'In thine esteem, and sin against the Gods."'
Label: 1
Sentence:  b'Him Hebe bathed, and with divine attire'
Label: 0
Sentence:  b"The host all seated, and the benches fill'd,"
Label: 0

Przygotuj zbiór danych do szkolenia

Zamiast używać tf.keras.layers.TextVectorization do Preprocesuj zestawu danych tekstowych, można teraz korzystać z API TensorFlow tekst do standaryzacji i tokenize danych, budowania słownictwa i wykorzystywać tf.lookup.StaticVocabularyTable mapowania znaków do liczb całkowitych do paszy do Model. (Dowiedz się więcej o TensorFlow Tekst ).

Zdefiniuj funkcję konwertującą tekst na małe litery i tokenizujący go:

  • TensorFlow Text udostępnia różne tokenizatory. W tym przykładzie, można użyć text.UnicodeScriptTokenizer do tokenize zestawu danych.
  • Będziesz korzystać Dataset.map zastosować tokenizacja do zbioru danych.
tokenizer = tf_text.UnicodeScriptTokenizer()
def tokenize(text, unused_label):
  lower_case = tf_text.case_fold_utf8(text)
  return tokenizer.tokenize(lower_case)
tokenized_ds = all_labeled_data.map(tokenize)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206: batch_gather (from tensorflow.python.ops.array_ops) is deprecated and will be removed after 2017-10-25.
Instructions for updating:
`tf.batch_gather` is deprecated, please use `tf.gather` with `batch_dims=-1` instead.

Możesz iterować po zbiorze danych i wydrukować kilka stokenizowanych przykładów:

for text_batch in tokenized_ds.take(5):
  print("Tokens: ", text_batch.numpy())
Tokens:  [b'by' b'hostile' b'hands' b'laid' b'prostrate' b'in' b'the' b'dust' b',']
Tokens:  [b'watch' b'over' b'her' b'no' b'longer' b';' b'all' b'are' b'gain' b"'"
 b'd']
Tokens:  [b'her' b'home' b',' b'and' b'parents' b';' b'o' b"'" b'er' b'her' b'head'
 b'she' b'threw']
Tokens:  [b'diomed' b'himself' b'with' b'glory' b'.']
Tokens:  [b'therefore' b',' b'the' b'trojans' b'and' b'lycians' b'on' b'the' b'one'
 b'hand' b',' b'and' b'the' b'myrmidons']

Następnie należy zbudować słownictwo sortując tokeny według częstotliwości i utrzymanie najlepszych VOCAB_SIZE tokeny:

tokenized_ds = configure_dataset(tokenized_ds)

vocab_dict = collections.defaultdict(lambda: 0)
for toks in tokenized_ds.as_numpy_iterator():
  for tok in toks:
    vocab_dict[tok] += 1

vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True)
vocab = [token for token, count in vocab]
vocab = vocab[:VOCAB_SIZE]
vocab_size = len(vocab)
print("Vocab size: ", vocab_size)
print("First five vocab entries:", vocab[:5])
Vocab size:  10000
First five vocab entries: [b',', b'the', b'and', b"'", b'of']

Aby przekonwertować żetony do liczb całkowitych, użyj vocab zestaw do tworzenia tf.lookup.StaticVocabularyTable . Będziesz map tokeny do liczb całkowitych w przedziale [ 2 , vocab_size + 2 ]. Podobnie jak w przypadku TextVectorization warstwie 0 jest zarezerwowany dla oznaczenia wyściółki i 1 jest zarezerwowany dla oznaczenia out-of-słownika (OOV) znaczników.

keys = vocab
values = range(2, len(vocab) + 2)  # Reserve `0` for padding, `1` for OOV tokens.

init = tf.lookup.KeyValueTensorInitializer(
    keys, values, key_dtype=tf.string, value_dtype=tf.int64)

num_oov_buckets = 1
vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)

Na koniec zdefiniuj funkcję do standaryzacji, tokenizacji i wektoryzacji zbioru danych za pomocą tokenizera i tabeli przeglądowej:

def preprocess_text(text, label):
  standardized = tf_text.case_fold_utf8(text)
  tokenized = tokenizer.tokenize(standardized)
  vectorized = vocab_table.lookup(tokenized)
  return vectorized, label

Możesz spróbować tego na jednym przykładzie, aby wydrukować dane wyjściowe:

example_text, example_label = next(iter(all_labeled_data))
print("Sentence: ", example_text.numpy())
vectorized_text, example_label = preprocess_text(example_text, example_label)
print("Vectorized sentence: ", vectorized_text.numpy())
Sentence:  b'By hostile hands laid prostrate in the dust,'
Vectorized sentence:  [  26 1007  146  339 1560   13    3  317    2]

Teraz uruchom funkcję Preprocesuj na zbiorze danych z wykorzystaniem Dataset.map :

all_encoded_data = all_labeled_data.map(preprocess_text)

Podziel zbiór danych na zestawy treningowe i testowe

Keras TextVectorization warstwą również partie i tarcze wektorowy danych. Dopełnienie jest wymagane, ponieważ przykłady w partii muszą mieć ten sam rozmiar i kształt, ale przykłady w tych zestawach danych nie mają tego samego rozmiaru — każdy wiersz tekstu ma inną liczbę słów.

tf.data.Dataset obsługuje dzielenie i wyściełane-grupujące zbiorów danych:

train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)
validation_data = all_encoded_data.take(VALIDATION_SIZE)
train_data = train_data.padded_batch(BATCH_SIZE)
validation_data = validation_data.padded_batch(BATCH_SIZE)

Teraz validation_data i train_data nie zbiorów ( example, label ) par, ale zbiory partii. Każda partia jest parą (wiele przykładów, wiele etykiet) w postaci tablic.

Aby to zilustrować:

sample_text, sample_labels = next(iter(validation_data))
print("Text batch shape: ", sample_text.shape)
print("Label batch shape: ", sample_labels.shape)
print("First text example: ", sample_text[0])
print("First label example: ", sample_labels[0])
Text batch shape:  (64, 19)
Label batch shape:  (64,)
First text example:  tf.Tensor(
[  26 1007  146  339 1560   13    3  317    2    0    0    0    0    0
    0    0    0    0    0], shape=(19,), dtype=int64)
First label example:  tf.Tensor(1, shape=(), dtype=int64)

Ponieważ używasz 0 dla dopełnienia i 1 dla out-of-słownictwa (OOV) żetonów, wielkość słownictwo wzrosła o dwóch:

vocab_size += 2

Skonfiguruj zestawy danych, aby uzyskać lepszą wydajność, jak poprzednio:

train_data = configure_dataset(train_data)
validation_data = configure_dataset(validation_data)

Trenuj modelkę

Możesz trenować model na tym zbiorze danych jak poprzednio:

model = create_model(vocab_size=vocab_size, num_labels=3)

model.compile(
    optimizer='adam',
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

history = model.fit(train_data, validation_data=validation_data, epochs=3)
Epoch 1/3
697/697 [==============================] - 26s 8ms/step - loss: 0.5180 - accuracy: 0.7704 - val_loss: 0.3602 - val_accuracy: 0.8434
Epoch 2/3
697/697 [==============================] - 2s 3ms/step - loss: 0.2801 - accuracy: 0.8854 - val_loss: 0.3480 - val_accuracy: 0.8538
Epoch 3/3
697/697 [==============================] - 2s 3ms/step - loss: 0.1896 - accuracy: 0.9269 - val_loss: 0.3800 - val_accuracy: 0.8500
loss, accuracy = model.evaluate(validation_data)

print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 1s 2ms/step - loss: 0.3800 - accuracy: 0.8500
Loss:  0.3800220191478729
Accuracy: 85.00%

Eksportuj model

Aby model zdolny do podejmowania surowych ciągi jako dane wejściowe, można utworzyć Keras TextVectorization warstwę, która wykonuje te same czynności, jak swojej funkcji przerób niestandardowej. Skoro masz już przeszkoleni słownictwa, można użyć TextVectorization.set_vocabulary (zamiast TextVectorization.adapt ), które pociągi nowego słownictwa.

preprocess_layer = TextVectorization(
    max_tokens=vocab_size,
    standardize=tf_text.case_fold_utf8,
    split=tokenizer.tokenize,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

preprocess_layer.set_vocabulary(vocab)
export_model = tf.keras.Sequential(
    [preprocess_layer, model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])
# Create a test dataset of raw strings.
test_ds = all_labeled_data.take(VALIDATION_SIZE).batch(BATCH_SIZE)
test_ds = configure_dataset(test_ds)

loss, accuracy = export_model.evaluate(test_ds)

print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
2021-10-14 01:25:02.750371: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: sequential_4/text_vectorization_2/UnicodeScriptTokenize/Assert_1/AssertGuard/branch_executed/_185
79/79 [==============================] - 6s 15ms/step - loss: 0.4715 - accuracy: 0.8116
Loss:  0.4715384840965271
Accuracy: 81.16%

Zgodnie z oczekiwaniami straty i dokładność modelu w zakodowanym zestawie walidacyjnym i wyeksportowanego modelu w surowym zestawie walidacyjnym są takie same.

Uruchom wnioskowanie na nowych danych

inputs = [
    "Join'd to th' Ionians with their flowing robes,",  # Label: 1
    "the allies, and his armour flashed about him so that he seemed to all",  # Label: 2
    "And with loud clangor of his arms he fell.",  # Label: 0
]

predicted_scores = export_model.predict(inputs)
predicted_labels = tf.argmax(predicted_scores, axis=1)

for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label.numpy())
2021-10-14 01:25:06.331899: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: sequential_4/text_vectorization_2/UnicodeScriptTokenize/Assert_1/AssertGuard/branch_executed/_185
Question:  Join'd to th' Ionians with their flowing robes,
Predicted label:  1
Question:  the allies, and his armour flashed about him so that he seemed to all
Predicted label:  2
Question:  And with loud clangor of his arms he fell.
Predicted label:  0

Pobierz więcej zestawów danych za pomocą zestawów danych TensorFlow (TFDS)

Można pobrać wiele więcej zestawów danych z TensorFlow zbiorów danych .

W tym przykładzie, można użyć IMDB Duży zbiór danych Przegląd filmów trenować model klasyfikacji nastrojów:

# Training set.
train_ds = tfds.load(
    'imdb_reviews',
    split='train[:80%]',
    batch_size=BATCH_SIZE,
    shuffle_files=True,
    as_supervised=True)
# Validation set.
val_ds = tfds.load(
    'imdb_reviews',
    split='train[80%:]',
    batch_size=BATCH_SIZE,
    shuffle_files=True,
    as_supervised=True)

Wydrukuj kilka przykładów:

for review_batch, label_batch in val_ds.take(1):
  for i in range(5):
    print("Review: ", review_batch[i].numpy())
    print("Label: ", label_batch[i].numpy())
Review:  b"Instead, go to the zoo, buy some peanuts and feed 'em to the monkeys. Monkeys are funny. People with amnesia who don't say much, just sit there with vacant eyes are not all that funny.<br /><br />Black comedy? There isn't a black person in it, and there isn't one funny thing in it either.<br /><br />Walmart buys these things up somehow and puts them on their dollar rack. It's labeled Unrated. I think they took out the topless scene. They may have taken out other stuff too, who knows? All we know is that whatever they took out, isn't there any more.<br /><br />The acting seemed OK to me. There's a lot of unfathomables tho. It's supposed to be a city? It's supposed to be a big lake? If it's so hot in the church people are fanning themselves, why are they all wearing coats?"
Label:  0
Review:  b'I remember stumbling upon this special while channel-surfing in 1965. I had never heard of Barbra before. When the show was over, I thought "This is probably the best thing on TV I will ever see in my life." 42 years later, that has held true. There is still nothing so amazing, so honestly astonishing as the talent that was displayed here. You can talk about all the super-stars you want to, this is the most superlative of them all!<br /><br />You name it, she can do it. Comedy, pathos, sultry seduction, ballads, Barbra is truly a story-teller. Her ability to pull off anything she attempts is legendary. But this special was made in the beginning, and helped to create the legend that she quickly became. In spite of rising so far in such a short time, she has fulfilled the promise, revealing more of her talents as she went along. But they are all here from the very beginning. You will not be disappointed in viewing this.'
Label:  1
Review:  b"I'm sorry but I didn't like this doc very much. I can think of a million ways it could have been better. The people who made it obviously don't have much imagination. The interviews aren't very interesting and no real insight is offered. The footage isn't assembled in a very informative way, either. It's too bad because this is a movie that really deserves spellbinding special features. One thing I'll say is that Isabella Rosselini gets more beautiful the older she gets. All considered, this only gets a '4.'"
Label:  0
Review:  b'This movie had all the elements to be a smart, sparkling comedy, but for some reason it took the dumbass route. Perhaps it didn\'t really know who its audience was: but it\'s hardly a man\'s movie given the cast and plot, yet is too slapstick and dumb-blonde to appeal fully to women.<br /><br />If you have seen Legally Blonde and its sequel, then this is like the bewilderingly awful sequel. Great actors such as Luke Wilson should expect better material. Jessica Simpson could also have managed so much more. Rachael Leigh Cook and Penelope Anne Miller languish in supporting roles that are silly rather than amusing.<br /><br />Many things in this movie were paint-by-numbers, the various uber-clich\xc3\xa9 montages, the last minute "misunderstanding", even the kids\' party chaos. This just suggests lazy scriptwriting.<br /><br />It should be possible to find this movie enjoyable if you don\'t take it seriously, but it\'s such a glaring could-do-better than you\'ll likely feel frustrated and increasingly disappointed as the scenes roll past.'
Label:  0
Review:  b'There is absolutely no plot in this movie ...no character development...no climax...nothing. But has a few good fighting scenes that are actually pretty good. So there you go...as a movie overall is pretty bad, but if you like a brainless flick that offer nothing but just good action scene then watch this movie. Do not expect nothing more that just that.Decent acting and a not so bad direction..A couple of cameos from Kimbo and Carano...I was looking to see Carano a little bit more in this movie..she is a good fighter and a really hot girl.... White is a great martial artist and a decent actor. I really hope he can land a better movie in the future so we can really enjoy his art..Imagine a film with White and Jaa together...that would be awesome'
Label:  0

Możesz teraz wstępnie przetwarzać dane i trenować model tak jak poprzednio.

Przygotuj zbiór danych do szkolenia

vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

# Make a text-only dataset (without labels), then call `TextVectorization.adapt`.
train_text = train_ds.map(lambda text, labels: text)
vectorize_layer.adapt(train_text)
def vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return vectorize_layer(text), label
train_ds = train_ds.map(vectorize_text)
val_ds = val_ds.map(vectorize_text)
# Configure datasets for performance as before.
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)

Twórz, konfiguruj i trenuj model

model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=1)
model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, None, 64)          640064    
_________________________________________________________________
conv1d_2 (Conv1D)            (None, None, 64)          20544     
_________________________________________________________________
global_max_pooling1d_2 (Glob (None, 64)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 65        
=================================================================
Total params: 660,673
Trainable params: 660,673
Non-trainable params: 0
_________________________________________________________________
model.compile(
    loss=losses.BinaryCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = model.fit(train_ds, validation_data=val_ds, epochs=3)
Epoch 1/3
313/313 [==============================] - 3s 6ms/step - loss: 0.5411 - accuracy: 0.6645 - val_loss: 0.3777 - val_accuracy: 0.8310
Epoch 2/3
313/313 [==============================] - 1s 3ms/step - loss: 0.2992 - accuracy: 0.8698 - val_loss: 0.3194 - val_accuracy: 0.8592
Epoch 3/3
313/313 [==============================] - 1s 3ms/step - loss: 0.1811 - accuracy: 0.9298 - val_loss: 0.3261 - val_accuracy: 0.8622
loss, accuracy = model.evaluate(val_ds)

print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 0s 1ms/step - loss: 0.3261 - accuracy: 0.8622
Loss:  0.3261321783065796
Accuracy: 86.22%

Eksportuj model

export_model = tf.keras.Sequential(
    [vectorize_layer, model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])
# 0 --> negative review
# 1 --> positive review
inputs = [
    "This is a fantastic movie.",
    "This is a bad movie.",
    "This movie was so bad that it was good.",
    "I will never say yes to watching this movie.",
]

predicted_scores = export_model.predict(inputs)
predicted_labels = [int(round(x[0])) for x in predicted_scores]

for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label)
Question:  This is a fantastic movie.
Predicted label:  1
Question:  This is a bad movie.
Predicted label:  0
Question:  This movie was so bad that it was good.
Predicted label:  0
Question:  I will never say yes to watching this movie.
Predicted label:  0

Wniosek

W tym samouczku przedstawiono kilka sposobów ładowania i wstępnego przetwarzania tekstu. W następnym kroku można zbadać dodatkowy tekst przerób TensorFlow tekstowych tutoriale, takie jak:

Można również znaleźć nowe zestawy danych na temat TensorFlow zbiorów danych . I, aby dowiedzieć się więcej o tf.data , sprawdź przewodnik na budowę rurociągów wejściowych .