날짜를 저장하십시오! Google I / O가 5 월 18 일부터 20 일까지 반환됩니다. 지금 등록
이 페이지는 Cloud Translation API를 통해 번역되었습니다.
Switch to English

텍스트로드

TensorFlow.org에서보기 Google Colab에서 실행 GitHub에서 소스보기 노트북 다운로드

이 가이드에서는 텍스트를로드하고 전처리하는 두 가지 방법을 보여줍니다.

  • 먼저 Keras 유틸리티와 레이어를 사용합니다. TensorFlow를 처음 사용하는 경우 다음부터 시작해야합니다.

  • 다음으로 tf.data.TextLineDataset 과 같은 하위 수준 유틸리티를 사용하여 텍스트 파일을로드하고 tf.text 를 사용하여보다 세밀한 제어를 위해 데이터를 전처리합니다.

# Be sure you're using the stable versions of both tf and tf-text, for binary compatibility.
pip install -q -U tensorflow
pip install -q -U tensorflow-text
import collections
import pathlib
import re
import string

import tensorflow as tf

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras import utils
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

import tensorflow_datasets as tfds
import tensorflow_text as tf_text

예 1 : 스택 오버플로 질문에 대한 태그 예측

첫 번째 예로서 Stack Overflow에서 프로그래밍 질문 데이터 세트를 다운로드합니다. 각 질문 ( "사전을 값별로 정렬하는 방법")은 정확히 하나의 태그 ( Python , CSharp , JavaScript 또는 Java )로 레이블이 지정됩니다. 당신의 임무는 질문에 대한 태그를 예측하는 모델을 개발하는 것입니다. 이것은 중요하고 광범위하게 적용 가능한 기계 학습 문제인 다중 클래스 분류의 예입니다.

데이터 세트 다운로드 및 탐색

다음으로 데이터 세트를 다운로드하고 디렉토리 구조를 탐색합니다.

data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz'
dataset = utils.get_file(
    'stack_overflow_16k.tar.gz',
    data_url,
    untar=True,
    cache_dir='stack_overflow',
    cache_subdir='')
dataset_dir = pathlib.Path(dataset).parent
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz
6053888/6053168 [==============================] - 0s 0us/step
list(dataset_dir.iterdir())
[PosixPath('/tmp/.keras/train'),
 PosixPath('/tmp/.keras/README.md'),
 PosixPath('/tmp/.keras/test'),
 PosixPath('/tmp/.keras/stack_overflow_16k.tar.gz.tar.gz')]
train_dir = dataset_dir/'train'
list(train_dir.iterdir())
[PosixPath('/tmp/.keras/train/java'),
 PosixPath('/tmp/.keras/train/csharp'),
 PosixPath('/tmp/.keras/train/javascript'),
 PosixPath('/tmp/.keras/train/python')]

train/csharp , train/java , train/pythontrain/javascript 디렉토리에는 많은 텍스트 파일이 포함되어 있으며 각각은 스택 오버플로 질문입니다. 파일을 인쇄하고 데이터를 검사합니다.

sample_file = train_dir/'python/1755.txt'
with open(sample_file) as f:
  print(f.read())
why does this blank program print true x=true.def stupid():.    x=false.stupid().print x

데이터 세트로드

다음으로 디스크에서 데이터를로드하고 교육에 적합한 형식으로 준비합니다. 이렇게하려면 text_dataset_from_directory 유틸리티를 사용하여 레이블이있는tf.data.Dataset 을 만듭니다. tf.data 를 처음 사용하는 경우 입력 파이프 라인을 구축하기위한 강력한 도구 모음입니다.

preprocessing.text_dataset_from_directory 는 다음과 같은 디렉토리 구조를 예상합니다.

train/
...csharp/
......1.txt
......2.txt
...java/
......1.txt
......2.txt
...javascript/
......1.txt
......2.txt
...python/
......1.txt
......2.txt

머신 러닝 실험을 실행할 때 데이터 세트를 train , validation , test의 세 분할로 나누는 것이 좋습니다. Stack Overflow 데이터 세트는 이미 학습과 테스트로 나뉘었지만 검증 세트가 없습니다. 아래 validation_split 인수를 사용하여 훈련 데이터의 80:20 분할을 사용하여 검증 세트를 만듭니다.

batch_size = 32
seed = 42

raw_train_ds = preprocessing.text_dataset_from_directory(
    train_dir,
    batch_size=batch_size,
    validation_split=0.2,
    subset='training',
    seed=seed)
Found 8000 files belonging to 4 classes.
Using 6400 files for training.

위에서 볼 수 있듯이 training 폴더에는 8,000 개의 예제가 있으며이 중 80 % (또는 6,400)를 훈련에 사용합니다. 당신이 잠시 살펴 보 겠지만, 당신은 전달하여 모델을 학습 할 수tf.data.Dataset 직접 model.fit . 먼저 데이터 세트를 반복하고 몇 가지 예제를 인쇄하여 데이터에 대한 느낌을 얻습니다.

for text_batch, label_batch in raw_train_ds.take(1):
  for i in range(10):
    print("Question: ", text_batch.numpy()[i])
    print("Label:", label_batch.numpy()[i])
