Subword tokenizers

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to generate a subword vocabulary from a dataset, and use it to build a text.BertTokenizer from the vocabulary.

The main advantage of a subword tokenizer is that it interpolates between word-based and character-based tokenization. Common words get a slot in the vocabulary, but the tokenizer can fall back to word pieces and individual characters for unknown words.

Overview

The tensorflow_text package includes TensorFlow implementations of many common tokenizers. This includes three subword-style tokenizers:

  • text.BertTokenizer - The BertTokenizer class is a higher level interface. It includes BERT's token splitting algorithm and a WordPieceTokenizer. It takes sentences as input and returns token-IDs.
  • text.WordpieceTokenizer - The WordPieceTokenizer class is a lower level interface. It only implements the WordPiece algorithm. You must standardize and split the text into words before calling it. It takes words as input and returns token-IDs.
  • text.SentencepieceTokenizer - The SentencepieceTokenizer requires a more complex setup. Its initializer requires a pre-trained sentencepiece model. See the google/sentencepiece repository for instructions on how to build one of these models. It can accept sentences as input when tokenizing.

This tutorial builds a Wordpiece vocabulary in a top down manner, starting from existing words. This process doesn't work for Japanese, Chinese, or Korean since these languages don't have clear multi-character units. To tokenize these languages consider using text.SentencepieceTokenizer, text.UnicodeCharTokenizer or this approach.

Setup

pip install -q -U "tensorflow-text==2.11.*"
pip install -q tensorflow_datasets
import collections
import os
import pathlib
import re
import string
import sys
import tempfile
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow_text as text
import tensorflow as tf
2024-02-25 12:31:35.662413: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-02-25 12:31:36.393652: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-02-25 12:31:36.393746: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2024-02-25 12:31:36.393755: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
tf.get_logger().setLevel('ERROR')
pwd = pathlib.Path.cwd()

Download the dataset

Fetch the Portuguese/English translation dataset from tfds:

examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
                               as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']
2024-02-25 12:31:38.681430: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-02-25 12:31:38.681543: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2024-02-25 12:31:38.681604: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2024-02-25 12:31:38.681665: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2024-02-25 12:31:38.738215: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2024-02-25 12:31:38.738430: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

This dataset produces Portuguese/English sentence pairs:

for pt, en in train_examples.take(1):
  print("Portuguese: ", pt.numpy().decode('utf-8'))
  print("English:   ", en.numpy().decode('utf-8'))
Portuguese:  e quando melhoramos a procura , tiramos a única vantagem da impressão , que é a serendipidade .
English:    and when you improve searchability , you actually take away the one advantage of print , which is serendipity .

Note a few things about the example sentences above:

  • They're lower case.
  • There are spaces around the punctuation.
  • It's not clear if or what unicode normalization is being used.
train_en = train_examples.map(lambda pt, en: en)
train_pt = train_examples.map(lambda pt, en: pt)

Generate the vocabulary

This section generates a wordpiece vocabulary from a dataset. If you already have a vocabulary file and just want to see how to build a text.BertTokenizer or text.WordpieceTokenizer tokenizer with it then you can skip ahead to the Build the tokenizer section.

The vocabulary generation code is included in the tensorflow_text pip package. It is not imported by default , you need to manually import it:

from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

The bert_vocab.bert_vocab_from_dataset function will generate the vocabulary.

There are many arguments you can set to adjust its behavior. For this tutorial, you'll mostly use the defaults. If you want to learn more about the options, first read about the algorithm, and then have a look at the code.

This takes about 2 minutes.

bert_tokenizer_params=dict(lower_case=True)
reserved_tokens=["[PAD]", "[UNK]", "[START]", "[END]"]

bert_vocab_args = dict(
    # The target vocabulary size
    vocab_size = 8000,
    # Reserved tokens that must be included in the vocabulary
    reserved_tokens=reserved_tokens,
    # Arguments for `text.BertTokenizer`
    bert_tokenizer_params=bert_tokenizer_params,
    # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
    learn_params={},
)
%%time
pt_vocab = bert_vocab.bert_vocab_from_dataset(
    train_pt.batch(1000).prefetch(2),
    **bert_vocab_args
)
CPU times: user 1min 22s, sys: 2.56 s, total: 1min 24s
Wall time: 1min 18s

