Unicode strings

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Introduction

NLP models often handle different languages with different character sets. Unicode is a standard encoding system that is used to represent characters from almost all languages. Every Unicode character is encoded using a unique integer code point between 0 and 0x10FFFF. A Unicode string is a sequence of zero or more code points.

This tutorial shows how to represent Unicode strings in TensorFlow and manipulate them using Unicode equivalents of standard string ops. It separates Unicode strings into tokens based on script detection.

import tensorflow as tf
import numpy as np
2023-12-07 12:38:28.567493: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-07 12:38:28.567537: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-07 12:38:28.569152: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

The tf.string data type

The basic TensorFlow tf.string dtype allows you to build tensors of byte strings. Unicode strings are utf-8 encoded by default.

tf.constant(u"Thanks 😊")
<tf.Tensor: shape=(), dtype=string, numpy=b'Thanks \xf0\x9f\x98\x8a'>

A tf.string tensor treats byte strings as atomic units. This enables it to store byte strings of varying lengths. The string length is not included in the tensor dimensions.

tf.constant([u"You're", u"welcome!"]).shape
TensorShape([2])

If you use Python to construct strings, note that string literals are Unicode-encoded by default.

Representing Unicode

There are two standard ways to represent a Unicode string in TensorFlow:

  • string scalar — where the sequence of code points is encoded using a known character encoding.
  • int32 vector — where each position contains a single code point.

For example, the following three values all represent the Unicode string "语言处理" (which means "language processing" in Chinese):

# Unicode string, represented as a UTF-8 encoded string scalar.
text_utf8 = tf.constant(u"语言处理")
text_utf8
<tf.Tensor: shape=(), dtype=string, numpy=b'\xe8\xaf\xad\xe8\xa8\x80\xe5\xa4\x84\xe7\x90\x86'>
# Unicode string, represented as a UTF-16-BE encoded string scalar.
text_utf16be = tf.constant(u"语言处理".encode("UTF-16-BE"))
text_utf16be
<tf.Tensor: shape=(), dtype=string, numpy=b'\x8b\xed\x8a\x00Y\x04t\x06'>
# Unicode string, represented as a vector of Unicode code points.
text_chars = tf.constant([ord(char) for char in u"语言处理"])
text_chars
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([35821, 35328, 22788, 29702], dtype=int32)>

Converting between representations

TensorFlow provides operations to convert between these different representations:

tf.strings.unicode_decode(text_utf8,
                          input_encoding='UTF-8')
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([35821, 35328, 22788, 29702], dtype=int32)>
tf.strings.unicode_encode(text_chars,
                          output_encoding='UTF-8')
<tf.Tensor: shape=(), dtype=string, numpy=b'\xe8\xaf\xad\xe8\xa8\x80\xe5\xa4\x84\xe7\x90\x86'>
tf.strings.unicode_transcode(text_utf8,
                             input_encoding='UTF8',
                             output_encoding='UTF-16-BE')
<tf.Tensor: shape=(), dtype=string, numpy=b'\x8b\xed\x8a\x00Y\x04t\x06'>

Batch dimensions

When decoding multiple strings, the number of characters in each string may not be equal. The return result is a tf.RaggedTensor, where the innermost dimension length varies depending on the number of characters in each string.

# A batch of Unicode strings, each represented as a UTF8-encoded string.
batch_utf8 = [s.encode('UTF-8') for s in
              [u'hÃllo', u'What is the weather tomorrow', u'Göödnight', u'😊']]
batch_chars_ragged = tf.strings.unicode_decode(batch_utf8,
                                               input_encoding='UTF-8')
for sentence_chars in batch_chars_ragged.to_list():
  print(sentence_chars)
[104, 195, 108, 108, 111]
[87, 104, 97, 116, 32, 105, 115, 32, 116, 104, 101, 32, 119, 101, 97, 116, 104, 101, 114, 32, 116, 111, 109, 111, 114, 114, 111, 119]
[71, 246, 246, 100, 110, 105, 103, 104, 116]
[128522]

