Wiki40B 언어 모델

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 보기 노트북 다운로드 TF Hub 모델보기

TensorFlow Hub에서 Wiki40B 언어 모델을 사용하여 Wikipedia와 유사한 텍스트를 생성하세요!

이 노트북은 다음 방법을 보여줍니다.

  • TF-Hub에서 Wiki40b-LM 컬렉션의 일부인 41개의 단일 언어 및 2개의 다국어 모델을 로드합니다.
  • 이들 모델을 사용하여 주어진 텍스트 조각에 대한 복잡도, 레이어별 활성화 및 단어 임베딩을 얻습니다.
  • 시드 텍스트 조각에서 토큰별로 텍스트를 생성합니다.

언어 모델은 TensorFlow 데이터세트에서 제공하는 새로 게시되고 정리된 Wiki40B 데이터세트로부터 훈련합니다. 훈련 설정은 “Wiki-40B: 다국어 모델 데이터세트” 논문을 기초로 합니다.

설정

Installing Dependencies

Imports

언어 선택하기

TF-Hub에서 로드할 언어 모델과 생성할 텍스트 길이를 선택하겠습니다.

Using the https://tfhub.dev/google/wiki40b-lm-en/1 model to generate sequences of max length 20.

모델 빌드하기

이제 사용할 사전 훈련된 모델을 구성했으므로 최대 max_gen_len까지 텍스트를 생성하도록 구성하겠습니다. TF-Hub에서 언어 모델을 로드하고, 시작 텍스트 조각을 입력한 다음 생성되는 토큰을 반복적으로 피드해야 합니다.

Load the language model pieces

2022-12-14 20:50:26.451211: W tensorflow/core/common_runtime/graph_constructor.cc:1511] Importing a graph with a lower producer version 359 into an existing graph with producer version 987. Shape inference will have run different parts of the graph with different producer versions.

Construct the per-token generation graph

Build the statically unrolled graph for max_gen_len tokens

일부 텍스트 생성하기

일부 텍스트를 생성해 보겠습니다! 언어 모델을 표시하기 위해 텍스트 seed를 설정합니다.

미리 정의된 시드 중 하나를 사용하거나 선택적으로 고유한 시드를 입력할 수 있습니다. 이 텍스트는 다음에 생성할 내용을 언어 모델에 표시하는 데 도움이 되는 언어 모델의 시드로 사용됩니다.

생성된 기사의 특정 부분 앞에 다음의 특수 토큰을 사용할 수 있습니다. _START_ARTICLE_을 사용하여 기사의 시작을 나타내고 _START_SECTION_를 사용하여 섹션의 시작을 나타내며 _START_PARAGRAPH_를 사용하여 기사의 텍스트를 생성합니다.

Predefined Seeds

Enter your own seed (Optional).

Generating text from seed:

_START_ARTICLE_
1882 Prince Edward Island general election
_START_PARAGRAPH_
The 1882 Prince Edward Island election was held on May 8, 1882 to elect members of the House of Assembly of the province of Prince Edward Island, Canada.

Initialize session.

Generate text

_NEWLINE__NEWLINE_The 1880 president of the parliamentary urpea election but did not entitle the Chamber

복잡도, 토큰 ID, 중간 활성화, 임베딩 등 모델의 다른 출력도 볼 수 있습니다.

ppl_result
array([23.50776], dtype=float32)
token_ids_result
array([[   8,    3, 6794, 1579, 1582,  721,  489,  448,    8,    5,   26,
        6794, 1579, 1582,  721,  448,   17,  245,   22,  166, 2928, 6794,
          16, 7690,  384,   11,    7,  402,   11, 1172,   11,    7, 2115,
          11, 1579, 1582,  721,    9,  646,   10]], dtype=int32)
activations_result.shape
(12, 1, 39, 768)
embeddings_result
array([[[ 0.12262525,  5.548009  ,  1.4743135 , ...,  2.4388404 ,
         -2.2788858 ,  2.172028  ],
        [-2.3905468 , -0.97108954, -1.5513545 , ...,  8.458472  ,
         -2.8723319 ,  0.6534524 ],
        [-0.83790785,  0.41630274, -0.8740793 , ...,  1.6446769 ,
         -0.9074106 ,  0.3339265 ],
        ...,
        [-0.8054745 , -1.2495526 ,  2.6232922 , ...,  2.893288  ,
         -0.91287214, -1.1259722 ],
        [ 0.64944506,  3.3696785 ,  0.09543293, ..., -0.7839227 ,
         -1.3573489 ,  1.862214  ],
        [-1.2970612 ,  0.5961366 ,  3.3531897 , ...,  3.2853985 ,
         -1.6212384 ,  0.30257902]]], dtype=float32)