Question:  b'"my tester is going to the wrong constructor i am new to programming so if i ask a question that can be easily fixed, please forgive me. my program has a tester class with a main. when i send that to my regularpolygon class, it sends it to the wrong constructor. i have two constructors. 1 without perameters..public regularpolygon().    {.       mynumsides = 5;.       mysidelength = 30;.    }//end default constructor...and my second, with perameters. ..public regularpolygon(int numsides, double sidelength).    {.        mynumsides = numsides;.        mysidelength = sidelength;.    }// end constructor...in my tester class i have these two lines:..regularpolygon shape = new regularpolygon(numsides, sidelength);.        shape.menu();...numsides and sidelength were declared and initialized earlier in the testing class...so what i want to happen, is the tester class sends numsides and sidelength to the second constructor and use it in that class. but it only uses the default constructor, which therefor ruins the whole rest of the program. can somebody help me?..for those of you who want to see more of my code: here you go..public double vertexangle().    {.        system.out.println(""the vertex angle method: "" + mynumsides);// prints out 5.        system.out.println(""the vertex angle method: "" + mysidelength); // prints out 30..        double vertexangle;.        vertexangle = ((mynumsides - 2.0) / mynumsides) * 180.0;.        return vertexangle;.    }//end method vertexangle..public void menu().{.    system.out.println(mynumsides); // prints out what the user puts in.    system.out.println(mysidelength); // prints out what the user puts in.    gotographic();.    calcr(mynumsides, mysidelength);.    calcr(mynumsides, mysidelength);.    print(); .}// end menu...this is my entire tester class:..public static void main(string[] arg).{.    int numsides;.    double sidelength;.    scanner keyboard = new scanner(system.in);..    system.out.println(""welcome to the regular polygon program!"");.    system.out.println();..    system.out.print(""enter the number of sides of the polygon ==> "");.    numsides = keyboard.nextint();.    system.out.println();..    system.out.print(""enter the side length of each side ==> "");.    sidelength = keyboard.nextdouble();.    system.out.println();..    regularpolygon shape = new regularpolygon(numsides, sidelength);.    shape.menu();.}//end main...for testing it i sent it numsides 4 and sidelength 100."\n'
Label: 1
Question:  b'"blank code slow skin detection this code changes the color space to lab and using a threshold finds the skin area of an image. but it\'s ridiculously slow. i don\'t know how to make it faster ?    ..from colormath.color_objects import *..def skindetection(img, treshold=80, color=[255,20,147]):..    print img.shape.    res=img.copy().    for x in range(img.shape[0]):.        for y in range(img.shape[1]):.            rgbimg=rgbcolor(img[x,y,0],img[x,y,1],img[x,y,2]).            labimg=rgbimg.convert_to(\'lab\', debug=false).            if (labimg.lab_l > treshold):.                res[x,y,:]=color.            else: .                res[x,y,:]=img[x,y,:]..    return res"\n'
Label: 3
Question:  b'"option and validation in blank i want to add a new option on my system where i want to add two text files, both rental.txt and customer.txt. inside each text are id numbers of the customer, the videotape they need and the price...i want to place it as an option on my code. right now i have:...add customer.rent return.view list.search.exit...i want to add this as my sixth option. say for example i ordered a video, it would display the price and would let me confirm the price and if i am going to buy it or not...here is my current code:..  import blank.io.*;.    import blank.util.arraylist;.    import static blank.lang.system.out;..    public class rentalsystem{.    static bufferedreader input = new bufferedreader(new inputstreamreader(system.in));.    static file file = new file(""file.txt"");.    static arraylist<string> list = new arraylist<string>();.    static int rows;..    public static void main(string[] args) throws exception{.        introduction();.        system.out.print(""nn"");.        login();.        system.out.print(""nnnnnnnnnnnnnnnnnnnnnn"");.        introduction();.        string repeat;.        do{.            loadfile();.            system.out.print(""nwhat do you want to do?nn"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                    |     1. add customer    |   2. rent return |n"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                    |     3. view list       |   4. search      |n"");.            system.out.print(""n                    - - - - - - - - - - - - - - - - - - - - - - -"");.            system.out.print(""nn                                             |   5. exit        |n"");.            system.out.print(""n                                              - - - - - - - - - -"");.            system.out.print(""nnchoice:"");.            int choice = integer.parseint(input.readline());.            switch(choice){.                case 1:.                    writedata();.                    break;.                case 2:.                    rentdata();.                    break;.                case 3:.                    viewlist();.                    break;.                case 4:.                    search();.                    break;.                case 5:.                    system.out.println(""goodbye!"");.                    system.exit(0);.                default:.                    system.out.print(""invalid choice: "");.                    break;.            }.            system.out.print(""ndo another task? [y/n] "");.            repeat = input.readline();.        }while(repeat.equals(""y""));..        if(repeat!=""y"") system.out.println(""ngoodbye!"");..    }..    public static void writedata() throws exception{.        system.out.print(""nname: "");.        string cname = input.readline();.        system.out.print(""address: "");.        string add = input.readline();.        system.out.print(""phone no.: "");.        string pno = input.readline();.        system.out.print(""rental amount: "");.        string ramount = input.readline();.        system.out.print(""tapenumber: "");.        string tno = input.readline();.        system.out.print(""title: "");.        string title = input.readline();.        system.out.print(""date borrowed: "");.        string dborrowed = input.readline();.        system.out.print(""due date: "");.        string ddate = input.readline();.        createline(cname, add, pno, ramount,tno, title, dborrowed, ddate);.        rentdata();.    }..    public static void createline(string name, string address, string phone , string rental, string tapenumber, string title, string borrowed, string due) throws exception{.        filewriter fw = new filewriter(file, true);.        fw.write(""nname: ""+name + ""naddress: "" + address +""nphone no.: ""+ phone+""nrentalamount: ""+rental+""ntape no.: ""+ tapenumber+""ntitle: ""+ title+""ndate borrowed: ""+borrowed +""ndue date: ""+ due+"":rn"");.        fw.close();.    }..    public static void loadfile() throws exception{.        try{.            list.clear();.            fileinputstream fstream = new fileinputstream(file);.            bufferedreader br = new bufferedreader(new inputstreamreader(fstream));.            rows = 0;.            while( br.ready()).            {.                list.add(br.readline());.                rows++;.            }.            br.close();.        } catch(exception e){.            system.out.println(""list not yet loaded."");.        }.    }..    public static void viewlist(){.        system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |list of all costumers|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        for(int i = 0; i <rows; i++){.            system.out.println(list.get(i));.        }.    }.        public static void rentdata()throws exception.    {   system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |rent data list|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print(""nenter customer name: "");.        string cname = input.readline();.        system.out.print(""date borrowed: "");.        string dborrowed = input.readline();.        system.out.print(""due date: "");.        string ddate = input.readline();.        system.out.print(""return date: "");.        string rdate = input.readline();.        system.out.print(""rent amount: "");.        string ramount = input.readline();..        system.out.print(""you pay:""+ramount);...    }.    public static void search()throws exception.    {   system.out.print(""n~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print("" |search costumers|"");.        system.out.print(""~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~"");.        system.out.print(""nenter costumer name: "");.        string cname = input.readline();.        boolean found = false;..        for(int i=0; i < rows; i++){.            string temp[] = list.get(i).split("","");..            if(cname.equals(temp[0])){.            system.out.println(""search result:nyou are "" + temp[0] + "" from "" + temp[1] + "".""+ temp[2] + "".""+ temp[3] + "".""+ temp[4] + "".""+ temp[5] + "" is "" + temp[6] + "".""+ temp[7] + "" is "" + temp[8] + ""."");.                found = true;.            }.        }..        if(!found){.            system.out.print(""no results."");.        }..    }..        public static boolean evaluate(string uname, string pass){.        if (uname.equals(""admin"")&&pass.equals(""12345"")) return true;.        else return false;.    }..    public static string login()throws exception{.        bufferedreader input=new bufferedreader(new inputstreamreader(system.in));.        int counter=0;.        do{.            system.out.print(""username:"");.            string uname =input.readline();.            system.out.print(""password:"");.            string pass =input.readline();..            boolean accept= evaluate(uname,pass);..            if(accept){.                break;.                }else{.                    system.out.println(""incorrect username or password!"");.                    counter ++;.                    }.        }while(counter<3);..            if(counter !=3) return ""login successful"";.            else return ""login failed"";.            }.        public static void introduction() throws exception{..        system.out.println(""                  - - - - - - - - - - - - - - - - - - - - - - - - -"");.        system.out.println(""                  !                  r e n t a l                  !"");.        system.out.println(""                   ! ~ ~ ~ ~ ~ !  =================  ! ~ ~ ~ ~ ~ !"");.        system.out.println(""                  !                  s y s t e m                  !"");.        system.out.println(""                  - - - - - - - - - - - - - - - - - - - - - - - - -"");.        }..}"\n'
Label: 1
Question:  b'"exception: dynamic sql generation for the updatecommand is not supported against a selectcommand that does not return any key i dont know what is the problem this my code : ..string nomtable;..datatable listeetablissementtable = new datatable();.datatable listeinteretstable = new datatable();.dataset ds = new dataset();.sqldataadapter da;.sqlcommandbuilder cmdb;..private void listeinterets_click(object sender, eventargs e).{.    nomtable = ""listeinteretstable"";.    d.cnx.open();.    da = new sqldataadapter(""select nome from offices"", d.cnx);.    ds = new dataset();.    da.fill(ds, nomtable);.    datagridview1.datasource = ds.tables[nomtable];.}..private void sauvgarder_click(object sender, eventargs e).{.    d.cnx.open();.    cmdb = new sqlcommandbuilder(da);.    da.update(ds, nomtable);.    d.cnx.close();.}"\n'
Label: 0
Question:  b'"parameter with question mark and super in blank, i\'ve come across a method that is formatted like this:..public final subscription subscribe(final action1<? super t> onnext, final action1<throwable> onerror) {.}...in the first parameter, what does the question mark and super mean?"\n'
Label: 1
Question:  b'call two objects wsdl the first time i got a very strange wsdl. ..i would like to call the object (interface - invoicecheck_out) do you know how?....i would like to call the object (variable) do you know how?..try to call (it`s ok)....try to call (how call this?)\n'
Label: 0
Question:  b"how to correctly make the icon for systemtray in blank using icon sizes of any dimension for systemtray doesn't look good overall. .what is the correct way of making icons for windows system tray?..screenshots: http://imgur.com/zsibwn9..icon: http://imgur.com/vsh4zo8\n"
Label: 0
Question:  b'"is there a way to check a variable that exists in a different script than the original one? i\'m trying to check if a variable, which was previously set to true in 2.py in 1.py, as 1.py is only supposed to continue if the variable is true...2.py..import os..completed = false..#some stuff here..completed = true...1.py..import 2 ..if completed == true.   #do things...however i get a syntax error at ..if completed == true"\n'
Label: 3
Question:  b'"blank control flow i made a number which asks for 2 numbers with blank and responds with  the corresponding message for the case. how come it doesnt work  for the second number ? .regardless what i enter for the second number , i am getting the message ""your number is in the range 0-10""...using system;.using system.collections.generic;.using system.linq;.using system.text;..namespace consoleapplication1.{.    class program.    {.        static void main(string[] args).        {.            string myinput;  // declaring the type of the variables.            int myint;..            string number1;.            int number;...            console.writeline(""enter a number"");.            myinput = console.readline(); //muyinput is a string  which is entry input.            myint = int32.parse(myinput); // myint converts the string into an integer..            if (myint > 0).                console.writeline(""your number {0} is greater than zero."", myint);.            else if (myint < 0).                console.writeline(""your number {0} is  less  than zero."", myint);.            else.                console.writeline(""your number {0} is equal zero."", myint);..            console.writeline(""enter another number"");.            number1 = console.readline(); .            number = int32.parse(myinput); ..            if (number < 0 || number == 0).                console.writeline(""your number {0} is  less  than zero or equal zero."", number);.            else if (number > 0 && number <= 10).                console.writeline(""your number {0} is  in the range from 0 to 10."", number);.            else.                console.writeline(""your number {0} is greater than 10."", number);..            console.writeline(""enter another number"");..        }.    }    .}"\n'
Label: 0
Question:  b'"credentials cannot be used for ntlm authentication i am getting org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials cannot be used for ntlm authentication: exception in eclipse..whether it is possible mention eclipse to take system proxy settings directly?..public class httpgetproxy {.    private static final string proxy_host = ""proxy.****.com"";.    private static final int proxy_port = 6050;..    public static void main(string[] args) {.        httpclient client = new httpclient();.        httpmethod method = new getmethod(""https://kodeblank.org"");..        hostconfiguration config = client.gethostconfiguration();.        config.setproxy(proxy_host, proxy_port);..        string username = ""*****"";.        string password = ""*****"";.        credentials credentials = new usernamepasswordcredentials(username, password);.        authscope authscope = new authscope(proxy_host, proxy_port);..        client.getstate().setproxycredentials(authscope, credentials);..        try {.            client.executemethod(method);..            if (method.getstatuscode() == httpstatus.sc_ok) {.                string response = method.getresponsebodyasstring();.                system.out.println(""response = "" + response);.            }.        } catch (ioexception e) {.            e.printstacktrace();.        } finally {.            method.releaseconnection();.        }.    }.}...exception:...  dec 08, 2017 1:41:39 pm .          org.apache.commons.httpclient.auth.authchallengeprocessor selectauthscheme.         info: ntlm authentication scheme selected.       dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector executeconnect.         severe: credentials cannot be used for ntlm authentication: .           org.apache.commons.httpclient.usernamepasswordcredentials.           org.apache.commons.httpclient.auth.invalidcredentialsexception: credentials .         cannot be used for ntlm authentication: .        enter code here .          org.apache.commons.httpclient.usernamepasswordcredentials.      at org.apache.commons.httpclient.auth.ntlmscheme.authenticate(ntlmscheme.blank:332).        at org.apache.commons.httpclient.httpmethoddirector.authenticateproxy(httpmethoddirector.blank:320).      at org.apache.commons.httpclient.httpmethoddirector.executeconnect(httpmethoddirector.blank:491).      at org.apache.commons.httpclient.httpmethoddirector.executewithretry(httpmethoddirector.blank:391).      at org.apache.commons.httpclient.httpmethoddirector.executemethod(httpmethoddirector.blank:171).      at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:397).      at org.apache.commons.httpclient.httpclient.executemethod(httpclient.blank:323).      at httpgetproxy.main(httpgetproxy.blank:31).  dec 08, 2017 1:41:39 pm org.apache.commons.httpclient.httpmethoddirector processproxyauthchallenge.  info: failure authenticating with ntlm @proxy.****.com:6050"\n'
Label: 1

