Mam pytanie? Połącz się ze społecznością na Forum TensorFlow Odwiedź Forum

Wczytaj tekst

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

W tym samouczku przedstawiono dwa sposoby ładowania i wstępnego przetwarzania tekstu.

  • Najpierw użyjesz narzędzi i warstw Keras. Jeśli jesteś nowy w TensorFlow, powinieneś zacząć od tych.

  • Następnie użyjesz narzędzi niższego poziomu, takich jak tf.data.TextLineDataset do ładowania plików tekstowych i tf.text do wstępnego przetwarzania danych w celu uzyskania dokładniejszej kontroli.

# Be sure you're using the stable versions of both tf and tf-text, for binary compatibility.
pip install -q -U tf-nightly
pip install -q -U tensorflow-text-nightly
import collections
import pathlib
import re
import string

import tensorflow as tf

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras import utils
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

import tensorflow_datasets as tfds
import tensorflow_text as tf_text

Przykład 1: Przewidywanie tagu dla pytania o przepełnienie stosu

Jako pierwszy przykład pobierzesz zestaw danych z pytaniami programistycznymi ze Stack Overflow. Każde pytanie („Jak posortować słownik według wartości?”) jest oznaczone dokładnie jednym tagiem ( Python , CSharp , JavaScript lub Java ). Twoim zadaniem jest opracowanie modelu, który przewiduje tag dla pytania. Jest to przykład klasyfikacji wieloklasowej, ważnego i szeroko stosowanego problemu uczenia maszynowego.

Pobierz i poznaj zbiór danych

Następnie pobierzesz zestaw danych i zbadasz strukturę katalogów.

data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz'
dataset = utils.get_file(
    'stack_overflow_16k.tar.gz',
    data_url,
    untar=True,
    cache_dir='stack_overflow',
    cache_subdir='')
dataset_dir = pathlib.Path(dataset).parent
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz
6053888/6053168 [==============================] - 0s 0us/step
6062080/6053168 [==============================] - 0s 0us/step
list(dataset_dir.iterdir())
[PosixPath('/tmp/.keras/train'),
 PosixPath('/tmp/.keras/README.md'),
 PosixPath('/tmp/.keras/test'),
 PosixPath('/tmp/.keras/stack_overflow_16k.tar.gz.tar.gz')]
train_dir = dataset_dir/'train'
list(train_dir.iterdir())
[PosixPath('/tmp/.keras/train/java'),
 PosixPath('/tmp/.keras/train/csharp'),
 PosixPath('/tmp/.keras/train/javascript'),
 PosixPath('/tmp/.keras/train/python')]

Katalogi train/csharp , train/java , train/python i train/javascript zawierają wiele plików tekstowych, z których każdy jest pytaniem Stack Overflow. Wydrukuj plik i sprawdź dane.

sample_file = train_dir/'python/1755.txt'
with open(sample_file) as f:
  print(f.read())
why does this blank program print true x=true.def stupid():.    x=false.stupid().print x

Załaduj zbiór danych

Następnie załadujesz dane z dysku i przygotujesz je do formatu odpowiedniego do treningu. Aby to zrobić, użyjesz narzędzia text_dataset_from_directory do utworzenia zbioru danych z etykietątf.data.Dataset . Jeśli jesteś nowicjuszem w tf.data , jest to potężna kolekcja narzędzi do tworzenia potoków wejściowych.

preprocessing.text_dataset_from_directory oczekuje następującej struktury katalogów.

train/
...csharp/
......1.txt
......2.txt
...java/
......1.txt
......2.txt
...javascript/
......1.txt
......2.txt
...python/
......1.txt
......2.txt

Podczas przeprowadzania eksperymentu uczenia maszynowego najlepszym rozwiązaniem jest podzielenie zestawu danych na trzy części: trenowanie , weryfikację i testowanie . Zestaw danych Stack Overflow został już podzielony na trenowanie i testowanie, ale brakuje w nim zestawu walidacyjnego. Utwórz zestaw walidacji przy użyciu podziału 80:20 danych uczących, używając poniższego argumentu validation_split .

batch_size = 32
seed = 42

raw_train_ds = preprocessing.text_dataset_from_directory(
    train_dir,
    batch_size=batch_size,
    validation_split=0.2,
    subset='training',
    seed=seed)
Found 8000 files belonging to 4 classes.
Using 6400 files for training.

Jak widać powyżej, w folderze szkoleniowym znajduje się 8000 przykładów, z których 80% (lub 6400) wykorzystasz do szkolenia. Jak zobaczysz za chwilę, możesz trenować model, przekazująctf.data.Dataset bezpośrednio do model.fit . Najpierw przeprowadź iterację zestawu danych i wydrukuj kilka przykładów, aby poznać dane.

for text_batch, label_batch in raw_train_ds.take(1):
  for i in range(10):
    print("Question: ", text_batch.numpy()[i])
    print("Label:", label_batch.numpy()[i])