You can use this tf.RaggedTensor directly, or convert it to a dense tf.Tensor with padding or a tf.sparse.SparseTensor using the methods tf.RaggedTensor.to_tensor and tf.RaggedTensor.to_sparse.

batch_chars_padded = batch_chars_ragged.to_tensor(default_value=-1)
print(batch_chars_padded.numpy())
[[   104    195    108    108    111     -1     -1     -1     -1     -1
      -1     -1     -1     -1     -1     -1     -1     -1     -1     -1
      -1     -1     -1     -1     -1     -1     -1     -1]
 [    87    104     97    116     32    105    115     32    116    104
     101     32    119    101     97    116    104    101    114     32
     116    111    109    111    114    114    111    119]
 [    71    246    246    100    110    105    103    104    116     -1
      -1     -1     -1     -1     -1     -1     -1     -1     -1     -1
      -1     -1     -1     -1     -1     -1     -1     -1]
 [128522     -1     -1     -1     -1     -1     -1     -1     -1     -1
      -1     -1     -1     -1     -1     -1     -1     -1     -1     -1
      -1     -1     -1     -1     -1     -1     -1     -1]]
batch_chars_sparse = batch_chars_ragged.to_sparse()

nrows, ncols = batch_chars_sparse.dense_shape.numpy()
elements = [['_' for i in range(ncols)] for j in range(nrows)]
for (row, col), value in zip(batch_chars_sparse.indices.numpy(), batch_chars_sparse.values.numpy()):
  elements[row][col] = str(value)
# max_width = max(len(value) for row in elements for value in row)
value_lengths = []
for row in elements:
  for value in row:
    value_lengths.append(len(value))
max_width = max(value_lengths)
print('[%s]' % '\n '.join(
    '[%s]' % ', '.join(value.rjust(max_width) for value in row)
    for row in elements))
[[   104,    195,    108,    108,    111,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _]
 [    87,    104,     97,    116,     32,    105,    115,     32,    116,    104,    101,     32,    119,    101,     97,    116,    104,    101,    114,     32,    116,    111,    109,    111,    114,    114,    111,    119]
 [    71,    246,    246,    100,    110,    105,    103,    104,    116,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _]
 [128522,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _,      _]]

When encoding multiple strings with the same lengths, use a tf.Tensor as the input.

tf.strings.unicode_encode([[99, 97, 116], [100, 111, 103], [99, 111, 119]],
                          output_encoding='UTF-8')
<tf.Tensor: shape=(3,), dtype=string, numpy=array([b'cat', b'dog', b'cow'], dtype=object)>

When encoding multiple strings with varying length, use a tf.RaggedTensor as the input.

tf.strings.unicode_encode(batch_chars_ragged, output_encoding='UTF-8')
<tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'h\xc3\x83llo', b'What is the weather tomorrow',
       b'G\xc3\xb6\xc3\xb6dnight', b'\xf0\x9f\x98\x8a'], dtype=object)>

If you have a tensor with multiple strings in padded or sparse format, convert it first into a tf.RaggedTensor before calling tf.strings.unicode_encode.

tf.strings.unicode_encode(
    tf.RaggedTensor.from_sparse(batch_chars_sparse),
    output_encoding='UTF-8')
<tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'h\xc3\x83llo', b'What is the weather tomorrow',
       b'G\xc3\xb6\xc3\xb6dnight', b'\xf0\x9f\x98\x8a'], dtype=object)>
tf.strings.unicode_encode(
    tf.RaggedTensor.from_tensor(batch_chars_padded, padding=-1),
    output_encoding='UTF-8')
<tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'h\xc3\x83llo', b'What is the weather tomorrow',
       b'G\xc3\xb6\xc3\xb6dnight', b'\xf0\x9f\x98\x8a'], dtype=object)>

Unicode operations

Character length