레이블은 0 , 1 , 2 또는 3 입니다. 이들 중 어느 문자열 레이블에 해당하는지 확인하려면 데이터 세트의 class_names 속성을 확인하면됩니다.

for i, label in enumerate(raw_train_ds.class_names):
  print("Label", i, "corresponds to", label)
Label 0 corresponds to csharp
Label 1 corresponds to java
Label 2 corresponds to javascript
Label 3 corresponds to python

다음으로 유효성 검사 및 테스트 데이터 세트를 만듭니다. 검증을 위해 교육 세트의 나머지 1,600 개 리뷰를 사용합니다.

raw_val_ds = preprocessing.text_dataset_from_directory(
    train_dir,
    batch_size=batch_size,
    validation_split=0.2,
    subset='validation',
    seed=seed)
Found 8000 files belonging to 4 classes.
Using 1600 files for validation.
test_dir = dataset_dir/'test'
raw_test_ds = preprocessing.text_dataset_from_directory(
    test_dir, batch_size=batch_size)
Found 8000 files belonging to 4 classes.

훈련을위한 데이터 세트 준비

다음으로 preprocessing.TextVectorization 레이어를 사용하여 데이터를 표준화, 토큰 화 및 벡터화합니다.

  • 표준화는 일반적으로 데이터 세트를 단순화하기 위해 구두점이나 HTML 요소를 제거하기 위해 텍스트를 사전 처리하는 것을 말합니다.

  • 토큰 화는 문자열을 토큰으로 분할하는 것을 말합니다 (예 : 공백으로 분할하여 문장을 개별 단어로 분할).

  • 벡터화는 토큰을 숫자로 변환하여 신경망에 공급하는 것을 말합니다.