Question:  b'"my tester is going to the wrong constructor i am new to programming so if i ask a question that can be easily fixed, please forgive me. my program has a tester class with a main. when i send that to my regularpolygon class, it sends it to the wrong constructor. i have two constructors. 1 without perameters..public regularpolygon().    {.       mynumsides = 5;.       mysidelength = 30;.    }//end default constructor...and my second, with perameters. ..public regularpolygon(int numsides, double sidelength).    {.        mynumsides = numsides;.        mysidelength = sidelength;.    }// end constructor...in my tester class i have these two lines:..regularpolygon shape = new regularpolygon(numsides, sidelength);.        shape.menu();...numsides and sidelength were declared and initialized earlier in the testing class...so what i want to happen, is the tester class sends numsides and sidelength to the second constructor and use it in that class. but it only uses the default constructor, which therefor ruins the whole rest of the program. can somebody help me?..for those of you who want to see more of my code: here you go..public double vertexangle().    {.        system.out.println(""the vertex angle method: "" + mynumsides);// prints out 5.        system.out.println(""the vertex angle method: "" + mysidelength); // prints out 30..        double vertexangle;.        vertexangle = ((mynumsides - 2.0) / mynumsides) * 180.0;.        return vertexangle;.    }//end method vertexangle..public void menu().{.    system.out.println(mynumsides); // prints out what the user puts in.    system.out.println(mysidelength); // prints out what the user puts in.    gotographic();.    calcr(mynumsides, mysidelength);.    calcr(mynumsides, mysidelength);.    print(); .}// end menu...this is my entire tester class:..public static void main(string[] arg).{.    int numsides;.    double sidelength;.    scanner keyboard = new scanner(system.in);..    system.out.println(""welcome to the regular polygon program!"");.    system.out.println();..    system.out.print(""enter the number of sides of the polygon ==> "");.    numsides = keyboard.nextint();.    system.out.println();..    system.out.print(""enter the side length of each side ==> "");.    sidelength = keyboard.nextdouble();.    system.out.println();..    regularpolygon shape = new regularpolygon(numsides, sidelength);.    shape.menu();.}//end main...for testing it i sent it numsides 4 and sidelength 100."\n'
Label: 1
Question:  b'"blank code slow skin detection this code changes the color space to lab and using a threshold finds the skin area of an image. but it\'s ridiculously slow. i don\'t know how to make it faster ?    ..from colormath.color_objects import *..def skindetection(img, treshold=80, color=[255,20,147]):..    print img.shape.    res=img.copy().    for x in range(img.shape[0]):.        for y in range(img.shape[1]):.            rgbimg=rgbcolor(img[x,y,0],img[x,y,1],img[x,y,2]).            labimg=rgbimg.convert_to(\'lab\', debug=false).            if (labimg.lab_l > treshold):.                res[x,y,:]=color.            else: .                res[x,y,:]=img[x,y,:]..    return res"\n'
Label: 3
Question:  b'"option and validation in blank i want to add a new option on my system where i want to add two text files, both rental.txt and customer.txt. inside each text are id numbers of the customer, the videotape they need and the price...i want to place it as an option on my code. right now i have:...add customer.rent return.view list.search.exit...i want to add this as my sixth option. say for example i ordered a video, it would display the price and would let me confirm the price and if i am going to buy it or not...here is my current code:..  import blank.io.*;.    import blank.util.arraylist;.    import static blank.lang.system.out;..    public class rentalsystem{.    static bufferedreader input = new bufferedreader(new inputstreamreader(system.in));.    static file file = new file(""file.txt"");.    static arraylist<string> list = new arraylist<string>();.    static int rows;..    public static void main(string[] args) throws exception{.        introduction();.        system.out.print(""nn"");.        login();.        system.out.print(""nnnnnnnnnnnnnnnnnnnnnn"");.        introduction();.        string repeat;.        do{.            loadfile();.            system.out.print(""nwhat do you want to do?nn"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                    |     1. add customer    |   2. rent return |n"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                    |     3. view list       |   4. search      |n"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                                             |   5. exit        |n"");.            system.out.print(""n                                              - - - - - - - - - -"");.            system.out.print(""nnchoice:"");.            int choice = integer.parseint(input.readline());.            switch(choice){.                case 1:.                    writedata();.                    break;.                case 2:.                    rentdata();.                    break;.                case 3:.                    viewlist();.                    break;.                case 4:.                    search();.                    break;.                case 5:.                    system.out.println(""goodbye!"");.                    system.exit(0);.                default:.                    system.out.print(""invalid choice: "");.                    break;.            }.            system.out.print(""ndo another task? [y/n] "");.            repeat = input.readline();.        }while(repeat.equals(""y""));..        if(repeat!=""y"") system.out.println(""ngoodbye!"");..    }..    public static void writedata() throws exception{.        system.out.print(""nname: "");.        string cname = input.readline();.        system.out.print(""address: "");.        string add = input.readline();.        system.out.print(""phone no.: "");.        string pno = input.readline();.        system.out.print(""rental amount: "");.        string ramount = input.readline();.        system.out.print(""tapenumber: "");.        string tno = input.readline();.        system.out.print(""title: "");.        string title = input.readline();.        system.out.print(""date borrowed: "");.        string dborrowed = input.readline();.        system.out.print(""due date: "");.        string ddate = input.readline();.        createline(cname, add, pno, ramount,tno, title, dborrowed, ddate);.        rentdata();.    }..    public static void createline(string name, string address, string phone , string rental, string tapenumber, string title, string borrowed, string due) throws exception{.        filewriter fw = new filewriter(file, true);.        fw.write(""nname: ""+name + ""naddress: "" + address +""nphone no.: ""+ phone+""nrentalamount: ""+rental+""ntape no.: ""+ tapenumber+""ntitle: ""+ title+""ndate borrowed: ""+borrowed +""ndue date: ""+ due+"":rn"");.        fw.close();.    }..    public static void loadfile() throws exception{.        try{.            list.clear();.            fileinputstream fstream = new fileinputstream(file);.            bufferedreader br = new bufferedreader(new inputstreamreader(fstream));.            rows = 0;.            while( br.ready()).            {.                list.add(br.readline());.                rows++;.            }.            br.close();.        } catch(exception e){.            system.out.println(""list not yet loaded."");.        }.    }..    public static void viewlist(){.        system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |list of all costumers|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        for(int i = 0; i <rows; i++){.            system.out.println(list.get(i));.        }.    }.        public static void rentdata()throws exception.    {   system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |rent data list|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print(""nenter customer name: "");.        string cname = input.readline();.        system.out.print(""date borrowed: "");.        string dborrowed = input.readline();.        system.out.print(""due date: "");.        string ddate = input.readline();.        system.out.print(""return date: "");.        string rdate = input.readline();.        system.out.print(""rent amount: "");.        string ramount = input.readline();..        system.out.print(""you pay:""+ramount);...    }.    public static void search()throws exception.    {   system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |search costumers|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print(""nenter costumer name: "");.        string cname = input.readline();.        boolean found = false;..        for(int i=0; i < rows; i++){.            string temp[] = list.get(i).split("","");..            if(cname.equals(temp[0])){.            system.out.println(""search result:nyou are "" + temp[0] + "" from "" + temp[1] + "".""+ temp[2] + "".""+ temp[3] + "".""+ temp[4] + "".""+ temp[5] + "" is "" + temp[6] + "".""+ temp[7] + "" is "" + temp[8] + ""."");.                found = true;.            }.        }..        if(!found){.            system.out.print(""no results."");.        }..    }..        public static boolean evaluate(string uname, string pass){.        if (uname.equals(""admin"")&&pass.equals(""12345"")) return true;.        else return false;.    }..    public static string login()throws exception{.        bufferedreader input=new bufferedreader(new inputstreamreader(system.in));.        int counter=0;.        do{.            system.out.print(""username:"");.            string uname =input.readline();.            system.out.print(""password:"");.            string pass =input.readline();..            boolean accept= evaluate(uname,pass);..            if(accept){.                break;.                }else{.                    system.out.println(""incorrect username or password!"");.                    counter ++;.                    }.        }while(counter<3);..            if(counter !=3) return ""login successful"";.            else return ""login failed"";.            }.        public static void introduction() throws exception{..        system.out.println(""                  - - - - - - - - - - - - - - - - - - - - - - - - -"");.        system.out.println(""                  !                  r e n t a l                  !"");.        system.out.println(""                   ! ~ ~ ~ ~ ~ !  =================  ! ~ ~ ~ ~ ~ !"");.        system.out.println(""                  !                  s y s t e m                  !"");.        system.out.println(""                  - - - - - - - - - - - - - - - - - - - - - - - - -"");.        }..}"\n'
Label: 1
Question:  b'"exception: dynamic sql generation for the updatecommand is not supported against a selectcommand that does not return any key i dont know what is the problem this my code : ..string nomtable;..datatable listeetablissementtable = new datatable();.datatable listeinteretstable = new datatable();.dataset ds = new dataset();.sqldataadapter da;.sqlcommandbuilder cmdb;..private void listeinterets_click(object sender, eventargs e).{.    nomtable = ""listeinteretstable"";.    d.cnx.open();.    da = new sqldataadapter(""select nome from offices"", d.cnx);.    ds = new dataset();.    da.fill(ds, nomtable);.    datagridview1.datasource = ds.tables[nomtable];.}..private void sauvgarder_click(object sender, eventargs e).{.    d.cnx.open();.    cmdb = new sqlcommandbuilder(da);.    da.update(ds, nomtable);.    d.cnx.close();.}"\n'
Label: 0
Question:  b'"parameter with question mark and super in blank, i\'ve come across a method that is formatted like this:..public final subscription subscribe(final action1<? super t> onnext, final action1<throwable> onerror) {.}...in the first parameter, what does the question mark and super mean?"\n'
Label: 1
Question:  b'call two objects wsdl the first time i got a very strange wsdl. ..i would like to call the object (interface - invoicecheck_out) do you know how?....i would like to call the object (variable) do you know how?..try to call (it`s ok)....try to call (how call this?)\n'
Label: 0
Question:  b"how to correctly make the icon for systemtray in blank using icon sizes of any dimension for systemtray doesn't look good overall. .what is the correct way of making icons for windows system tray?..screenshots: http://imgur.com/zsibwn9..icon: http://imgur.com/vsh4zo8\n"
Label: 0
Question:  b'"is there a way to check a variable that exists in a different script than the original one? i\'m trying to check if a variable, which was previously set to true in 2.py in 1.py, as 1.py is only supposed to continue if the variable is true...2.py..import os..completed = false..#some stuff here..completed = true...1.py..import 2 ..if completed == true.   #do things...however i get a syntax error at ..if completed == true"\n'
Label: 3
Question:  b'"blank control flow i made a number which asks for 2 numbers with blank and responds with  the corresponding message for the case. how come it doesnt work  for the second number ? .regardless what i enter for the second number , i am getting the message ""your number is in the range 0-10""...using system;.using system.collections.generic;.using system.linq;.using system.text;..namespace consoleapplication1.{.    class program.    {.        static void main(string[] args).        {.            string myinput;  // declaring the type of the variables.            int myint;..            string number1;.            int number;...            console.writeline(""enter a number"");.            myinput = console.readline(); //muyinput is a string  which is entry input.            myint = int32.parse(myinput); // myint converts the string into an integer..            if (myint > 0).                console.writeline(""your number {0} is greater than zero."", myint);.            else if (myint < 0).                console.writeline(""your number {0} is  less  than zero."", myint);.            else.                console.writeline(""your number {0} is equal zero."", myint);..            console.writeline(""enter another number"");.            number1 = console.readline(); .            number = int32.parse(myinput); ..            if (number < 0 || number == 0).                console.writeline(""your number {0} is  less  than zero or equal zero."", number);.            else if (number > 0 && number <= 10).                console.writeline(""your number {0} is  in the range from 0 to 10."", number);.            else.                console.writeline(""your number {0} is greater than 10."", number);..            console.writeline(""enter another number"");..        }.    }    .}"\n'
Label: 0
Question:  b'"credentials cannot be used for ntlm authentication i am getting org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials cannot be used for ntlm authentication: exception in eclipse..whether it is possible mention eclipse to take system proxy settings directly?..public class httpgetproxy {.    private static final string proxy_host = ""proxy.****.com"";.    private static final int proxy_port = 6050;..    public static void main(string[] args) {.        httpclient client = new httpclient();.        httpmethod method = new getmethod(""https://kodeblank.org"");..        hostconfiguration config = client.gethostconfiguration();.        config.setproxy(proxy_host, proxy_port);..        string username = ""*****"";.        string password = ""*****"";.        credentials credentials = new usernamepasswordcredentials(username, password);.        authscope authscope = new authscope(proxy_host, proxy_port);..        client.getstate().setproxycredentials(authscope, credentials);..        try {.            client.executemethod(method);..            if (method.getstatuscode() == httpstatus.sc_ok) {.                string response = method.getresponsebodyasstring();.                system.out.println(""response = "" + response);.            }.        } catch (ioexception e) {.            e.printstacktrace();.        } finally {.            method.releaseconnection();.        }.    }.}...exception:...  dec 08, 2017 1:41:39 pm .          org.apache.commons.httpclient.auth.authchallengeprocessor selectauthscheme.         info: ntlm authentication scheme selected.       dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector executeconnect.         severe: credentials cannot be used for ntlm authentication: .           org.apache.commons.httpclient.usernamepasswordcredentials.           org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials .         cannot be used for ntlm authentication: .        enter code here .          org.apache.commons.httpclient.usernamepasswordcredentials.      at org.apache.commons.httpclient.auth.ntlmscheme.authenticate(ntlmscheme.blank:332).        at org.apache.commons.httpclient.httpmethoddirector.authenticateproxy(httpmethoddirector.blank:320).      at org.apache.commons.httpclient.httpmethoddirector.executeconnect(httpmethoddirector.blank:491).      at org.apache.commons.httpclient.httpmethoddirector.executewithretry(httpmethoddirector.blank:391).      at org.apache.commons.httpclient.httpmethoddirector.executemethod(httpmethoddirector.blank:171).      at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:397).      at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:323).      at httpgetproxy.main(httpgetproxy.blank:31).  dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector processproxyauthchallenge.  info: failure authenticating with ntlm @proxy.****.com:6050"\n'
Label: 1