Here are some slices of the resulting vocabulary.

print(pt_vocab[:10])
print(pt_vocab[100:110])
print(pt_vocab[1000:1010])
print(pt_vocab[-10:])
['[PAD]', '[UNK]', '[START]', '[END]', '!', '#', '$', '%', '&', "'"]
['no', 'por', 'mais', 'na', 'eu', 'esta', 'muito', 'isso', 'isto', 'sao']
['90', 'desse', 'efeito', 'malaria', 'normalmente', 'palestra', 'recentemente', '##nca', 'bons', 'chave']
['##–', '##—', '##‘', '##’', '##“', '##”', '##⁄', '##€', '##♪', '##♫']

Write a vocabulary file:

def write_vocab_file(filepath, vocab):
  with open(filepath, 'w') as f:
    for token in vocab:
      print(token, file=f)
write_vocab_file('pt_vocab.txt', pt_vocab)

Use that function to generate a vocabulary from the english data:

%%time
en_vocab = bert_vocab.bert_vocab_from_dataset(
    train_en.batch(1000).prefetch(2),
    **bert_vocab_args
)
CPU times: user 58.7 s, sys: 2.17 s, total: 1min
Wall time: 54.4 s
print(en_vocab[:10])
print(en_vocab[100:110])
print(en_vocab[1000:1010])
print(en_vocab[-10:])
['[PAD]', '[UNK]', '[START]', '[END]', '!', '#', '$', '%', '&', "'"]
['as', 'all', 'at', 'one', 'people', 're', 'like', 'if', 'our', 'from']
['choose', 'consider', 'extraordinary', 'focus', 'generation', 'killed', 'patterns', 'putting', 'scientific', 'wait']
['##_', '##`', '##ย', '##ร', '##อ', '##–', '##—', '##’', '##♪', '##♫']

Here are the two vocabulary files:

write_vocab_file('en_vocab.txt', en_vocab)
ls *.txt
en_vocab.txt  pt_vocab.txt

Build the tokenizer

The text.BertTokenizer can be initialized by passing the vocabulary file's path as the first argument (see the section on tf.lookup for other options):

pt_tokenizer = text.BertTokenizer('pt_vocab.txt', **bert_tokenizer_params)
en_tokenizer = text.BertTokenizer('en_vocab.txt', **bert_tokenizer_params)

Now you can use it to encode some text. Take a batch of 3 examples from the english data:

for pt_examples, en_examples in train_examples.batch(3).take(1):
  for ex in en_examples:
    print(ex.numpy())
b'and when you improve searchability , you actually take away the one advantage of print , which is serendipity .'
b'but what if it were active ?'
b"but they did n't test for curiosity ."

Run it through the BertTokenizer.tokenize method. Initially, this returns a tf.RaggedTensor with axes (batch, word, word-piece):

# Tokenize the examples -> (batch, word, word-piece)
token_batch = en_tokenizer.tokenize(en_examples)
# Merge the word and word-piece axes -> (batch, tokens)
token_batch = token_batch.merge_dims(-2,-1)

for ex in token_batch.to_list():
  print(ex)
[72, 117, 79, 1259, 1491, 2362, 13, 79, 150, 184, 311, 71, 103, 2308, 74, 2679, 13, 148, 80, 55, 4840, 1434, 2423, 540, 15]
[87, 90, 107, 76, 129, 1852, 30]
[87, 83, 149, 50, 9, 56, 664, 85, 2512, 15]

If you replace the token IDs with their text representations (using tf.gather) you can see that in the first example the words "searchability" and "serendipity" have been decomposed into "search ##ability" and "s ##ere ##nd ##ip ##ity":