이러한 모든 작업은이 계층으로 수행 할 수 있습니다. API 문서 에서 각각에 대해 자세히 알아볼 수 있습니다.

  • 기본 표준화는 텍스트를 소문자로 변환하고 구두점을 제거합니다.

  • 기본 토크 나이 저는 공백으로 분할됩니다.

  • 기본 벡터화 모드는 int 입니다. 이것은 정수 인덱스 (토큰 당 하나)를 출력합니다. 이 모드는 단어 순서를 고려하는 모델을 구축하는 데 사용할 수 있습니다. 또한 binary 와 같은 다른 모드를 사용하여 bag-of-word 모델을 구축 할 수 있습니다.

이에 대해 자세히 알아보기 위해 두 가지 모드를 빌드합니다. 먼저 binary 모델을 사용하여 bag-of-words 모델을 만듭니다. 다음으로 1D ConvNet에서 int 모드를 사용합니다.

VOCAB_SIZE = 10000

binary_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='binary')

int 모드의 경우 최대 어휘 크기 외에도 명시 적 최대 시퀀스 길이를 설정해야합니다. 그러면 레이어가 시퀀스를 정확히 sequence_length 값으로 채우거나 자르게됩니다.

MAX_SEQUENCE_LENGTH = 250

int_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

다음으로, adapt 을 호출하여 전처리 레이어의 상태를 데이터 세트에 맞 춥니 다. 이렇게하면 모델이 문자열 인덱스를 정수로 작성합니다.

# Make a text-only dataset (without labels), then call adapt
train_text = raw_train_ds.map(lambda text, labels: text)
binary_vectorize_layer.adapt(train_text)
int_vectorize_layer.adapt(train_text)

다음 레이어를 사용하여 데이터를 전처리 한 결과를 확인하세요.

def binary_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return binary_vectorize_layer(text), label
def int_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return int_vectorize_layer(text), label
# Retrieve a batch (of 32 reviews and labels) from the dataset
text_batch, label_batch = next(iter(raw_train_ds))
first_question, first_label = text_batch[0], label_batch[0]
print("Question", first_question)
print("Label", first_label)
Question tf.Tensor(b'"function expected error in blank for dynamically created check box when it is clicked i want to grab the attribute value.it is working in ie 8,9,10 but not working in ie 11,chrome shows function expected error..<input type=checkbox checked=\'checked\' id=\'symptomfailurecodeid\' tabindex=\'54\' style=\'cursor:pointer;\' onclick=chkclickevt(this);  failurecodeid=""1"" >...function chkclickevt(obj) { .    alert(obj.attributes(""failurecodeid""));.}"\n', shape=(), dtype=string)
Label tf.Tensor(2, shape=(), dtype=int32)
print("'binary' vectorized question:", 
      binary_vectorize_text(first_question, first_label)[0])
'binary' vectorized question: tf.Tensor([[1. 1. 1. ... 0. 0. 0.]], shape=(1, 10000), dtype=float32)
print("'int' vectorized question:",
      int_vectorize_text(first_question, first_label)[0])
'int' vectorized question: tf.Tensor(
[[  38  450   65    7   16   12  892  265  186  451   44   11    6  685
     3   46    4 2062    2  485    1    6  158    7  479    1   26   20
   158    7  479    1  502   38  450    1 1767 1763    1    1    1    1
     1    1    1    1    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0]], shape=(1, 250), dtype=int64)

위에서 볼 수 있듯이 binary 모드는 입력에 한 번 이상 존재하는 토큰을 나타내는 배열을 반환하는 반면 int 모드는 각 토큰을 정수로 대체하여 순서를 유지합니다. 레이어에서 .get_vocabulary() 를 호출하여 각 정수에 해당하는 토큰 (문자열)을 조회 할 수 있습니다.

print("1289 ---> ", int_vectorize_layer.get_vocabulary()[1289])
print("313 ---> ", int_vectorize_layer.get_vocabulary()[313])
print("Vocabulary size: {}".format(len(int_vectorize_layer.get_vocabulary())))
1289 --->  roman
313 --->  source
Vocabulary size: 10000

모델을 훈련 할 준비가 거의되었습니다. 마지막 전처리 단계로 이전에 생성 한 TextVectorization 레이어를 학습, 유효성 검사 및 테스트 데이터 세트에 적용합니다.

binary_train_ds = raw_train_ds.map(binary_vectorize_text)
binary_val_ds = raw_val_ds.map(binary_vectorize_text)
binary_test_ds = raw_test_ds.map(binary_vectorize_text)

int_train_ds = raw_train_ds.map(int_vectorize_text)
int_val_ds = raw_val_ds.map(int_vectorize_text)
int_test_ds = raw_test_ds.map(int_vectorize_text)

성능을위한 데이터 세트 구성

다음은 I / O가 차단되지 않도록 데이터를로드 할 때 사용해야하는 두 가지 중요한 방법입니다.

.cache() 데이터가 디스크에서로드 된 후 메모리에 데이터를 보관합니다. 이렇게하면 모델을 학습하는 동안 데이터 세트가 병목 현상이 발생하지 않습니다. 데이터 세트가 너무 커서 메모리에 맞지 않는 경우이 방법을 사용하여 성능이 뛰어난 온 디스크 캐시를 만들 수도 있습니다. 이는 많은 작은 파일보다 읽기가 더 효율적입니다.

.prefetch() 학습 중 데이터 전처리 및 모델 실행과 겹칩니다.