Etykiety to 0 , 1 , 2 lub 3 . Aby zobaczyć, które z nich odpowiadają danej etykiecie ciągu, możesz sprawdzić właściwość class_names w zestawie danych.

for i, label in enumerate(raw_train_ds.class_names):
  print("Label", i, "corresponds to", label)
Label 0 corresponds to csharp
Label 1 corresponds to java
Label 2 corresponds to javascript
Label 3 corresponds to python

Następnie utworzysz zestaw danych walidacyjnych i testowych. Do walidacji wykorzystasz pozostałe 1600 recenzji z zestawu szkoleniowego.

raw_val_ds = preprocessing.text_dataset_from_directory(
    train_dir,
    batch_size=batch_size,
    validation_split=0.2,
    subset='validation',
    seed=seed)
Found 8000 files belonging to 4 classes.
Using 1600 files for validation.
test_dir = dataset_dir/'test'
raw_test_ds = preprocessing.text_dataset_from_directory(
    test_dir, batch_size=batch_size)
Found 8000 files belonging to 4 classes.

Przygotuj zbiór danych do szkolenia

Następnie dokonasz standaryzacji, tokenizacji i wektoryzacji danych za pomocą warstwy preprocessing.TextVectorization .

  • Standaryzacja odnosi się do wstępnego przetwarzania tekstu, zwykle w celu usunięcia interpunkcji lub elementów HTML w celu uproszczenia zestawu danych.

  • Tokenizacja odnosi się do dzielenia ciągów na tokeny (na przykład dzielenie zdania na pojedyncze słowa poprzez dzielenie na odstępach).

  • Wektoryzacja odnosi się do konwersji tokenów na liczby, aby można je było wprowadzić do sieci neuronowej.