# Lookup each token id in the vocabulary.
txt_tokens = tf.gather(en_vocab, token_batch)
# Join with spaces.
tf.strings.reduce_join(txt_tokens, separator=' ', axis=-1)
<tf.Tensor: shape=(3,), dtype=string, numpy=
array([b'and when you improve search ##ability , you actually take away the one advantage of print , which is s ##ere ##nd ##ip ##ity .',
       b'but what if it were active ?',
       b"but they did n ' t test for curiosity ."], dtype=object)>

To re-assemble words from the extracted tokens, use the BertTokenizer.detokenize method:

words = en_tokenizer.detokenize(token_batch)
tf.strings.reduce_join(words, separator=' ', axis=-1)
<tf.Tensor: shape=(3,), dtype=string, numpy=
array([b'and when you improve searchability , you actually take away the one advantage of print , which is serendipity .',
       b'but what if it were active ?',
       b"but they did n ' t test for curiosity ."], dtype=object)>

Customization and export

This tutorial builds the text tokenizer and detokenizer used by the Transformer tutorial. This section adds methods and processing steps to simplify that tutorial, and exports the tokenizers using tf.saved_model so they can be imported by the other tutorials.

Custom tokenization

The downstream tutorials both expect the tokenized text to include [START] and [END] tokens.

The reserved_tokens reserve space at the beginning of the vocabulary, so [START] and [END] have the same indexes for both languages:

START = tf.argmax(tf.constant(reserved_tokens) == "[START]")
END = tf.argmax(tf.constant(reserved_tokens) == "[END]")

def add_start_end(ragged):
  count = ragged.bounding_shape()[0]
  starts = tf.fill([count,1], START)
  ends = tf.fill([count,1], END)
  return tf.concat([starts, ragged, ends], axis=1)
words = en_tokenizer.detokenize(add_start_end(token_batch))
tf.strings.reduce_join(words, separator=' ', axis=-1)
<tf.Tensor: shape=(3,), dtype=string, numpy=
array([b'[START] and when you improve searchability , you actually take away the one advantage of print , which is serendipity . [END]',
       b'[START] but what if it were active ? [END]',
       b"[START] but they did n ' t test for curiosity . [END]"],
      dtype=object)>

Custom detokenization

Before exporting the tokenizers there are a couple of things you can cleanup for the downstream tutorials:

  1. They want to generate clean text output, so drop reserved tokens like [START], [END] and [PAD].
  2. They're interested in complete strings, so apply a string join along the words axis of the result.
def cleanup_text(reserved_tokens, token_txt):
  # Drop the reserved tokens, except for "[UNK]".
  bad_tokens = [re.escape(tok) for tok in reserved_tokens if tok != "[UNK]"]
  bad_token_re = "|".join(bad_tokens)

  bad_cells = tf.strings.regex_full_match(token_txt, bad_token_re)
  result = tf.ragged.boolean_mask(token_txt, ~bad_cells)

  # Join them into strings.
  result = tf.strings.reduce_join(result, separator=' ', axis=-1)

  return result
en_examples.numpy()
array([b'and when you improve searchability , you actually take away the one advantage of print , which is serendipity .',
       b'but what if it were active ?',
       b"but they did n't test for curiosity ."], dtype=object)
token_batch = en_tokenizer.tokenize(en_examples).merge_dims(-2,-1)
words = en_tokenizer.detokenize(token_batch)
words
<tf.RaggedTensor [[b'and', b'when', b'you', b'improve', b'searchability', b',', b'you',
  b'actually', b'take', b'away', b'the', b'one', b'advantage', b'of',
  b'print', b',', b'which', b'is', b'serendipity', b'.']              ,
 [b'but', b'what', b'if', b'it', b'were', b'active', b'?'],
 [b'but', b'they', b'did', b'n', b"'", b't', b'test', b'for', b'curiosity',
  b'.']                                                                    ]>
cleanup_text(reserved_tokens, words).numpy()
array([b'and when you improve searchability , you actually take away the one advantage of print , which is serendipity .',
       b'but what if it were active ?',
       b"but they did n ' t test for curiosity ."], dtype=object)

Export

The following code block builds a CustomTokenizer class to contain the text.BertTokenizer instances, the custom logic, and the @tf.function wrappers required for export.