데이터 성능 가이드 에서 두 가지 방법 및 디스크에 데이터를 캐시하는 방법에 대해 자세히 알아볼 수 있습니다.

AUTOTUNE = tf.data.AUTOTUNE

def configure_dataset(dataset):
  return dataset.cache().prefetch(buffer_size=AUTOTUNE)
binary_train_ds = configure_dataset(binary_train_ds)
binary_val_ds = configure_dataset(binary_val_ds)
binary_test_ds = configure_dataset(binary_test_ds)

int_train_ds = configure_dataset(int_train_ds)
int_val_ds = configure_dataset(int_val_ds)
int_test_ds = configure_dataset(int_test_ds)

모델 훈련

신경망을 만들 때입니다. binary 벡터화 된 데이터의 경우 간단한 bag-of-words 선형 모델을 학습시킵니다.

binary_model = tf.keras.Sequential([layers.Dense(4)])
binary_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = binary_model.fit(
    binary_train_ds, validation_data=binary_val_ds, epochs=10)
Epoch 1/10
200/200 [==============================] - 2s 9ms/step - loss: 1.2359 - accuracy: 0.5427 - val_loss: 0.9108 - val_accuracy: 0.7744
Epoch 2/10
200/200 [==============================] - 1s 3ms/step - loss: 0.8149 - accuracy: 0.8277 - val_loss: 0.7481 - val_accuracy: 0.8031
Epoch 3/10
200/200 [==============================] - 1s 3ms/step - loss: 0.6482 - accuracy: 0.8616 - val_loss: 0.6631 - val_accuracy: 0.8125
Epoch 4/10
200/200 [==============================] - 1s 3ms/step - loss: 0.5492 - accuracy: 0.8832 - val_loss: 0.6100 - val_accuracy: 0.8225
Epoch 5/10
200/200 [==============================] - 1s 3ms/step - loss: 0.4805 - accuracy: 0.9055 - val_loss: 0.5735 - val_accuracy: 0.8294
Epoch 6/10
200/200 [==============================] - 1s 3ms/step - loss: 0.4287 - accuracy: 0.9177 - val_loss: 0.5470 - val_accuracy: 0.8369
Epoch 7/10
200/200 [==============================] - 1s 3ms/step - loss: 0.3876 - accuracy: 0.9286 - val_loss: 0.5270 - val_accuracy: 0.8363
Epoch 8/10
200/200 [==============================] - 1s 3ms/step - loss: 0.3537 - accuracy: 0.9332 - val_loss: 0.5115 - val_accuracy: 0.8394
Epoch 9/10
200/200 [==============================] - 1s 3ms/step - loss: 0.3250 - accuracy: 0.9396 - val_loss: 0.4993 - val_accuracy: 0.8419
Epoch 10/10
200/200 [==============================] - 1s 3ms/step - loss: 0.3003 - accuracy: 0.9479 - val_loss: 0.4896 - val_accuracy: 0.8438

다음으로 int 벡터화 레이어를 사용하여 1D ConvNet을 구축합니다.

def create_model(vocab_size, num_labels):
  model = tf.keras.Sequential([
      layers.Embedding(vocab_size, 64, mask_zero=True),
      layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2),
      layers.GlobalMaxPooling1D(),
      layers.Dense(num_labels)
  ])
  return model
# vocab_size is VOCAB_SIZE + 1 since 0 is used additionally for padding.
int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=4)
int_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)
Epoch 1/5
200/200 [==============================] - 4s 8ms/step - loss: 1.3016 - accuracy: 0.3903 - val_loss: 0.7395 - val_accuracy: 0.6950
Epoch 2/5
200/200 [==============================] - 1s 6ms/step - loss: 0.6901 - accuracy: 0.7170 - val_loss: 0.5435 - val_accuracy: 0.7906
Epoch 3/5
200/200 [==============================] - 1s 6ms/step - loss: 0.4277 - accuracy: 0.8562 - val_loss: 0.4766 - val_accuracy: 0.8194
Epoch 4/5
200/200 [==============================] - 1s 6ms/step - loss: 0.2419 - accuracy: 0.9402 - val_loss: 0.4701 - val_accuracy: 0.8188
Epoch 5/5
200/200 [==============================] - 1s 6ms/step - loss: 0.1218 - accuracy: 0.9767 - val_loss: 0.4932 - val_accuracy: 0.8163

두 모델을 비교하십시오.

print("Linear model on binary vectorized data:")
print(binary_model.summary())
Linear model on binary vectorized data:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 4)                 40004     
=================================================================
Total params: 40,004
Trainable params: 40,004
Non-trainable params: 0
_________________________________________________________________
None
print("ConvNet model on int vectorized data:")
print(int_model.summary())
ConvNet model on int vectorized data:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 64)          640064    
_________________________________________________________________
conv1d (Conv1D)              (None, None, 64)          20544     
_________________________________________________________________
global_max_pooling1d (Global (None, 64)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 260       
=================================================================
Total params: 660,868
Trainable params: 660,868
Non-trainable params: 0
_________________________________________________________________
None

테스트 데이터에서 두 모델을 모두 평가합니다.

binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds)
int_loss, int_accuracy = int_model.evaluate(int_test_ds)

print("Binary model accuracy: {:2.2%}".format(binary_accuracy))
print("Int model accuracy: {:2.2%}".format(int_accuracy))
250/250 [==============================] - 1s 5ms/step - loss: 0.5166 - accuracy: 0.8139
250/250 [==============================] - 1s 4ms/step - loss: 0.5116 - accuracy: 0.8117
Binary model accuracy: 81.39%
Int model accuracy: 81.17%

모델 내보내기

위의 코드에서는 모델에 텍스트를 공급하기 전에 TextVectorization 레이어를 데이터 세트에 적용했습니다. 모델이 원시 문자열을 처리 할 수 ​​있도록하려면 (예 : 배포 단순화) 모델 내부에 TextVectorization 레이어를 포함 할 수 있습니다. 이를 위해 방금 훈련 한 가중치를 사용하여 새 모델을 만들 수 있습니다.

export_model = tf.keras.Sequential(
    [binary_vectorize_layer, binary_model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])

# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_test_ds)
print("Accuracy: {:2.2%}".format(binary_accuracy))
250/250 [==============================] - 2s 5ms/step - loss: 0.5187 - accuracy: 0.8138
Accuracy: 81.39%

이제 모델은 원시 문자열을 입력으로 model.predict 사용하여 각 레이블의 점수를 예측할 수 있습니다. 최대 점수가있는 레이블을 찾는 함수를 정의하십시오.

def get_string_labels(predicted_scores_batch):
  predicted_int_labels = tf.argmax(predicted_scores_batch, axis=1)
  predicted_labels = tf.gather(raw_train_ds.class_names, predicted_int_labels)
  return predicted_labels

새 데이터에 대한 추론 실행