Wszystkie te zadania można wykonać za pomocą tej warstwy. Więcej informacji na temat każdego z nich można znaleźć w dokumentacji API .

  • Domyślna standaryzacja konwertuje tekst na małe litery i usuwa znaki interpunkcyjne.

  • Domyślny tokenizer dzieli się na białych znakach.

  • Domyślnym trybem wektoryzacji jest int . Daje to indeksy liczb całkowitych (jeden na token). Ten tryb może służyć do budowania modeli uwzględniających szyk wyrazów. Możesz także użyć innych trybów, takich jak binary , do budowania modeli typu bag-of-word.

Zbudujesz dwa tryby, aby dowiedzieć się o nich więcej. Najpierw użyjesz modelu binary do zbudowania modelu worka słów. Następnie użyjesz trybu int z ConvNet 1D.

VOCAB_SIZE = 10000

binary_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='binary')

W trybie int , oprócz maksymalnego rozmiaru słownika, musisz ustawić jawną maksymalną długość sekwencji, co spowoduje, że warstwa będzie wypełniać lub skracać sekwencje do dokładnie wartości sequence_length.

MAX_SEQUENCE_LENGTH = 250

int_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

Następnie wywołasz adapt aby dopasować stan warstwy przetwarzania wstępnego do zestawu danych. Spowoduje to, że model zbuduje indeks ciągów do liczb całkowitych.

# Make a text-only dataset (without labels), then call adapt
train_text = raw_train_ds.map(lambda text, labels: text)
binary_vectorize_layer.adapt(train_text)
int_vectorize_layer.adapt(train_text)

Zobacz wynik użycia tych warstw do wstępnego przetwarzania danych:

def binary_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return binary_vectorize_layer(text), label
def int_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return int_vectorize_layer(text), label
# Retrieve a batch (of 32 reviews and labels) from the dataset
text_batch, label_batch = next(iter(raw_train_ds))
first_question, first_label = text_batch[0], label_batch[0]
print("Question", first_question)
print("Label", first_label)
Question tf.Tensor(b'"what is the difference between these two ways to create an element? var a = document.createelement(\'div\');..a.id = ""mydiv"";...and..var a = document.createelement(\'div\').id = ""mydiv"";...what is the difference between them such that the first one works and the second one doesn\'t?"\n', shape=(), dtype=string)
Label tf.Tensor(2, shape=(), dtype=int32)
print("'binary' vectorized question:", 
      binary_vectorize_text(first_question, first_label)[0])
'binary' vectorized question: tf.Tensor([[1. 1. 0. ... 0. 0. 0.]], shape=(1, 10000), dtype=float32)
print("'int' vectorized question:",
      int_vectorize_text(first_question, first_label)[0])
'int' vectorized question: tf.Tensor(
[[ 55   6   2 410 211 229 121 895   4 124  32 245  43   5   1   1   5   1
    1   6   2 410 211 191 318  14   2  98  71 188   8   2 199  71 178   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]], shape=(1, 250), dtype=int64)

Jak widać powyżej, tryb binary zwraca tablicę oznaczającą, które tokeny istnieją co najmniej raz na wejściu, podczas gdy tryb int zastępuje każdy token liczbą całkowitą, zachowując w ten sposób ich kolejność. Możesz .get_vocabulary() token (łańcuch), .get_vocabulary() odpowiada każda liczba całkowita, wywołując .get_vocabulary() na warstwie.

print("1289 ---> ", int_vectorize_layer.get_vocabulary()[1289])
print("313 ---> ", int_vectorize_layer.get_vocabulary()[313])
print("Vocabulary size: {}".format(len(int_vectorize_layer.get_vocabulary())))
1289 --->  roman
313 --->  source
Vocabulary size: 10000

Jesteś już prawie gotowy do trenowania swojego modelu. Jako ostatni etap przetwarzania wstępnego zastosujesz utworzone wcześniej warstwy TextVectorization do zestawu danych do pociągu, walidacji i testu.

binary_train_ds = raw_train_ds.map(binary_vectorize_text)
binary_val_ds = raw_val_ds.map(binary_vectorize_text)
binary_test_ds = raw_test_ds.map(binary_vectorize_text)

int_train_ds = raw_train_ds.map(int_vectorize_text)
int_val_ds = raw_val_ds.map(int_vectorize_text)
int_test_ds = raw_test_ds.map(int_vectorize_text)

Skonfiguruj zbiór danych pod kątem wydajności

Są to dwie ważne metody, których należy użyć podczas ładowania danych, aby upewnić się, że operacje we/wy nie zostaną zablokowane.

.cache() przechowuje dane w pamięci po ich załadowaniu z dysku. Zapewni to, że zestaw danych nie stanie się wąskim gardłem podczas trenowania modelu. Jeśli zestaw danych jest zbyt duży, aby zmieścić się w pamięci, możesz również użyć tej metody, aby utworzyć wydajną pamięć podręczną na dysku, która jest bardziej wydajna do odczytu niż wiele małych plików.

.prefetch() nakłada się na .prefetch() przetwarzanie danych i wykonywanie modelu podczas uczenia.

Więcej informacji na temat obu metod oraz sposobu buforowania danych na dysku można znaleźć w przewodniku po wydajności danych .

AUTOTUNE = tf.data.AUTOTUNE

def configure_dataset(dataset):
  return dataset.cache().prefetch(buffer_size=AUTOTUNE)
binary_train_ds = configure_dataset(binary_train_ds)
binary_val_ds = configure_dataset(binary_val_ds)
binary_test_ds = configure_dataset(binary_test_ds)

int_train_ds = configure_dataset(int_train_ds)
int_val_ds = configure_dataset(int_val_ds)
int_test_ds = configure_dataset(int_test_ds)

Trenuj modelkę

Czas stworzyć naszą sieć neuronową. W przypadku binary danych wektoryzowanych wytrenuj prosty model liniowy typu bag-of-words:

binary_model = tf.keras.Sequential([layers.Dense(4)])
binary_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = binary_model.fit(
    binary_train_ds, validation_data=binary_val_ds, epochs=10)
Epoch 1/10
200/200 [==============================] - 2s 5ms/step - loss: 1.1138 - accuracy: 0.6502 - val_loss: 0.9118 - val_accuracy: 0.7775
Epoch 2/10
200/200 [==============================] - 0s 2ms/step - loss: 0.7773 - accuracy: 0.8153 - val_loss: 0.7493 - val_accuracy: 0.7981
Epoch 3/10
200/200 [==============================] - 0s 2ms/step - loss: 0.6270 - accuracy: 0.8598 - val_loss: 0.6643 - val_accuracy: 0.8081
Epoch 4/10
200/200 [==============================] - 0s 2ms/step - loss: 0.5339 - accuracy: 0.8845 - val_loss: 0.6113 - val_accuracy: 0.8238
Epoch 5/10
200/200 [==============================] - 0s 2ms/step - loss: 0.4681 - accuracy: 0.9027 - val_loss: 0.5748 - val_accuracy: 0.8350
Epoch 6/10
200/200 [==============================] - 0s 2ms/step - loss: 0.4179 - accuracy: 0.9175 - val_loss: 0.5483 - val_accuracy: 0.8375
Epoch 7/10
200/200 [==============================] - 0s 2ms/step - loss: 0.3777 - accuracy: 0.9292 - val_loss: 0.5284 - val_accuracy: 0.8388
Epoch 8/10
200/200 [==============================] - 0s 2ms/step - loss: 0.3445 - accuracy: 0.9359 - val_loss: 0.5129 - val_accuracy: 0.8381
Epoch 9/10
200/200 [==============================] - 0s 2ms/step - loss: 0.3163 - accuracy: 0.9430 - val_loss: 0.5007 - val_accuracy: 0.8369
Epoch 10/10
200/200 [==============================] - 0s 2ms/step - loss: 0.2919 - accuracy: 0.9498 - val_loss: 0.4910 - val_accuracy: 0.8388