Use the unit parameter of the tf.strings.length op to indicate how character lengths should be computed. unit defaults to "BYTE", but it can be set to other values, such as "UTF8_CHAR" or "UTF16_CHAR", to determine the number of Unicode codepoints in each encoded string.

# Note that the final character takes up 4 bytes in UTF8.
thanks = u'Thanks 😊'.encode('UTF-8')
num_bytes = tf.strings.length(thanks).numpy()
num_chars = tf.strings.length(thanks, unit='UTF8_CHAR').numpy()
print('{} bytes; {} UTF-8 characters'.format(num_bytes, num_chars))
11 bytes; 8 UTF-8 characters

Character substrings

The tf.strings.substr op accepts the unit parameter, and uses it to determine what kind of offsets the pos and len paremeters contain.

# Here, unit='BYTE' (default). Returns a single byte with len=1
tf.strings.substr(thanks, pos=7, len=1).numpy()
b'\xf0'
# Specifying unit='UTF8_CHAR', returns a single 4 byte character in this case
print(tf.strings.substr(thanks, pos=7, len=1, unit='UTF8_CHAR').numpy())
b'\xf0\x9f\x98\x8a'

Split Unicode strings

The tf.strings.unicode_split operation splits unicode strings into substrings of individual characters.

tf.strings.unicode_split(thanks, 'UTF-8').numpy()
array([b'T', b'h', b'a', b'n', b'k', b's', b' ', b'\xf0\x9f\x98\x8a'],
      dtype=object)

Byte offsets for characters

To align the character tensor generated by tf.strings.unicode_decode with the original string, it's useful to know the offset for where each character begins. The method tf.strings.unicode_decode_with_offsets is similar to unicode_decode, except that it returns a second tensor containing the start offset of each character.

codepoints, offsets = tf.strings.unicode_decode_with_offsets(u'🎈🎉🎊', 'UTF-8')

for (codepoint, offset) in zip(codepoints.numpy(), offsets.numpy()):
  print('At byte offset {}: codepoint {}'.format(offset, codepoint))
At byte offset 0: codepoint 127880
At byte offset 4: codepoint 127881
At byte offset 8: codepoint 127882

Unicode scripts

Each Unicode code point belongs to a single collection of codepoints known as a script . A character's script is helpful in determining which language the character might be in. For example, knowing that 'Б' is in Cyrillic script indicates that modern text containing that character is likely from a Slavic language such as Russian or Ukrainian.

TensorFlow provides the tf.strings.unicode_script operation to determine which script a given codepoint uses. The script codes are int32 values corresponding to International Components for Unicode (ICU) UScriptCode values.

uscript = tf.strings.unicode_script([33464, 1041])  # ['芸', 'Б']

print(uscript.numpy())  # [17, 8] == [USCRIPT_HAN, USCRIPT_CYRILLIC]
[17  8]

The tf.strings.unicode_script operation can also be applied to multidimensional tf.Tensors or tf.RaggedTensors of codepoints:

print(tf.strings.unicode_script(batch_chars_ragged))
<tf.RaggedTensor [[25, 25, 25, 25, 25],
 [25, 25, 25, 25, 0, 25, 25, 0, 25, 25, 25, 0, 25, 25, 25, 25, 25, 25, 25,
  0, 25, 25, 25, 25, 25, 25, 25, 25]                                      ,
 [25, 25, 25, 25, 25, 25, 25, 25, 25], [0]]>

Example: Simple segmentation

Segmentation is the task of splitting text into word-like units. This is often easy when space characters are used to separate words, but some languages (like Chinese and Japanese) do not use spaces, and some languages (like German) contain long compounds that must be split in order to analyze their meaning. In web text, different languages and scripts are frequently mixed together, as in "NY株価" (New York Stock Exchange).

We can perform very rough segmentation (without implementing any ML models) by using changes in script to approximate word boundaries. This will work for strings like the "NY株価" example above. It will also work for most languages that use spaces, as the space characters of various scripts are all classified as USCRIPT_COMMON, a special script code that differs from that of any actual text.