inputs = [
    "how do I extract keys from a dict into a list?",  # python
    "debug public static void main(string[] args) {...}",  # java
]
predicted_scores = export_model.predict(inputs)
predicted_labels = get_string_labels(predicted_scores)
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label.numpy())
Question:  how do I extract keys from a dict into a list?
Predicted label:  b'python'
Question:  debug public static void main(string[] args) {...}
Predicted label:  b'java'

모델 내에 텍스트 사전 처리 로직을 포함하면 배포를 단순화하고 학습 / 테스트 왜곡 가능성을 줄이는 프로덕션 용 모델을 내보낼 수 있습니다.

TextVectorization 레이어를 적용 할 위치를 선택할 때 염두에 두어야 할 성능 차이가 있습니다. 모델 외부에서 사용하면 GPU에서 학습 할 때 비동기 CPU 처리 및 데이터 버퍼링을 수행 할 수 있습니다. 따라서 GPU에서 모델을 교육하는 경우이 옵션을 사용하여 모델을 개발하는 동안 최상의 성능을 얻은 다음 배포를 준비 할 준비가되면 모델 내부에 TextVectorization 레이어를 포함하도록 전환 할 수 있습니다. .

모델 저장에 대해 자세히 알아 보려면이 튜토리얼 을 참조하십시오.

예제 2 : Illiad 번역의 저자 예측

다음은 tf.data.TextLineDataset 을 사용하여 텍스트 파일에서 예제를로드하고 tf.text 를 사용하여 데이터를 전처리하는 예제를 제공합니다. 이 예에서는 동일한 작업의 세 가지 다른 영어 번역 인 Homer 's Illiad를 사용하고 한 줄의 텍스트로 번역자를 식별하는 모델을 훈련시킵니다.

데이터 세트 다운로드 및 탐색

세 가지 번역의 텍스트는 다음과 같습니다.

이 자습서에 사용 된 텍스트 파일은 문서 머리글 및 바닥 글, 줄 번호 및 장 제목 제거와 같은 몇 가지 일반적인 전처리 작업을 거쳤습니다. 이 가벼운 파일을 로컬로 다운로드하십시오.

DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']

for name in FILE_NAMES:
  text_dir = utils.get_file(name, origin=DIRECTORY_URL + name)

parent_dir = pathlib.Path(text_dir).parent
list(parent_dir.iterdir())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step
[PosixPath('/home/kbuilder/.keras/datasets/Giant Panda'),
 PosixPath('/home/kbuilder/.keras/datasets/derby.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/flower_photos.tar.gz'),
 PosixPath('/home/kbuilder/.keras/datasets/spa-eng'),
 PosixPath('/home/kbuilder/.keras/datasets/heart.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/iris_test.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/train.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/butler.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/flower_photos'),
 PosixPath('/home/kbuilder/.keras/datasets/image.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/shakespeare.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/Fireboat'),
 PosixPath('/home/kbuilder/.keras/datasets/iris_training.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/cowper.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/320px-Felis_catus-cat_on_snow.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/jena_climate_2009_2016.csv.zip'),
 PosixPath('/home/kbuilder/.keras/datasets/fashion-mnist'),
 PosixPath('/home/kbuilder/.keras/datasets/ImageNetLabels.txt'),
 PosixPath('/home/kbuilder/.keras/datasets/mnist.npz'),
 PosixPath('/home/kbuilder/.keras/datasets/jena_climate_2009_2016.csv'),
 PosixPath('/home/kbuilder/.keras/datasets/spa-eng.zip')]

데이터 세트로드

텍스트 파일에서tf.data.Dataset 을 생성하도록 설계된 TextLineDataset 을 사용합니다. 각 예제는 원본 파일의 텍스트 text_dataset_from_directory 반면 text_dataset_from_directory 는 파일의 모든 내용을 단일 예제로 취급합니다. TextLineDataset 은 주로 행 기반의 텍스트 데이터 (예 :시 또는 오류 로그)에 유용합니다.

이러한 파일을 반복하여 각 파일을 자체 데이터 세트에로드합니다. 각 예제는 개별적으로 레이블이 지정되어야하므로 tf.data.Dataset.map 을 사용하여 각 예제에 레이 블러 함수를 적용하십시오. 이렇게하면 데이터 세트의 모든 예를 반복하여 ( example, label ) 쌍을 반환합니다.

def labeler(example, index):
  return example, tf.cast(index, tf.int64)
labeled_data_sets = []

for i, file_name in enumerate(FILE_NAMES):
  lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name))
  labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))
  labeled_data_sets.append(labeled_dataset)

다음으로 레이블이 지정된 데이터 세트를 단일 데이터 세트로 결합하고 섞습니다.

BUFFER_SIZE = 50000
BATCH_SIZE = 64
VALIDATION_SIZE = 5000
all_labeled_data = labeled_data_sets[0]
for labeled_dataset in labeled_data_sets[1:]:
  all_labeled_data = all_labeled_data.concatenate(labeled_dataset)

all_labeled_data = all_labeled_data.shuffle(
    BUFFER_SIZE, reshuffle_each_iteration=False)

이전과 같이 몇 가지 예를 인쇄하십시오. 데이터 세트가 아직 일괄 처리되지 않았으므로 all_labeled_data 각 항목은 하나의 데이터 포인트에 해당합니다.

for text, label in all_labeled_data.take(10):
  print("Sentence: ", text.numpy())
  print("Label:", label.numpy())
Sentence:  b'To chariot driven, thou maim thyself and me.'
Label: 0
Sentence:  b'On choicest marrow, and the fat of lambs;'
Label: 1
Sentence:  b'And through the gorgeous breastplate, and within'
Label: 1
Sentence:  b'To visit there the parent of the Gods'
Label: 0
Sentence:  b'For safe escape from danger and from death.'
Label: 0
Sentence:  b'Achilles, ye at least the fight decline'
Label: 0
Sentence:  b"Which done, Achilles portion'd out to each"
Label: 0
Sentence:  b'Whom therefore thou devourest; else themselves'
Label: 0
Sentence:  b'Drove them afar into the host of Greece.'
Label: 0
Sentence:  b"Their succour; then I warn thee, while 'tis time,"
Label: 1

훈련을위한 데이터 세트 준비

TextVectorization 레이어를 사용하여 텍스트 데이터 세트를 전처리하는 대신 이제tf.text API 를 사용하여 데이터를 표준화 및 토큰 화하고, 어휘를 구축하고, StaticVocabularyTable 을 사용하여 토큰을 정수에 매핑하여 모델에 공급합니다.

tf.text는 다양한 토크 나이저를 제공하지만 UnicodeScriptTokenizer 를 사용하여 데이터 세트를 토큰 화합니다. 텍스트를 소문자로 변환하고 토큰 화하는 함수를 정의하십시오. tf.data.Dataset.map 을 사용하여 토큰 화를 데이터 세트에 적용합니다.