class CustomTokenizer(tf.Module):
  def __init__(self, reserved_tokens, vocab_path):
    self.tokenizer = text.BertTokenizer(vocab_path, lower_case=True)
    self._reserved_tokens = reserved_tokens
    self._vocab_path = tf.saved_model.Asset(vocab_path)

    vocab = pathlib.Path(vocab_path).read_text().splitlines()
    self.vocab = tf.Variable(vocab)

    ## Create the signatures for export:   

    # Include a tokenize signature for a batch of strings. 
    self.tokenize.get_concrete_function(
        tf.TensorSpec(shape=[None], dtype=tf.string))

    # Include `detokenize` and `lookup` signatures for:
    #   * `Tensors` with shapes [tokens] and [batch, tokens]
    #   * `RaggedTensors` with shape [batch, tokens]
    self.detokenize.get_concrete_function(
        tf.TensorSpec(shape=[None, None], dtype=tf.int64))
    self.detokenize.get_concrete_function(
          tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))

    self.lookup.get_concrete_function(
        tf.TensorSpec(shape=[None, None], dtype=tf.int64))
    self.lookup.get_concrete_function(
          tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))

    # These `get_*` methods take no arguments
    self.get_vocab_size.get_concrete_function()
    self.get_vocab_path.get_concrete_function()
    self.get_reserved_tokens.get_concrete_function()

  @tf.function
  def tokenize(self, strings):
    enc = self.tokenizer.tokenize(strings)
    # Merge the `word` and `word-piece` axes.
    enc = enc.merge_dims(-2,-1)
    enc = add_start_end(enc)
    return enc

  @tf.function
  def detokenize(self, tokenized):
    words = self.tokenizer.detokenize(tokenized)
    return cleanup_text(self._reserved_tokens, words)

  @tf.function
  def lookup(self, token_ids):
    return tf.gather(self.vocab, token_ids)

  @tf.function
  def get_vocab_size(self):
    return tf.shape(self.vocab)[0]

  @tf.function
  def get_vocab_path(self):
    return self._vocab_path

  @tf.function
  def get_reserved_tokens(self):
    return tf.constant(self._reserved_tokens)

Build a CustomTokenizer for each language:

tokenizers = tf.Module()
tokenizers.pt = CustomTokenizer(reserved_tokens, 'pt_vocab.txt')
tokenizers.en = CustomTokenizer(reserved_tokens, 'en_vocab.txt')

Export the tokenizers as a saved_model:

model_name = 'ted_hrlr_translate_pt_en_converter'
tf.saved_model.save(tokenizers, model_name)

Reload the saved_model and test the methods:

reloaded_tokenizers = tf.saved_model.load(model_name)
reloaded_tokenizers.en.get_vocab_size().numpy()
7010
tokens = reloaded_tokenizers.en.tokenize(['Hello TensorFlow!'])
tokens.numpy()
array([[   2, 4006, 2358,  687, 1192, 2365,    4,    3]])
text_tokens = reloaded_tokenizers.en.lookup(tokens)
text_tokens
<tf.RaggedTensor [[b'[START]', b'hello', b'tens', b'##or', b'##f', b'##low', b'!',
  b'[END]']]>
round_trip = reloaded_tokenizers.en.detokenize(tokens)

print(round_trip.numpy()[0].decode('utf-8'))
hello tensorflow !

Archive it for the translation tutorials:

zip -r {model_name}.zip {model_name}
adding: ted_hrlr_translate_pt_en_converter/ (stored 0%)
  adding: ted_hrlr_translate_pt_en_converter/fingerprint.pb (stored 0%)
  adding: ted_hrlr_translate_pt_en_converter/saved_model.pb (deflated 91%)
  adding: ted_hrlr_translate_pt_en_converter/variables/ (stored 0%)
  adding: ted_hrlr_translate_pt_en_converter/variables/variables.index (deflated 33%)
  adding: ted_hrlr_translate_pt_en_converter/variables/variables.data-00000-of-00001 (deflated 51%)
  adding: ted_hrlr_translate_pt_en_converter/assets/ (stored 0%)
  adding: ted_hrlr_translate_pt_en_converter/assets/en_vocab.txt (deflated 54%)
  adding: ted_hrlr_translate_pt_en_converter/assets/pt_vocab.txt (deflated 57%)