# dtype: string; shape: [num_sentences]
#
# The sentences to process.  Edit this line to try out different inputs!
sentence_texts = [u'Hello, world.', u'世界こんにちは']

First, decode the sentences into character codepoints, and find the script identifeir for each character.

# dtype: int32; shape: [num_sentences, (num_chars_per_sentence)]
#
# sentence_char_codepoint[i, j] is the codepoint for the j'th character in
# the i'th sentence.
sentence_char_codepoint = tf.strings.unicode_decode(sentence_texts, 'UTF-8')
print(sentence_char_codepoint)

# dtype: int32; shape: [num_sentences, (num_chars_per_sentence)]
#
# sentence_char_scripts[i, j] is the Unicode script of the j'th character in
# the i'th sentence.
sentence_char_script = tf.strings.unicode_script(sentence_char_codepoint)
print(sentence_char_script)
<tf.RaggedTensor [[72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 46],
 [19990, 30028, 12371, 12435, 12395, 12385, 12399]]>
<tf.RaggedTensor [[25, 25, 25, 25, 25, 0, 0, 25, 25, 25, 25, 25, 0],
 [17, 17, 20, 20, 20, 20, 20]]>

Use the script identifiers to determine where word boundaries should be added. Add a word boundary at the beginning of each sentence, and for each character whose script differs from the previous character.

# dtype: bool; shape: [num_sentences, (num_chars_per_sentence)]
#
# sentence_char_starts_word[i, j] is True if the j'th character in the i'th
# sentence is the start of a word.
sentence_char_starts_word = tf.concat(
    [tf.fill([sentence_char_script.nrows(), 1], True),
     tf.not_equal(sentence_char_script[:, 1:], sentence_char_script[:, :-1])],
    axis=1)

# dtype: int64; shape: [num_words]
#
# word_starts[i] is the index of the character that starts the i'th word (in
# the flattened list of characters from all sentences).
word_starts = tf.squeeze(tf.where(sentence_char_starts_word.values), axis=1)
print(word_starts)
tf.Tensor([ 0  5  7 12 13 15], shape=(6,), dtype=int64)

You can then use those start offsets to build a RaggedTensor containing the list of words from all batches.

# dtype: int32; shape: [num_words, (num_chars_per_word)]
#
# word_char_codepoint[i, j] is the codepoint for the j'th character in the
# i'th word.
word_char_codepoint = tf.RaggedTensor.from_row_starts(
    values=sentence_char_codepoint.values,
    row_starts=word_starts)
print(word_char_codepoint)
<tf.RaggedTensor [[72, 101, 108, 108, 111], [44, 32], [119, 111, 114, 108, 100], [46],
 [19990, 30028], [12371, 12435, 12395, 12385, 12399]]>

To finish, segment the word codepoints RaggedTensor back into sentences and encode into UTF-8 strings for readability.

# dtype: int64; shape: [num_sentences]
#
# sentence_num_words[i] is the number of words in the i'th sentence.
sentence_num_words = tf.reduce_sum(
    tf.cast(sentence_char_starts_word, tf.int64),
    axis=1)

# dtype: int32; shape: [num_sentences, (num_words_per_sentence), (num_chars_per_word)]
#
# sentence_word_char_codepoint[i, j, k] is the codepoint for the k'th character
# in the j'th word in the i'th sentence.
sentence_word_char_codepoint = tf.RaggedTensor.from_row_lengths(
    values=word_char_codepoint,
    row_lengths=sentence_num_words)
print(sentence_word_char_codepoint)

tf.strings.unicode_encode(sentence_word_char_codepoint, 'UTF-8').to_list()
<tf.RaggedTensor [[[72, 101, 108, 108, 111], [44, 32], [119, 111, 114, 108, 100], [46]],
 [[19990, 30028], [12371, 12435, 12395, 12385, 12399]]]>
[[b'Hello', b', ', b'world', b'.'],
 [b'\xe4\xb8\x96\xe7\x95\x8c',
  b'\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf']]