tokenizer = tf_text.UnicodeScriptTokenizer()
def tokenize(text, unused_label):
  lower_case = tf_text.case_fold_utf8(text)
  return tokenizer.tokenize(lower_case)
tokenized_ds = all_labeled_data.map(tokenize)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: batch_gather (from tensorflow.python.ops.array_ops) is deprecated and will be removed after 2017-10-25.
Instructions for updating:
`tf.batch_gather` is deprecated, please use `tf.gather` with `batch_dims=-1` instead.

데이터 세트를 반복하고 토큰 화 된 몇 가지 예를 인쇄 할 수 있습니다.

for text_batch in tokenized_ds.take(5):
  print("Tokens: ", text_batch.numpy())
Tokens:  [b'to' b'chariot' b'driven' b',' b'thou' b'maim' b'thyself' b'and' b'me'
 b'.']
Tokens:  [b'on' b'choicest' b'marrow' b',' b'and' b'the' b'fat' b'of' b'lambs' b';']
Tokens:  [b'and' b'through' b'the' b'gorgeous' b'breastplate' b',' b'and' b'within']
Tokens:  [b'to' b'visit' b'there' b'the' b'parent' b'of' b'the' b'gods']
Tokens:  [b'for' b'safe' b'escape' b'from' b'danger' b'and' b'from' b'death' b'.']

다음으로 빈도별로 토큰을 정렬하고 상위 VOCAB_SIZE 토큰을 유지하여 어휘를 구축합니다.

tokenized_ds = configure_dataset(tokenized_ds)

vocab_dict = collections.defaultdict(lambda: 0)
for toks in tokenized_ds.as_numpy_iterator():
  for tok in toks:
    vocab_dict[tok] += 1

vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True)
vocab = [token for token, count in vocab]
vocab = vocab[:VOCAB_SIZE]
vocab_size = len(vocab)
print("Vocab size: ", vocab_size)
print("First five vocab entries:", vocab[:5])
Vocab size:  10000
First five vocab entries: [b',', b'the', b'and', b"'", b'of']

정수로 토큰을 변환하려면 사용 vocab 크리에이트로 설정 StaticVocabularyTable . 토큰을 [ 2 , vocab_size + 2 ] 범위의 정수로 매핑합니다. TextVectorization 레이어와 마찬가지로 0 은 패딩을 나타 내기 위해 예약되고 1 은 OOV (Out-of-vocabulary) 토큰을 나타 내기 위해 예약됩니다.

keys = vocab
values = range(2, len(vocab) + 2)  # reserve 0 for padding, 1 for OOV

init = tf.lookup.KeyValueTensorInitializer(
    keys, values, key_dtype=tf.string, value_dtype=tf.int64)

num_oov_buckets = 1
vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)

마지막으로 토크 나이저 및 룩업 테이블을 사용하여 데이터 세트를 표준화, 토큰 화 및 벡터화하는 기능을 정의합니다.

def preprocess_text(text, label):
  standardized = tf_text.case_fold_utf8(text)
  tokenized = tokenizer.tokenize(standardized)
  vectorized = vocab_table.lookup(tokenized)
  return vectorized, label

단일 예제에서 이것을 시도하여 출력을 볼 수 있습니다.

example_text, example_label = next(iter(all_labeled_data))
print("Sentence: ", example_text.numpy())
vectorized_text, example_label = preprocess_text(example_text, example_label)
print("Vectorized sentence: ", vectorized_text.numpy())
Sentence:  b'To chariot driven, thou maim thyself and me.'
Vectorized sentence:  [   8  195  716    2   47 5605  552    4   40    7]

이제 tf.data.Dataset.map 사용하여 데이터 세트에서 전처리 함수를 실행합니다.

all_encoded_data = all_labeled_data.map(preprocess_text)

데이터 세트를 학습 및 테스트로 분할

TextVectorization 레이어는 벡터화 된 데이터를 일괄 처리하고 TextVectorization . 배치 내부의 예는 크기와 모양이 동일해야하지만 이러한 데이터 세트의 예는 모두 같은 크기가 아니기 때문에 패딩이 필요합니다. 텍스트의 각 줄에는 단어 수가 다릅니다.tf.data.Dataset 은 분할 및 패딩 일괄 처리 데이터 세트를 지원합니다.

train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)
validation_data = all_encoded_data.take(VALIDATION_SIZE)
train_data = train_data.padded_batch(BATCH_SIZE)
validation_data = validation_data.padded_batch(BATCH_SIZE)

이제 validation_datatrain_data 는 ( example, label ) 쌍의 모음이 아니라 배치 모음입니다. 각 배치는 배열로 표현되는 쌍 ( 많은 예 , 많은 레이블 )입니다. 설명하기 위해 :

sample_text, sample_labels = next(iter(validation_data))
print("Text batch shape: ", sample_text.shape)
print("Label batch shape: ", sample_labels.shape)
print("First text example: ", sample_text[0])
print("First label example: ", sample_labels[0])
Text batch shape:  (64, 16)
Label batch shape:  (64,)
First text example:  tf.Tensor(
[   8  195  716    2   47 5605  552    4   40    7    0    0    0    0
    0    0], shape=(16,), dtype=int64)
First label example:  tf.Tensor(0, shape=(), dtype=int64)

패딩에 0 을 사용하고 OOV (Out-of-vocabulary) 토큰에 1 을 사용하기 때문에 어휘 크기가 2만큼 증가했습니다.

vocab_size += 2

이전과 같이 더 나은 성능을 위해 데이터 세트를 구성하십시오.

train_data = configure_dataset(train_data)
validation_data = configure_dataset(validation_data)

모델 훈련

이전과 같이이 데이터 세트에서 모델을 학습시킬 수 있습니다.