du -h *.zip
168K    ted_hrlr_translate_pt_en_converter.zip

Optional: The algorithm

It's worth noting here that there are two versions of the WordPiece algorithm: Bottom-up and top-down. In both cases goal is the same: "Given a training corpus and a number of desired tokens D, the optimization problem is to select D wordpieces such that the resulting corpus is minimal in the number of wordpieces when segmented according to the chosen wordpiece model."

The original bottom-up WordPiece algorithm, is based on byte-pair encoding. Like BPE, It starts with the alphabet, and iteratively combines common bigrams to form word-pieces and words.

TensorFlow Text's vocabulary generator follows the top-down implementation from BERT. Starting with words and breaking them down into smaller components until they hit the frequency threshold, or can't be broken down further. The next section describes this in detail. For Japanese, Chinese and Korean this top-down approach doesn't work since there are no explicit word units to start with. For those you need a different approach.

Choosing the vocabulary

The top-down WordPiece generation algorithm takes in a set of (word, count) pairs and a threshold T, and returns a vocabulary V.

The algorithm is iterative. It is run for k iterations, where typically k = 4, but only the first two are really important. The third and fourth (and beyond) are just identical to the second. Note that each step of the binary search runs the algorithm from scratch for k iterations.

The iterations described below:

First iteration

  1. Iterate over every word and count pair in the input, denoted as (w, c).
  2. For each word w, generate every substring, denoted as s. E.g., for the word human, we generate {h, hu, hum, huma, human, ##u, ##um, ##uma, ##uman, ##m, ##ma, ##man, #a, ##an, ##n}.
  3. Maintain a substring-to-count hash map, and increment the count of each s by c. E.g., if we have (human, 113) and (humas, 3) in our input, the count of s = huma will be 113+3=116.
  4. Once we've collected the counts of every substring, iterate over the (s, c) pairs starting with the longest s first.
  5. Keep any s that has a c > T. E.g., if T = 100 and we have (pers, 231); (dogs, 259); (##rint; 76), then we would keep pers and dogs.
  6. When an s is kept, subtract off its count from all of its prefixes. This is the reason for sorting all of the s by length in step 4. This is a critical part of the algorithm, because otherwise words would be double counted. For example, let's say that we've kept human and we get to (huma, 116). We know that 113 of those 116 came from human, and 3 came from humas. However, now that human is in our vocabulary, we know we will never segment human into huma ##n. So once human has been kept, then huma only has an effective count of 3.

This algorithm will generate a set of word pieces s (many of which will be whole words w), which we could use as our WordPiece vocabulary.

However, there is a problem: This algorithm will severely overgenerate word pieces. The reason is that we only subtract off counts of prefix tokens. Therefore, if we keep the word human, we will subtract off the count for h, hu, hu, huma, but not for ##u, ##um, ##uma, ##uman and so on. So we might generate both human and ##uman as word pieces, even though ##uman will never be applied.

So why not subtract off the counts for every substring, not just every prefix? Because then we could end up subtracting off the counts multiple times. Let's say that we're processing s of length 5 and we keep both (##denia, 129) and (##eniab, 137), where 65 of those counts came from the word undeniable. If we subtract off from every substring, we would subtract 65 from the substring ##enia twice, even though we should only subtract once. However, if we only subtract off from prefixes, it will correctly only be subtracted once.

Second (and third ...) iteration

To solve the overgeneration issue mentioned above, we perform multiple iterations of the algorithm.

Subsequent iterations are identical to the first, with one important distinction: In step 2, instead of considering every substring, we apply the WordPiece tokenization algorithm using the vocabulary from the previous iteration, and only consider substrings which start on a split point.

For example, let's say that we're performing step 2 of the algorithm and encounter the word undeniable. In the first iteration, we would consider every substring, e.g., {u, un, und, ..., undeniable, ##n, ##nd, ..., ##ndeniable, ...}.

Now, for the second iteration, we will only consider a subset of these. Let's say that after the first iteration, the relevant word pieces are:

un, ##deni, ##able, ##ndeni, ##iable

The WordPiece algorithm will segment this into un ##deni ##able (see the section Applying WordPiece for more information). In this case, we will only consider substrings that start at a segmentation point. We will still consider every possible end position. So during the second iteration, the set of s for undeniable is:

{u, un, und, unden, undeni, undenia, undeniab, undeniabl, undeniable, ##d, ##de, ##den, ##deni, ##denia, ##deniab, ##deniabl , ##deniable, ##a, ##ab, ##abl, ##able}

The algorithm is otherwise identical. In this example, in the first iteration, the algorithm produces the spurious tokens ##ndeni and ##iable. Now, these tokens are never considered, so they will not be generated by the second iteration. We perform several iterations just to make sure the results converge (although there is no literal convergence guarantee).

Applying WordPiece

Once a WordPiece vocabulary has been generated, we need to be able to apply it to new data. The algorithm is a simple greedy longest-match-first application.

For example, consider segmenting the word undeniable.

We first lookup undeniable in our WordPiece dictionary, and if it's present, we're done. If not, we decrement the end point by one character, and repeat, e.g., undeniabl.

Eventually, we will either find a subtoken in our vocabulary, or get down to a single character subtoken. (In general, we assume that every character is in our vocabulary, although this might not be the case for rare Unicode characters. If we encounter a rare Unicode character that's not in the vocabulary we simply map the entire word to <unk>).

In this case, we find un in our vocabulary. So that's our first word piece. Then we jump to the end of un and repeat the processing, e.g., try to find ##deniable, then ##deniabl, etc. This is repeated until we've segmented the entire word.

Intuition

Intuitively, WordPiece tokenization is trying to satisfy two different objectives:

  1. Tokenize the data into the least number of pieces as possible. It is important to keep in mind that the WordPiece algorithm does not "want" to split words. Otherwise, it would just split every word into its characters, e.g., human -> {h, ##u, ##m, ##a, #n}. This is one critical thing that makes WordPiece different from morphological splitters, which will split linguistic morphemes even for common words (e.g., unwanted -> {un, want, ed}).

  2. When a word does have to be split into pieces, split it into pieces that have maximal counts in the training data. For example, the reason why the word undeniable would be split into {un, ##deni, ##able} rather than alternatives like {unde, ##niab, ##le} is that the counts for un and ##able in particular will be very high, since these are common prefixes and suffixes. Even though the count for ##le must be higher than ##able, the low counts of unde and ##niab will make this a less "desirable" tokenization to the algorithm.

Optional: tf.lookup

If you need access to, or more control over the vocabulary it's worth noting that you can build the lookup table yourself and pass that to BertTokenizer.

When you pass a string, BertTokenizer does the following:

pt_lookup = tf.lookup.StaticVocabularyTable(
    num_oov_buckets=1,
    initializer=tf.lookup.TextFileInitializer(
        filename='pt_vocab.txt',
        key_dtype=tf.string,
        key_index = tf.lookup.TextFileIndex.WHOLE_LINE,
        value_dtype = tf.int64,
        value_index=tf.lookup.TextFileIndex.LINE_NUMBER)) 
pt_tokenizer = text.BertTokenizer(pt_lookup)

Now you have direct access to the lookup table used in the tokenizer.

pt_lookup.lookup(tf.constant(['é', 'um', 'uma', 'para', 'não']))
<tf.Tensor: shape=(5,), dtype=int64, numpy=array([7765,   85,   86,   87, 7765])>

You don't need to use a vocabulary file, tf.lookup has other initializer options. If you have the vocabulary in memory you can use lookup.KeyValueTensorInitializer:

pt_lookup = tf.lookup.StaticVocabularyTable(
    num_oov_buckets=1,
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=pt_vocab,
        values=tf.range(len(pt_vocab), dtype=tf.int64))) 
pt_tokenizer = text.BertTokenizer(pt_lookup)