Następnie użyjesz warstwy wektorowej int do zbudowania sieci ConvNet 1D.

def create_model(vocab_size, num_labels):
  model = tf.keras.Sequential([
      layers.Embedding(vocab_size, 64, mask_zero=True),
      layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
      layers.GlobalMaxPooling1D(),
      layers.Dense(num_labels)
  ])
  return model
# vocab_size is VOCAB_SIZE + 1 since 0 is used additionally for padding.
int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=4)
int_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)
Epoch 1/5
200/200 [==============================] - 2s 4ms/step - loss: 1.1919 - accuracy: 0.4784 - val_loss: 0.8055 - val_accuracy: 0.6837
Epoch 2/5
200/200 [==============================] - 1s 3ms/step - loss: 0.6476 - accuracy: 0.7472 - val_loss: 0.5603 - val_accuracy: 0.7844
Epoch 3/5
200/200 [==============================] - 1s 3ms/step - loss: 0.3888 - accuracy: 0.8766 - val_loss: 0.4817 - val_accuracy: 0.8156
Epoch 4/5
200/200 [==============================] - 1s 3ms/step - loss: 0.2171 - accuracy: 0.9494 - val_loss: 0.4721 - val_accuracy: 0.8238
Epoch 5/5
200/200 [==============================] - 1s 3ms/step - loss: 0.1098 - accuracy: 0.9817 - val_loss: 0.4930 - val_accuracy: 0.8181

Porównaj dwa modele:

print("Linear model on binary vectorized data:")
print(binary_model.summary())
Linear model on binary vectorized data:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 4)                 40004     
=================================================================
Total params: 40,004
Trainable params: 40,004
Non-trainable params: 0
_________________________________________________________________
None
print("ConvNet model on int vectorized data:")
print(int_model.summary())
ConvNet model on int vectorized data:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 64)          640064    
_________________________________________________________________
conv1d (Conv1D)              (None, None, 64)          20544     
_________________________________________________________________
global_max_pooling1d (Global (None, 64)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 260       
=================================================================
Total params: 660,868
Trainable params: 660,868
Non-trainable params: 0
_________________________________________________________________
None

Oceń oba modele na podstawie danych testowych:

binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds)
int_loss, int_accuracy = int_model.evaluate(int_test_ds)

print("Binary model accuracy: {:2.2%}".format(binary_accuracy))
print("Int model accuracy: {:2.2%}".format(int_accuracy))
250/250 [==============================] - 1s 3ms/step - loss: 0.5180 - accuracy: 0.8151
250/250 [==============================] - 1s 2ms/step - loss: 0.5222 - accuracy: 0.8083
Binary model accuracy: 81.51%
Int model accuracy: 80.83%

Eksportuj model

W powyższym kodzie zastosowałeś warstwę TextVectorization do zestawu danych przed wprowadzeniem tekstu do modelu. Jeśli chcesz, aby Twój model mógł przetwarzać nieprzetworzone ciągi (na przykład, aby uprościć jego wdrażanie), możesz dołączyć warstwę TextVectorization do swojego modelu. Aby to zrobić, możesz utworzyć nowy model, korzystając z wag, które właśnie wytrenowałeś.

export_model = tf.keras.Sequential(
    [binary_vectorize_layer, binary_model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])

# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_test_ds)
print("Accuracy: {:2.2%}".format(binary_accuracy))
250/250 [==============================] - 1s 4ms/step - loss: 0.5180 - accuracy: 0.8151
Accuracy: 81.51%

Teraz Twój model może pobierać nieprzetworzone ciągi jako dane wejściowe i przewidywać wynik dla każdej etykiety za pomocą model.predict . Zdefiniuj funkcję, aby znaleźć etykietę z maksymalnym wynikiem:

def get_string_labels(predicted_scores_batch):
  predicted_int_labels = tf.argmax(predicted_scores_batch, axis=1)
  predicted_labels = tf.gather(raw_train_ds.class_names, predicted_int_labels)
  return predicted_labels

Uruchom wnioskowanie na nowych danych

inputs = [
    "how do I extract keys from a dict into a list?",  # python
    "debug public static void main(string[] args) {...}",  # java
]
predicted_scores = export_model.predict(inputs)
predicted_labels = get_string_labels(predicted_scores)
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label.numpy())
Question:  how do I extract keys from a dict into a list?
Predicted label:  b'python'
Question:  debug public static void main(string[] args) {...}
Predicted label:  b'java'

Dołączenie logiki wstępnego przetwarzania tekstu do modelu umożliwia wyeksportowanie modelu do produkcji, co upraszcza wdrażanie i zmniejsza ryzyko pochylenia trenowania/testowania .

Przy wyborze miejsca zastosowania warstwy TextVectorization należy pamiętać o różnicy w wydajności. Używanie go poza modelem umożliwia asynchroniczne przetwarzanie procesora i buforowanie danych podczas uczenia na GPU. Jeśli więc trenujesz swój model na GPU, prawdopodobnie chcesz skorzystać z tej opcji, aby uzyskać najlepszą wydajność podczas opracowywania modelu, a następnie przełącz się na włączenie warstwy TextVectorization do modelu, gdy będziesz gotowy do przygotowania do wdrożenia .

Odwiedź ten samouczek, aby dowiedzieć się więcej o zapisywaniu modeli.

Przykład 2: Wytypuj autora tłumaczeń Iliady

Poniżej przedstawiono przykład użycia tf.data.TextLineDataset do ładowania przykładów z plików tekstowych i tf.text do wstępnego przetwarzania danych. W tym przykładzie użyjesz trzech różnych angielskich przekładów tego samego dzieła, Iliady Homera, i wytrenujesz model identyfikowania tłumacza na podstawie pojedynczej linijki tekstu.

Pobierz i poznaj zbiór danych

Teksty trzech tłumaczeń są autorstwa:

Pliki tekstowe użyte w tym samouczku przeszły kilka typowych zadań wstępnego przetwarzania, takich jak usunięcie nagłówka i stopki dokumentu, numerów wierszy i tytułów rozdziałów. Pobierz lokalnie te lekko zmodyfikowane pliki.

DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']

for name in FILE_NAMES:
  text_dir = utils.get_file(name, origin=DIRECTORY_URL + name)