model = create_model(vocab_size=vocab_size, num_labels=3)
model.compile(
    optimizer='adam',
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
history = model.fit(train_data, validation_data=validation_data, epochs=3)
Epoch 1/3
697/697 [==============================] - 30s 12ms/step - loss: 0.6900 - accuracy: 0.6660 - val_loss: 0.3815 - val_accuracy: 0.8368
Epoch 2/3
697/697 [==============================] - 5s 7ms/step - loss: 0.3173 - accuracy: 0.8705 - val_loss: 0.3622 - val_accuracy: 0.8460
Epoch 3/3
697/697 [==============================] - 4s 6ms/step - loss: 0.2159 - accuracy: 0.9167 - val_loss: 0.3895 - val_accuracy: 0.8466
loss, accuracy = model.evaluate(validation_data)

print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 1s 2ms/step - loss: 0.3895 - accuracy: 0.8466
Loss:  0.3894515335559845
Accuracy: 84.66%

모델 내보내기

모델이 원시 문자열을 입력으로 사용할 수 있도록하려면 사용자 지정 전처리 기능과 동일한 단계를 수행하는 TextVectorization 레이어를 만듭니다. 이미 어휘를 훈련 set_vocaublary 새로운 어휘를 훈련시키는 adapt 대신 set_vocaublary 를 사용할 수 있습니다.

preprocess_layer = TextVectorization(
    max_tokens=vocab_size,
    standardize=tf_text.case_fold_utf8,
    split=tokenizer.tokenize,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)
preprocess_layer.set_vocabulary(vocab)
export_model = tf.keras.Sequential(
    [preprocess_layer, model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])
# Create a test dataset of raw strings
test_ds = all_labeled_data.take(VALIDATION_SIZE).batch(BATCH_SIZE)
test_ds = configure_dataset(test_ds)
loss, accuracy = export_model.evaluate(test_ds)
print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
79/79 [==============================] - 7s 11ms/step - loss: 0.4626 - accuracy: 0.8128
Loss:  0.4913882315158844
Accuracy: 80.50%

인코딩 된 유효성 검사 세트의 모델과 원시 유효성 검사 세트의 내 보낸 모델에 대한 손실 및 정확도는 예상대로 동일합니다.

새 데이터에 대한 추론 실행

inputs = [
    "Join'd to th' Ionians with their flowing robes,",  # Label: 1
    "the allies, and his armour flashed about him so that he seemed to all",  # Label: 2
    "And with loud clangor of his arms he fell.",  # Label: 0
]
predicted_scores = export_model.predict(inputs)
predicted_labels = tf.argmax(predicted_scores, axis=1)
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label.numpy())
Question:  Join'd to th' Ionians with their flowing robes,
Predicted label:  1
Question:  the allies, and his armour flashed about him so that he seemed to all
Predicted label:  2
Question:  And with loud clangor of his arms he fell.
Predicted label:  0

TensorFlow 데이터 세트 (TFDS)를 사용하여 더 많은 데이터 세트 다운로드

TensorFlow Datasets 에서 더 많은 데이터 세트를 다운로드 할 수 있습니다. 예를 들어 IMDB Large Movie Review 데이터 세트를 다운로드하고이를 사용하여 감정 분류를위한 모델을 학습합니다.

train_ds = tfds.load(
    'imdb_reviews',
    split='train',
    batch_size=BATCH_SIZE,
    shuffle_files=True,
    as_supervised=True)
val_ds = tfds.load(
    'imdb_reviews',
    split='train',
    batch_size=BATCH_SIZE,
    shuffle_files=True,
    as_supervised=True)

몇 가지 예를 인쇄하십시오.

for review_batch, label_batch in val_ds.take(1):
  for i in range(5):
    print("Review: ", review_batch[i].numpy())
    print("Label: ", label_batch[i].numpy())
Review:  b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
Label:  0
Review:  b'I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.'
Label:  0
Review:  b'Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to enforce the law themselves, then gunfighters battling it out on the streets for control of the town? <br /><br />Nothing even remotely resembling that happened on the Canadian side of the border during the Klondike gold rush. Mr. Mann and company appear to have mistaken Dawson City for Deadwood, the Canadian North for the American Wild West.<br /><br />Canadian viewers be prepared for a Reefer Madness type of enjoyable howl with this ludicrous plot, or, to shake your head in disgust.'
Label:  0
Review:  b'This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cross, no dangerous waters, just a warm and witty paddle through New York life at its best. A family film in every sense and one that deserves the praise it received.'
Label:  1
Review:  b'As others have mentioned, all the women that go nude in this film are mostly absolutely gorgeous. The plot very ably shows the hypocrisy of the female libido. When men are around they want to be pursued, but when no "men" are around, they become the pursuers of a 14 year old boy. And the boy becomes a man really fast (we should all be so lucky at this age!). He then gets up the courage to pursue his true love.'
Label:  1

이제 이전과 같이 데이터를 전처리하고 모델을 학습시킬 수 있습니다.

훈련을위한 데이터 세트 준비

vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

# Make a text-only dataset (without labels), then call adapt
train_text = train_ds.map(lambda text, labels: text)
vectorize_layer.adapt(train_text)
def vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return vectorize_layer(text), label
train_ds = train_ds.map(vectorize_text)
val_ds = val_ds.map(vectorize_text)
# Configure datasets for performance as before
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)

모델 훈련

model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=1)
model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, None, 64)          640064    
_________________________________________________________________
conv1d_2 (Conv1D)            (None, None, 64)          20544     
_________________________________________________________________
global_max_pooling1d_2 (Glob (None, 64)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 65        
=================================================================
Total params: 660,673
Trainable params: 660,673
Non-trainable params: 0
_________________________________________________________________
model.compile(
    loss=losses.BinaryCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = model.fit(train_ds, validation_data=val_ds, epochs=3)
Epoch 1/3
391/391 [==============================] - 6s 13ms/step - loss: 0.6123 - accuracy: 0.5805 - val_loss: 0.2976 - val_accuracy: 0.8807
Epoch 2/3
391/391 [==============================] - 4s 10ms/step - loss: 0.3141 - accuracy: 0.8609 - val_loss: 0.1708 - val_accuracy: 0.9423
Epoch 3/3
391/391 [==============================] - 4s 10ms/step - loss: 0.1977 - accuracy: 0.9211 - val_loss: 0.0944 - val_accuracy: 0.9776
loss, accuracy = model.evaluate(val_ds)

print("Loss: ", loss)
print("Accuracy: {:2.2%}".format(accuracy))
391/391 [==============================] - 1s 3ms/step - loss: 0.0944 - accuracy: 0.9776
Loss:  0.09437894821166992
Accuracy: 97.76%

모델 내보내기

export_model = tf.keras.Sequential(
    [vectorize_layer, model,
     layers.Activation('sigmoid')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])
# 0 --> negative review
# 1 --> positive review
inputs = [
    "This is a fantastic movie.",
    "This is a bad movie.",
    "This movie was so bad that it was good.",
    "I will never say yes to watching this movie.",
]
predicted_scores = export_model.predict(inputs)
predicted_labels = [int(round(x[0])) for x in predicted_scores]
for input, label in zip(inputs, predicted_labels):
  print("Question: ", input)
  print("Predicted label: ", label)
Question:  This is a fantastic movie.
Predicted label:  1
Question:  This is a bad movie.
Predicted label:  0
Question:  This movie was so bad that it was good.
Predicted label:  0
Question:  I will never say yes to watching this movie.
Predicted label:  0

결론

이 튜토리얼에서는 텍스트를로드하고 전처리하는 여러 방법을 보여주었습니다. 다음 단계로 웹 사이트에서 추가 가이드를 탐색하거나 TensorFlow Datasets 에서 새 데이터 세트를 다운로드 할 수 있습니다.