parent_dir = pathlib.Path(text_dir).parent
list(parent_dir.iterdir())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
827392/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
819200/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step
819200/807992 [==============================] - 0s 0us/step
[PosixPath('/home/kbuilder/.keras/datasets/derby.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/flower_photos.tar.gz'),
 PosixPath('/home/kbuilder/.keras/datasets/heart.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/iris_test.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/train.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/butler.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/flower_photos'),
 PosixPath('/home/kbuilder/.keras/datasets/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/iris_training.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/cowper.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/320px-Felis_catus-cat_on_snow.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/fashion-mnist'),
 PosixPath('/home/kbuilder/.keras/datasets/HIGGS.csv.gz'),
 PosixPath('/home/kbuilder/.keras/datasets/mnist.npz')]

Załaduj zbiór danych

tf.data.Dataset TextLineDataset , który jest przeznaczony do tworzeniatf.data.Dataset z pliku tekstowego, w którym każdy przykład jest wierszem tekstu z oryginalnego pliku, podczas gdy text_dataset_from_directory traktuje całą zawartość pliku jako pojedynczy przykład. TextLineDataset jest przydatny w przypadku danych tekstowych, które są głównie oparte na wierszach (na przykład poezja lub dzienniki błędów).

Iteruj przez te pliki, ładując każdy z nich do własnego zestawu danych. Każdy przykład musi być indywidualnie oznaczony, więc użyj tf.data.Dataset.map aby zastosować funkcję etykietowania do każdego z nich. Spowoduje to iterację po każdym przykładzie w zestawie danych, zwracając pary ( example, label ).

def labeler(example, index):
  return example, tf.cast(index, tf.int64)
labeled_data_sets = []

for i, file_name in enumerate(FILE_NAMES):
  lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name))
  labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))
  labeled_data_sets.append(labeled_dataset)

Następnie połączysz te oznaczone etykietami zestawy danych w jeden zestaw danych i przetasujesz go.

BUFFER_SIZE = 50000
BATCH_SIZE = 64
VALIDATION_SIZE = 5000
all_labeled_data = labeled_data_sets[0]
for labeled_dataset in labeled_data_sets[1:]:
  all_labeled_data = all_labeled_data.concatenate(labeled_dataset)

all_labeled_data = all_labeled_data.shuffle(
    BUFFER_SIZE, reshuffle_each_iteration=False)

Wydrukuj kilka przykładów, jak poprzednio. Zestaw danych nie został jeszcze all_labeled_data na all_labeled_data , dlatego każdy wpis w all_labeled_data odpowiada jednemu punktowi danych:

for text, label in all_labeled_data.take(10):
  print("Sentence: ", text.numpy())
  print("Label:", label.numpy())
Sentence:  b'Thus marched the host like a consuming fire, and the earth groaned'
Label: 2
Sentence:  b'Eluded, had not Phoebus, at his last,'
Label: 0
Sentence:  b"alike; with the one spear he struck Achilles' shield, but did not"
Label: 2
Sentence:  b'another in his heart; therefore I will say what I mean. I will be'
Label: 2
Sentence:  b'to you of all birds, that I may see it with my own eyes and trust it as'
Label: 2
Sentence:  b'The gulfs of ocean, where she sat beside'
Label: 0
Sentence:  b'When thou didst send me in yon field surprised'
Label: 0
Sentence:  b"A boy, amid them, from a clear-ton'd harp"
Label: 1
Sentence:  b'Oh Lycians! where is your heroic might?'
Label: 0
Sentence:  b'To him, despite thy railing, as of right'
Label: 1

Przygotuj zbiór danych do szkolenia

Zamiast używać warstwy Keras TextVectorization do wstępnego przetwarzania naszego zestawu danych tekstowych, będziesz teraz używać interfejsu API tf.text do standaryzacji i tokenizacji danych, tworzenia słownika i używania StaticVocabularyTable do mapowania tokenów na liczby całkowite, które zostaną przekazane do modelu.

Chociaż tf.text udostępnia różne tokenizatory, użyjesz UnicodeScriptTokenizer do UnicodeScriptTokenizer naszego zestawu danych. Zdefiniuj funkcję, która konwertuje tekst na małe litery i tokenizuje go. tf.data.Dataset.map aby zastosować tokenizację do zestawu danych.

tokenizer = tf_text.UnicodeScriptTokenizer()
def tokenize(text, unused_label):
  lower_case = tf_text.case_fold_utf8(text)
  return tokenizer.tokenize(lower_case)
tokenized_ds = all_labeled_data.map(tokenize)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206: batch_gather (from tensorflow.python.ops.array_ops) is deprecated and will be removed after 2017-10-25.
Instructions for updating:
`tf.batch_gather` is deprecated, please use `tf.gather` with `batch_dims=-1` instead.

Możesz iterować po zbiorze danych i wydrukować kilka stokenizowanych przykładów.

for text_batch in tokenized_ds.take(5):
  print("Tokens: ", text_batch.numpy())
Tokens:  [b'thus' b'marched' b'the' b'host' b'like' b'a' b'consuming' b'fire' b','
 b'and' b'the' b'earth' b'groaned']
Tokens:  [b'eluded' b',' b'had' b'not' b'phoebus' b',' b'at' b'his' b'last' b',']
Tokens:  [b'alike' b';' b'with' b'the' b'one' b'spear' b'he' b'struck' b'achilles'
 b"'" b'shield' b',' b'but' b'did' b'not']
Tokens:  [b'another' b'in' b'his' b'heart' b';' b'therefore' b'i' b'will' b'say'
 b'what' b'i' b'mean' b'.' b'i' b'will' b'be']
Tokens:  [b'to' b'you' b'of' b'all' b'birds' b',' b'that' b'i' b'may' b'see' b'it'
 b'with' b'my' b'own' b'eyes' b'and' b'trust' b'it' b'as']

Następnie zbudujesz słownictwo, sortując tokeny według częstotliwości i zachowując najwyższe tokeny VOCAB_SIZE .

tokenized_ds = configure_dataset(tokenized_ds)

vocab_dict = collections.defaultdict(lambda: 0)
for toks in tokenized_ds.as_numpy_iterator():
  for tok in toks:
    vocab_dict[tok] += 1

vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True)
vocab = [token for token, count in vocab]
vocab = vocab[:VOCAB_SIZE]
vocab_size = len(vocab)
print("Vocab size: ", vocab_size)
print("First five vocab entries:", vocab[:5])
Vocab size:  10000
First five vocab entries: [b',', b'the', b'and', b"'", b'of']

Aby przekonwertować tokeny na liczby całkowite, użyj zestawu vocab aby utworzyć StaticVocabularyTable . vocab_size + 2 tokeny na liczby całkowite z zakresu [ 2 , vocab_size + 2 ]. Podobnie jak w TextVectorization warstwy TextVectorization , 0 jest zarezerwowane do oznaczenia dopełnienia, a 1 jest zarezerwowane do oznaczenia tokenu poza słownictwem (OOV).

keys = vocab
values = range(2, len(vocab) + 2)  # reserve 0 for padding, 1 for OOV

init = tf.lookup.KeyValueTensorInitializer(
    keys, values, key_dtype=tf.string, value_dtype=tf.int64)

num_oov_buckets = 1
vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)

Na koniec zdefiniuj funkcję do standaryzacji, tokenizacji i wektoryzacji zbioru danych za pomocą tokenizera i tabeli przeglądowej:

def preprocess_text(text, label):
  standardized = tf_text.case_fold_utf8(text)
  tokenized = tokenizer.tokenize(standardized)
  vectorized = vocab_table.lookup(tokenized)
  return vectorized, label

Możesz spróbować tego na jednym przykładzie, aby zobaczyć wynik:

example_text, example_label = next(iter(all_labeled_data))
print("Sentence: ", example_text.numpy())
vectorized_text, example_label = preprocess_text(example_text, example_label)
print("Vectorized sentence: ", vectorized_text.numpy())
Sentence:  b'Thus marched the host like a consuming fire, and the earth groaned'
Vectorized sentence:  [  45 3768    3  118  158   15 3769  238    2    4    3  181 3286]

Teraz uruchom funkcję preprocess w zestawie danych przy użyciu tf.data.Dataset.map .

all_encoded_data = all_labeled_data.map(preprocess_text)

Podziel zbiór danych na pociąg i test

Warstwa Keras TextVectorization również TextVectorization dane wektorowe. Dopełnienie jest wymagane, ponieważ przykłady w partii muszą mieć ten sam rozmiar i kształt, ale przykłady w tych zestawach danych nie mają tego samego rozmiaru — każdy wiersz tekstu ma inną liczbę słów.tf.data.Dataset obsługuje dzielenie itf.data.Dataset zestawów danych:

train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)
validation_data = all_encoded_data.take(VALIDATION_SIZE)
train_data = train_data.padded_batch(BATCH_SIZE)
validation_data = validation_data.padded_batch(BATCH_SIZE)

Teraz validation_data i train_data nie są kolekcjami par ( na example, label ), ale kolekcjami partii. Każda partia to para ( wiele przykładów , wiele etykiet ) reprezentowanych jako tablice. Ilustrować:

sample_text, sample_labels = next(iter(validation_data))
print("Text batch shape: ", sample_text.shape)
print("Label batch shape: ", sample_labels.shape)
print("First text example: ", sample_text[0])
print("First label example: ", sample_labels[0])
Text batch shape:  (64, 19)
Label batch shape:  (64,)
First text example:  tf.Tensor(
[  45 3768    3  118  158   15 3769  238    2    4    3  181 3286    0
    0    0    0    0    0], shape=(19,), dtype=int64)
First label example:  tf.Tensor(2, shape=(), dtype=int64)

Ponieważ używamy 0 do wypełnienia i 1 do tokenów poza słownictwem (OOV), rozmiar słownictwa zwiększył się o dwa.

vocab_size += 2

Skonfiguruj zestawy danych, aby uzyskać lepszą wydajność, jak poprzednio.

train_data = configure_dataset(train_data)
validation_data = configure_dataset(validation_data)

Trenuj modelkę

Możesz trenować model na tym zbiorze danych jak poprzednio.

model = create_model(vocab_size=vocab_size, num_labels=3)
model.compile(
    optimizer='adam',
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
history = model.fit(train_data, validation_data=validation_data, epochs=3)
Epoch 1/3
697/697 [==============================] - 27s 8ms/step - loss: 0.5258 - accuracy: 0.7661 - val_loss: 0.3681 - val_accuracy: 0.8416
Epoch 2/3
697/697 [==============================] - 2s 3ms/step - loss: 0.2844 - accuracy: 0.8851 - val_loss: 0.3532 - val_accuracy: 0.8514
Epoch 3/3
697/697 [==============================] - 2s 4ms/step - loss: 0.1904 - accuracy: 0.9288 - val_loss: 0.3813 - val_accuracy: 0.8484
loss, accuracy = model.evaluate(validation_data)

print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 1s 2ms/step - loss: 0.3813 - accuracy: 0.8484
Loss:  0.3813416361808777
Accuracy: 84.84%

Eksportuj model

Aby nasz model mógł przyjmować nieprzetworzone ciągi jako dane wejściowe, utworzysz warstwę TextVectorization , która wykonuje te same czynności, co nasza niestandardowa funkcja wstępnego przetwarzania. Ponieważ już set_vocaublary się słownictwa, możesz użyć set_vocaublary zamiast adapt która set_vocaublary nowego słownictwa.

preprocess_layer = TextVectorization(
    max_tokens=vocab_size,
    standardize=tf_text.case_fold_utf8,
    split=tokenizer.tokenize,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)
preprocess_layer.set_vocabulary(vocab)
export_model = tf.keras.Sequential(
    [preprocess_layer, model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])
# Create a test dataset of raw strings
test_ds = all_labeled_data.take(VALIDATION_SIZE).batch(BATCH_SIZE)
test_ds = configure_dataset(test_ds)
loss, accuracy = export_model.evaluate(test_ds)
print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 5s 8ms/step - loss: 0.5407 - accuracy: 0.7926
Loss:  0.5406631827354431
Accuracy: 79.26%

Zgodnie z oczekiwaniami straty i dokładność modelu w zakodowanym zestawie walidacyjnym i wyeksportowanego modelu w surowym zestawie walidacyjnym są takie same.

Uruchom wnioskowanie na nowych danych

inputs = [
    "Join'd to th' Ionians with their flowing robes,",  # Label: 1
    "the allies, and his armour flashed about him so that he seemed to all",  # Label: 2
    "And with loud clangor of his arms he fell.",  # Label: 0
]
predicted_scores = export_model.predict(inputs)
predicted_labels = tf.argmax(predicted_scores, axis=1)
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label.numpy())
Question:  Join'd to th' Ionians with their flowing robes,
Predicted label:  1
Question:  the allies, and his armour flashed about him so that he seemed to all
Predicted label:  2
Question:  And with loud clangor of his arms he fell.
Predicted label:  0

Pobieranie większej liczby zestawów danych za pomocą zestawów danych TensorFlow (TFDS)

Możesz pobrać o wiele więcej zestawów danych z TensorFlow Datasets . Na przykład pobierzesz zestaw danych IMDB Large Movie Review i użyjesz go do trenowania modelu klasyfikacji tonacji.

train_ds = tfds.load(
    'imdb_reviews',
    split='train',
    batch_size=BATCH_SIZE,
    shuffle_files=True,
    as_supervised=True)
val_ds = tfds.load(
    'imdb_reviews',
    split='train',
    batch_size=BATCH_SIZE,
    shuffle_files=True,
    as_supervised=True)

Wydrukuj kilka przykładów.

for review_batch, label_batch in val_ds.take(1):
  for i in range(5):
    print("Review: ", review_batch[i].numpy())
    print("Label: ", label_batch[i].numpy())
Review:  b'I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.'
Label:  0
Review:  b'This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cross, no dangerous waters, just a warm and witty paddle through New York life at its best. A family film in every sense and one that deserves the praise it received.'
Label:  1
Review:  b'As others have mentioned, all the women that go nude in this film are mostly absolutely gorgeous. The plot very ably shows the hypocrisy of the female libido. When men are around they want to be pursued, but when no "men" are around, they become the pursuers of a 14 year old boy. And the boy becomes a man really fast (we should all be so lucky at this age!). He then gets up the courage to pursue his true love.'
Label:  1
Review:  b"This is a film which should be seen by anybody interested in, effected by, or suffering from an eating disorder. It is an amazingly accurate and sensitive portrayal of bulimia in a teenage girl, its causes and its symptoms. The girl is played by one of the most brilliant young actresses working in cinema today, Alison Lohman, who was later so spectacular in 'Where the Truth Lies'. I would recommend that this film be shown in all schools, as you will never see a better on this subject. Alison Lohman is absolutely outstanding, and one marvels at her ability to convey the anguish of a girl suffering from this compulsive disorder. If barometers tell us the air pressure, Alison Lohman tells us the emotional pressure with the same degree of accuracy. Her emotional range is so precise, each scene could be measured microscopically for its gradations of trauma, on a scale of rising hysteria and desperation which reaches unbearable intensity. Mare Winningham is the perfect choice to play her mother, and does so with immense sympathy and a range of emotions just as finely tuned as Lohman's. Together, they make a pair of sensitive emotional oscillators vibrating in resonance with one another. This film is really an astonishing achievement, and director Katt Shea should be proud of it. The only reason for not seeing it is if you are not interested in people. But even if you like nature films best, this is after all animal behaviour at the sharp edge. Bulimia is an extreme version of how a tormented soul can destroy her own body in a frenzy of despair. And if we don't sympathise with people suffering from the depths of despair, then we are dead inside."
Label:  1
Review:  b'Okay, you have:<br /><br />Penelope Keith as Miss Herringbone-Tweed, B.B.E. (Backbone of England.) She\'s killed off in the first scene - that\'s right, folks; this show has no backbone!<br /><br />Peter O\'Toole as Ol\' Colonel Cricket from The First War and now the emblazered Lord of the Manor.<br /><br />Joanna Lumley as the ensweatered Lady of the Manor, 20 years younger than the colonel and 20 years past her own prime but still glamourous (Brit spelling, not mine) enough to have a toy-boy on the side. It\'s alright, they have Col. Cricket\'s full knowledge and consent (they guy even comes \'round for Christmas!) Still, she\'s considerate of the colonel enough to have said toy-boy her own age (what a gal!)<br /><br />David McCallum as said toy-boy, equally as pointlessly glamourous as his squeeze. Pilcher couldn\'t come up with any cover for him within the story, so she gave him a hush-hush job at the Circus.<br /><br />and finally:<br /><br />Susan Hampshire as Miss Polonia Teacups, Venerable Headmistress of the Venerable Girls\' Boarding-School, serving tea in her office with a dash of deep, poignant advice for life in the outside world just before graduation. Her best bit of advice: "I\'ve only been to Nancherrow (the local Stately Home of England) once. I thought it was very beautiful but, somehow, not part of the real world." Well, we can\'t say they didn\'t warn us.<br /><br />Ah, Susan - time was, your character would have been running the whole show. They don\'t write \'em like that any more. Our loss, not yours.<br /><br />So - with a cast and setting like this, you have the re-makings of "Brideshead Revisited," right?<br /><br />Wrong! They took these 1-dimensional supporting roles because they paid so well. After all, acting is one of the oldest temp-jobs there is (YOU name another!)<br /><br />First warning sign: lots and lots of backlighting. They get around it by shooting outdoors - "hey, it\'s just the sunlight!"<br /><br />Second warning sign: Leading Lady cries a lot. When not crying, her eyes are moist. That\'s the law of romance novels: Leading Lady is "dewy-eyed."<br /><br />Henceforth, Leading Lady shall be known as L.L.<br /><br />Third warning sign: L.L. actually has stars in her eyes when she\'s in love. Still, I\'ll give Emily Mortimer an award just for having to act with that spotlight in her eyes (I wonder . did they use contacts?)<br /><br />And lastly, fourth warning sign: no on-screen female character is "Mrs." She\'s either "Miss" or "Lady."<br /><br />When all was said and done, I still couldn\'t tell you who was pursuing whom and why. I couldn\'t even tell you what was said and done.<br /><br />To sum up: they all live through World War II without anything happening to them at all.<br /><br />OK, at the end, L.L. finds she\'s lost her parents to the Japanese prison camps and baby sis comes home catatonic. Meanwhile (there\'s always a "meanwhile,") some young guy L.L. had a crush on (when, I don\'t know) comes home from some wartime tough spot and is found living on the street by Lady of the Manor (must be some street if SHE\'s going to find him there.) Both war casualties are whisked away to recover at Nancherrow (SOMEBODY has to be "whisked away" SOMEWHERE in these romance stories!)<br /><br />Great drama.'
Label:  0

Możesz teraz wstępnie przetwarzać dane i trenować model tak jak poprzednio.

Przygotuj zbiór danych do szkolenia

vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

# Make a text-only dataset (without labels), then call adapt
train_text = train_ds.map(lambda text, labels: text)
vectorize_layer.adapt(train_text)
def vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return vectorize_layer(text), label
train_ds = train_ds.map(vectorize_text)
val_ds = val_ds.map(vectorize_text)
# Configure datasets for performance as before
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)

Trenuj modelkę

model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=1)
model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, None, 64)          640064    
_________________________________________________________________
conv1d_2 (Conv1D)            (None, None, 64)          20544     
_________________________________________________________________
global_max_pooling1d_2 (Glob (None, 64)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 65        
=================================================================
Total params: 660,673
Trainable params: 660,673
Non-trainable params: 0
_________________________________________________________________
model.compile(
    loss=losses.BinaryCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = model.fit(train_ds, validation_data=val_ds, epochs=3)
Epoch 1/3
391/391 [==============================] - 5s 9ms/step - loss: 0.5030 - accuracy: 0.6984 - val_loss: 0.2978 - val_accuracy: 0.8806
Epoch 2/3
391/391 [==============================] - 2s 4ms/step - loss: 0.2754 - accuracy: 0.8800 - val_loss: 0.1712 - val_accuracy: 0.9456
Epoch 3/3
391/391 [==============================] - 2s 4ms/step - loss: 0.1679 - accuracy: 0.9348 - val_loss: 0.0938 - val_accuracy: 0.9788
loss, accuracy = model.evaluate(val_ds)

print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
391/391 [==============================] - 1s 2ms/step - loss: 0.0938 - accuracy: 0.9788
Loss:  0.09377793222665787
Accuracy: 97.88%

Eksportuj model

export_model = tf.keras.Sequential(
    [vectorize_layer, model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])
# 0 --> negative review
# 1 --> positive review
inputs = [
    "This is a fantastic movie.",
    "This is a bad movie.",
    "This movie was so bad that it was good.",
    "I will never say yes to watching this movie.",
]
predicted_scores = export_model.predict(inputs)
predicted_labels = [int(round(x[0])) for x in predicted_scores]
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label)
Question:  This is a fantastic movie.
Predicted label:  1
Question:  This is a bad movie.
Predicted label:  0
Question:  This movie was so bad that it was good.
Predicted label:  0
Question:  I will never say yes to watching this movie.
Predicted label:  0

Wniosek

W tym samouczku przedstawiono kilka sposobów ładowania i wstępnego przetwarzania tekstu. W następnym kroku możesz zapoznać się z dodatkowymi samouczkami na stronie internetowej lub pobrać nowe zestawy danych z TensorFlow Datasets .