ประเภทส่วนขยาย

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHubดาวน์โหลดโน๊ตบุ๊ค

ติดตั้ง

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

ประเภทส่วนขยาย

ประเภทที่กำหนดโดยผู้ใช้สามารถทำให้โครงการอ่านง่ายขึ้น เป็นแบบแยกส่วน และบำรุงรักษาได้ อย่างไรก็ตาม TensorFlow API ส่วนใหญ่มีการสนับสนุนที่จำกัดมากสำหรับประเภท Python ที่ผู้ใช้กำหนด ซึ่งรวมถึง API ระดับสูง (เช่น Keras , tf.function , tf.SavedModel ) และ API ระดับล่าง (เช่น tf.while_loop และ tf.concat ) สามารถใช้ ประเภทส่วนขยาย TensorFlow เพื่อสร้างประเภทเชิงวัตถุที่ผู้ใช้กำหนดซึ่งทำงานได้อย่างราบรื่นกับ API ของ TensorFlow ในการสร้างประเภทส่วนขยาย เพียงกำหนดคลาส Python ด้วย tf.experimental.ExtensionType เป็นฐาน และใช้ คำอธิบายประกอบประเภท เพื่อระบุประเภทสำหรับแต่ละฟิลด์

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

คลาสพื้นฐาน tf.experimental.ExtensionType ทำงานคล้ายกับการ typing.NamedTuple และ @dataclasses.dataclass จากไลบรารี Python มาตรฐาน โดยเฉพาะอย่างยิ่ง จะเพิ่มตัวสร้างและวิธีการพิเศษโดยอัตโนมัติ (เช่น __repr__ และ __eq__ ) ตามคำอธิบายประกอบประเภทฟิลด์

โดยทั่วไป ประเภทส่วนขยายมักจะจัดเป็นหนึ่งในสองประเภท:

  • โครงสร้างข้อมูล ซึ่งจัดกลุ่มคอลเลกชั่นของค่าที่เกี่ยวข้องกัน และสามารถจัดเตรียมการดำเนินการที่เป็นประโยชน์ตามค่าเหล่านั้น โครงสร้างข้อมูลอาจค่อนข้างทั่วไป (เช่นตัวอย่าง TensorGraph ด้านบน) หรืออาจปรับแต่งให้เข้ากับรุ่นที่เฉพาะเจาะจงได้อย่างมาก

  • ประเภทคล้ายเทนเซอร์ ซึ่งเชี่ยวชาญหรือขยายแนวคิดของ "เทนเซอร์" ประเภทในหมวดหมู่นี้มี rank shape และมักจะเป็น dtype ; และมันสมเหตุสมผลที่จะใช้พวกมันกับการทำงานของเทนเซอร์ (เช่น tf.stack , tf.add หรือ tf.matmul ) MaskedTensor และ CSRSparseMatrix เป็นตัวอย่างของประเภทที่คล้ายเทนเซอร์

API ที่รองรับ

ประเภทส่วนขยายได้รับการสนับสนุนโดย TensorFlow API ต่อไปนี้:

  • Keras : ประเภทส่วนขยายสามารถใช้เป็นอินพุตและเอาต์พุตสำหรับ Keras Models และ Layers
  • tf.data.Dataset : สามารถรวมประเภทส่วนขยายใน Datasets และส่งคืนโดย dataset Iterators
  • ฮับ ​​Tensorflow : ประเภทส่วนขยายสามารถใช้เป็นอินพุตและเอาต์พุตสำหรับโมดูล tf.hub
  • SavedModel : ประเภทส่วนขยายสามารถใช้เป็นอินพุตและเอาต์พุตสำหรับฟังก์ชัน SavedModel
  • tf.function : ประเภทส่วนขยายสามารถใช้เป็นอาร์กิวเมนต์และส่งคืนค่าสำหรับฟังก์ชันที่หุ้มด้วย @tf.function decorator
  • while loops : ประเภทส่วนขยายสามารถใช้เป็นตัวแปรลูปใน tf.while_loop และสามารถใช้เป็นอาร์กิวเมนต์และส่งกลับค่าสำหรับเนื้อความของ while-loop
  • conditionals : สามารถเลือกประเภทส่วนขยายตามเงื่อนไขได้โดยใช้ tf.cond และ tf.case
  • py_function : ประเภทส่วนขยายสามารถใช้เป็นอาร์กิวเมนต์และส่งคืนค่าสำหรับอาร์กิวเมนต์ func ไปที่ tf.py_function ได้
  • Tensor ops : สามารถขยายประเภทส่วนขยายเพื่อรองรับ TensorFlow ops ส่วนใหญ่ที่ยอมรับอินพุตของ Tensor (เช่น tf.matmul , tf.gather และ tf.reduce_sum ) ดูส่วน " จัดส่ง " ด้านล่างสำหรับข้อมูลเพิ่มเติม
  • กลยุทธ์การกระจาย : ประเภทส่วนขยายสามารถใช้เป็นค่าต่อแบบจำลองได้

สำหรับรายละเอียดเพิ่มเติม โปรดดูหัวข้อ "TensorFlow APIs ที่รองรับ ExtensionTypes" ด้านล่าง

ความต้องการ

ประเภทฟิลด์

ต้องประกาศฟิลด์ทั้งหมด (ตัวแปรอินสแตนซ์ที่เรียกว่า) และต้องระบุคำอธิบายประกอบสำหรับแต่ละฟิลด์ รองรับคำอธิบายประกอบประเภทต่อไปนี้:

พิมพ์ ตัวอย่าง
จำนวนเต็มหลาม i: int
งูหลามลอย f: float
สตริง Python s: str
Python บูลีน b: bool
Python ไม่มี n: None
รูปร่างเทนเซอร์ shape: tf.TensorShape
เทนเซอร์ dtypes dtype: tf.DType
เทนเซอร์ t: tf.Tensor
ประเภทส่วนขยาย mt: MyMaskedTensor
Ragged Tensors rt: tf.RaggedTensor
เซนเซอร์แบบกระจัดกระจาย st: tf.SparseTensor
ชิ้นที่จัดทำดัชนี s: tf.IndexedSlices
ตัวเลือกเทนเซอร์ o: tf.experimental.Optional
พิมพ์สหภาพแรงงาน int_or_float: typing.Union[int, float]
ทูเปิลส์ params: typing.Tuple[int, float, tf.Tensor, int]
tuples ยาว lengths: typing.Tuple[int, ...]
การทำแผนที่ tags: typing.Mapping[str, tf.Tensor]
ค่าทางเลือก weight: typing.Optional[tf.Tensor]

การกลายพันธุ์

ประเภทส่วนขยายจะต้องไม่เปลี่ยนรูป เพื่อให้แน่ใจว่าสามารถติดตามได้อย่างถูกต้องโดยกลไกการติดตามกราฟของ TensorFlow หากคุณพบว่าตัวเองต้องการเปลี่ยนค่าประเภทส่วนขยาย ให้พิจารณากำหนดวิธีการที่แปลงค่าแทน ตัวอย่างเช่น แทนที่จะกำหนดเมธอด set_mask เพื่อกลายพันธุ์ MaskedTensor คุณสามารถกำหนดเมธอด replace_mask ที่ส่งคืน MaskedTensor ใหม่:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

ฟังก์ชันที่เพิ่มโดย ExtensionType

คลาสพื้นฐาน ExtensionType มีฟังก์ชันการทำงานต่อไปนี้:

  • ตัวสร้าง ( __init__ )
  • วิธีการแสดงที่พิมพ์ได้ ( __repr__ )
  • ตัวดำเนินการความเท่าเทียมกันและความไม่เท่าเทียมกัน ( __eq__ )
  • วิธีการตรวจสอบความถูกต้อง ( __validate__ )
  • บังคับไม่เปลี่ยนรูป
  • TypeSpec ที่ซ้อนกัน
  • รองรับการส่ง Tensor API

ดูส่วน "การปรับแต่งประเภทส่วนขยาย" ด้านล่างสำหรับข้อมูลเพิ่มเติมเกี่ยวกับการปรับแต่งฟังก์ชันนี้

ตัวสร้าง

ตัวสร้างที่เพิ่มโดย ExtensionType ใช้แต่ละฟิลด์เป็นอาร์กิวเมนต์ที่มีชื่อ (ตามลำดับที่ระบุไว้ในคำจำกัดความของคลาส) ตัวสร้างนี้จะพิมพ์-ตรวจสอบแต่ละพารามิเตอร์ และแปลงตามความจำเป็น โดยเฉพาะอย่างยิ่ง ฟิลด์ Tensor จะถูกแปลงโดยใช้ tf.convert_to_tensor ; ฟิลด์ทู Tuple จะถูกแปลงเป็น tuple s; และฟิลด์ Mapping จะถูกแปลงเป็น dicts ที่ไม่เปลี่ยนรูป

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor(
[[1 2 3]
 [4 5 6]], shape=(2, 3), dtype=int32)

ตัวสร้างจะเพิ่ม TypeError หากค่าฟิลด์ไม่สามารถแปลงเป็นประเภทที่ประกาศได้:

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None

ค่าเริ่มต้นสำหรับฟิลด์สามารถระบุได้โดยการตั้งค่าที่ระดับคลาส:

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)

ตัวแทนที่พิมพ์ได้

ExtensionType เพิ่มวิธีการแสดงการพิมพ์เริ่มต้น ( __repr__ ) ที่มีชื่อคลาสและค่าสำหรับแต่ละฟิลด์:

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>)

ตัวดำเนินการความเท่าเทียมกัน

ExtensionType เพิ่มตัวดำเนินการความเท่าเทียมกันเริ่มต้น ( __eq__ และ __ne__ ) ซึ่งถือว่าค่าสองค่าเท่ากันหากมีประเภทเดียวกันและฟิลด์ทั้งหมดเท่ากัน ฟิลด์เทนเซอร์จะถือว่าเท่ากันหากมีรูปร่างเหมือนกันและมีค่าเท่ากันสำหรับองค์ประกอบทั้งหมด

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True
a == b: False
a == a.values: False

วิธีการตรวจสอบ

ExtensionType เพิ่มเมธอด __validate__ ซึ่งสามารถแทนที่ได้เพื่อทำการตรวจสอบความถูกต้องบนฟิลด์ มันทำงานหลังจากเรียกตัวสร้าง และหลังจากที่ฟิลด์ได้รับการตรวจสอบประเภทและแปลงเป็นประเภทที่ประกาศแล้ว ดังนั้นจึงสามารถสันนิษฐานได้ว่าฟิลด์ทั้งหมดมีประเภทที่ประกาศ

เขาติดตามตัวอย่างอัปเดต MaskedTensor เพื่อตรวจสอบ shape s และ dtype s ของฟิลด์:

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # wrong dtype for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible

บังคับไม่เปลี่ยนรูป

ExtensionType จะแทนที่ __setattr__ และ __delattr__ เพื่อป้องกันการกลายพันธุ์ เพื่อให้แน่ใจว่าค่าของประเภทส่วนขยายจะไม่เปลี่ยนแปลง

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
ตัวยึดตำแหน่ง25 l10n-ตัวยึดตำแหน่ง
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.

ประเภทที่ซ้อนกันSpec

แต่ละคลาส ExtensionType มีคลาส TypeSpec ที่สอดคล้องกัน ซึ่งถูกสร้างขึ้นโดยอัตโนมัติและจัดเก็บเป็น <extension_type_name>.Spec

คลาสนี้รวบรวมข้อมูลทั้งหมดจากค่า ยกเว้น ค่าของเทนเซอร์ที่ซ้อนกัน โดยเฉพาะอย่างยิ่ง TypeSpec สำหรับค่าจะถูกสร้างขึ้นโดยการแทนที่ Tensor, ExtensionType หรือ CompositeTensor ที่ซ้อนกันด้วย TypeSpec

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records dtype and shape, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
TensorSpec(shape=(), dtype=tf.string, name=None)
ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})

ค่า TypeSpec สามารถสร้างได้อย่างชัดเจน หรือสามารถสร้างขึ้นจากค่า ExtensionType โดยใช้ tf.type_spec_from_value :

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpec ถูกใช้โดย TensorFlow เพื่อแบ่งค่าออกเป็น องค์ประกอบคงที่และองค์ประกอบ แบบไดนามิก :

  • องค์ประกอบแบบคงที่ (ซึ่งได้รับการแก้ไขในเวลาที่สร้างกราฟ) ถูกเข้ารหัสด้วย tf.TypeSpec
  • องค์ประกอบไดนามิก (ซึ่งสามารถเปลี่ยนแปลงได้ทุกครั้งที่รันกราฟ) ถูกเข้ารหัสเป็นรายการของ tf.Tensor

ตัวอย่างเช่น tf.function ย้อนฟังก์ชันที่ห่อหุ้มไว้เมื่อใดก็ตามที่อาร์กิวเมนต์มี TypeSpec ที่มองไม่เห็นก่อนหน้านี้:

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
ตัวยึดตำแหน่ง35
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))

สำหรับข้อมูลเพิ่มเติม โปรดดูที่ tf.function Guide

การปรับแต่งประเภทส่วนขยาย

นอกจากการประกาศฟิลด์และประเภทแล้ว ประเภทส่วนขยายอาจ:

  • แทนที่การแสดงการพิมพ์เริ่มต้น ( __repr__ )
  • กำหนดวิธีการ
  • กำหนด classmethods และ staticmethods
  • กำหนดคุณสมบัติ
  • แทนที่คอนสตรัคเตอร์เริ่มต้น ( __init__ )
  • แทนที่ตัวดำเนินการความเท่าเทียมกันเริ่มต้น ( __eq__ )
  • กำหนดตัวดำเนินการ (เช่น __add__ และ __lt__ )
  • ประกาศค่าเริ่มต้นสำหรับฟิลด์
  • กำหนดคลาสย่อย

การลบล้างการแสดงแทนการพิมพ์เริ่มต้น

คุณสามารถแทนที่โอเปอเรเตอร์การแปลงสตริงเริ่มต้นนี้สำหรับประเภทส่วนขยาย ตัวอย่างต่อไปนี้จะอัพเดตคลาส MaskedTensor เพื่อสร้างการแสดงสตริงที่อ่านง่ายขึ้นเมื่อค่าถูกพิมพ์ในโหมด Eager

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>
ตัวยึดตำแหน่ง39

การกำหนดวิธีการ

ประเภทส่วนขยายอาจกำหนดวิธีการ เช่นเดียวกับคลาส Python ทั่วไป ตัวอย่างเช่น ประเภท MaskedTensor สามารถกำหนดวิธีการ with_default ที่ส่งกลับสำเนาของ self ด้วยค่าที่ปิดบังแทนที่ด้วย default ที่กำหนด สามารถเลือกวิธีการเพิ่มเติมด้วย @tf.function decorator

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>
ตัวยึดตำแหน่ง41

การกำหนด classmethods และ staticmethods

ประเภทส่วนขยายอาจกำหนดวิธีการโดยใช้ตัวตกแต่ง @classmethod และ @staticmethod ตัวอย่างเช่น ประเภท MaskedTensor สามารถกำหนดวิธีการของโรงงานที่ปิดบังองค์ประกอบใดๆ ด้วยค่าที่กำหนด:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values == value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>
ตัวยึดตำแหน่ง43

การกำหนดคุณสมบัติ

ประเภทส่วนขยายอาจกำหนดคุณสมบัติโดยใช้ @property decorator เช่นเดียวกับคลาส Python ปกติ ตัวอย่างเช่น ประเภท MaskedTensor สามารถกำหนดคุณสมบัติ dtype ที่เป็นชวเลขสำหรับ dtype ของค่า:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32

เอาชนะคอนสตรัคเตอร์เริ่มต้น

คุณสามารถแทนที่คอนสตรัคเตอร์เริ่มต้นสำหรับประเภทส่วนขยายได้ ตัวสร้างแบบกำหนดเองต้องตั้งค่าสำหรับทุกฟิลด์ที่ประกาศ และหลังจากที่คอนสตรัคเตอร์แบบกำหนดเองกลับมา ฟิลด์ทั้งหมดจะถูกตรวจสอบประเภท และค่าจะถูกแปลงตามที่อธิบายไว้ข้างต้น

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

หรือคุณอาจพิจารณาปล่อยให้ตัวสร้างเริ่มต้นตามที่เป็นอยู่ แต่เพิ่มวิธีการจากโรงงานอย่างน้อยหนึ่งวิธี เช่น:

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

แทนที่ตัวดำเนินการความเท่าเทียมกันเริ่มต้น ( __eq__ )

คุณสามารถแทนที่ตัวดำเนินการ __eq__ เริ่มต้นสำหรับประเภทส่วนขยายได้ ตัวอย่างต่อไปนี้จะอัปเดต MaskedTensor เพื่อละเว้นองค์ประกอบที่ถูกปิดบังเมื่อเปรียบเทียบเพื่อความเท่าเทียมกัน

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)
ตัวยึดตำแหน่ง51

การใช้การอ้างอิงไปข้างหน้า

หากยังไม่ได้กำหนดประเภทของฟิลด์ คุณสามารถใช้สตริงที่มีชื่อประเภทนั้นแทนได้ ในตัวอย่างต่อไปนี้ สตริง "Node" ใช้เพื่อใส่คำอธิบาย children ให้กับฟิลด์ชายน์ เนื่องจากยังไม่ได้กำหนดประเภท Node (ทั้งหมด)

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))
ตัวยึดตำแหน่ง53

การกำหนดคลาสย่อย

ประเภทส่วนขยายอาจถูกจัดประเภทย่อยโดยใช้ไวยากรณ์ Python มาตรฐาน คลาสย่อยประเภทส่วนขยายอาจเพิ่มฟิลด์ เมธอด และคุณสมบัติใหม่ และอาจแทนที่คอนสตรัคเตอร์ การแทนแบบพิมพ์ได้ และตัวดำเนินการความเท่าเทียมกัน ตัวอย่างต่อไปนี้กำหนดคลาส TensorGraph พื้นฐานที่ใช้สามฟิลด์ Tensor เพื่อเข้ารหัสชุดของขอบระหว่างโหนด จากนั้นกำหนดคลาสย่อยที่เพิ่มฟิลด์ Tensor เพื่อบันทึก "ค่าคุณลักษณะ" สำหรับแต่ละโหนด คลาสย่อยยังกำหนดวิธีการเผยแพร่ค่าคุณลักษณะตามขอบ

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10.  0.  2.  5. -1.  0.], shape=(6,), dtype=float32)
After propagating: tf.Tensor([10. 12.  4.  4. -1.  0.], shape=(6,), dtype=float32)
ตัวยึดตำแหน่ง55

การกำหนดฟิลด์ส่วนตัว

ฟิลด์ของประเภทส่วนขยายอาจถูกทำเครื่องหมายเป็นส่วนตัวโดยนำหน้าด้วยเครื่องหมายขีดล่าง (ตามแบบแผน Python มาตรฐาน) สิ่งนี้ไม่กระทบต่อวิธีที่ TensorFlow ปฏิบัติต่อฟิลด์ในทางใดทางหนึ่ง แต่เพียงทำหน้าที่เป็นสัญญาณให้กับผู้ใช้ประเภทส่วนขยายว่าฟิลด์เหล่านี้เป็นแบบส่วนตัว

การปรับแต่ง TypeSpec ของ ExtensionType

แต่ละคลาส ExtensionType มีคลาส TypeSpec ที่สอดคล้องกัน ซึ่งถูกสร้างขึ้นโดยอัตโนมัติและจัดเก็บเป็น <extension_type_name>.Spec สำหรับข้อมูลเพิ่มเติม โปรดดูส่วน "Nested TypeSpec" ด้านบน

ในการปรับแต่ง TypeSpec เพียงแค่กำหนดคลาสที่ซ้อนกันของคุณชื่อ Spec และ ExtensionType จะใช้สิ่งนั้นเป็นพื้นฐานสำหรับ TypeSpec ที่สร้างขึ้นโดยอัตโนมัติ คุณสามารถปรับแต่งคลาส Spec ได้โดย:

  • การลบล้างการแสดงแทนการพิมพ์เริ่มต้น
  • การแทนที่คอนสตรัคเตอร์เริ่มต้น
  • การกำหนดเมธอด classmethods staticmethods และคุณสมบัติ

ตัวอย่างต่อไปนี้ปรับแต่งคลาส MaskedTensor.Spec เพื่อให้ใช้งานง่ายขึ้น:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

การส่ง Tensor API

ประเภทส่วนขยายสามารถ "เหมือนเทนเซอร์" ได้ ในแง่ที่ว่าพวกเขาเชี่ยวชาญหรือขยายอินเทอร์เฟซที่กำหนดโดยประเภท tf.Tensor ตัวอย่างของประเภทส่วนขยายที่เหมือนเทนเซอร์ ได้แก่ RaggedTensor , SparseTensor และ MaskedTensor ตัว ตกแต่ง Dispatch สามารถใช้เพื่อแทนที่การทำงานเริ่มต้นของการดำเนินการ TensorFlow เมื่อนำไปใช้กับประเภทส่วนขยายที่เหมือนเทนเซอร์ ปัจจุบัน TensorFlow กำหนดผู้ตกแต่งการจัดส่งสามคน:

ส่งสำหรับ API . เดียว

มัณฑนากร tf.experimental.dispatch_for_api จะแทนที่การทำงานเริ่มต้นของการดำเนินการ TensorFlow ที่ระบุ เมื่อถูกเรียกด้วยลายเซ็นที่ระบุ ตัวอย่างเช่น คุณสามารถใช้มัณฑนากรนี้เพื่อระบุว่า tf.stack ควรประมวลผลค่า MaskedTensor :

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

สิ่งนี้จะแทนที่การใช้งานเริ่มต้นสำหรับ tf.stack ทุกครั้งที่มีการเรียกด้วยรายการค่า MaskedTensor (เนื่องจากอาร์กิวเมนต์ values มีคำอธิบายประกอบด้วย typing.List[MaskedTensor] ):

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>

ในการอนุญาตให้ tf.stack จัดการรายการค่า MaskedTensor และ Tensor แบบผสม คุณสามารถปรับแต่งคำอธิบายประกอบประเภทสำหรับพารามิเตอร์ values และอัปเดตเนื้อหาของฟังก์ชันได้อย่างเหมาะสม:

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>

สำหรับรายการ API ที่สามารถแทนที่ได้ โปรดดูเอกสารประกอบ API สำหรับ tf.experimental.dispatch_for_api

จัดส่งสำหรับ unary elementwise APIs ทั้งหมด

มัณฑนากร tf.experimental.dispatch_for_unary_elementwise_apis จะแทนที่การทำงานเริ่มต้นของ ops ที่เป็นเอกเทศ ทั้งหมด (เช่น tf.math.cos ) เมื่อใดก็ตามที่ค่าสำหรับอาร์กิวเมนต์แรก (โดยทั่วไปจะมีชื่อว่า x ) ตรงกับประเภทหมายเหตุประกอบ x_type ฟังก์ชั่นที่ตกแต่งควรมีสองอาร์กิวเมนต์:

  • api_func : ฟังก์ชันที่รับพารามิเตอร์ตัวเดียวและดำเนินการตามองค์ประกอบ (เช่น tf.abs )
  • x : อาร์กิวเมนต์แรกของการดำเนินการตามองค์ประกอบ

ตัวอย่างต่อไปนี้จะอัพเดตการดำเนินการ unary elementwise ทั้งหมดเพื่อจัดการกับประเภท MaskedTensor :

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

ตอนนี้ฟังก์ชันนี้จะถูกใช้เมื่อใดก็ตามที่มีการเรียกการดำเนินการตามองค์ประกอบแบบเอกภาคบน MaskedTensor

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>

จัดส่งสำหรับไบนารี API ทุกองค์ประกอบ

ในทำนองเดียวกัน tf.experimental.dispatch_for_binary_elementwise_apis สามารถใช้เพื่ออัปเดตการดำเนินการไบนารี elementwise ทั้งหมดเพื่อจัดการกับประเภท MaskedTensor :

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>

สำหรับรายการ API ของ elementwise ที่ถูกแทนที่ โปรดดูเอกสารประกอบ API สำหรับ tf.experimental.dispatch_for_unary_elementwise_apis และ tf.experimental.dispatch_for_binary_elementwise_apis

ประเภทส่วนขยายแบบแบตช์ได้

ExtensionType สามารถ แบ ทช์ได้หากสามารถใช้อินสแตนซ์เดียวเพื่อแสดงชุดของค่าได้ โดยทั่วไป สามารถทำได้โดยการเพิ่มขนาดแบทช์ให้กับ Tensor ที่ซ้อนกันทั้งหมด TensorFlow APIs ต่อไปนี้ต้องการให้อินพุตประเภทส่วนขยายเป็นแบตช์ได้:

โดยค่าเริ่มต้น BatchableExtensionType จะสร้างค่าแบทช์โดยการแบทช์ Tensor ที่ซ้อนกัน , CompositeTensor s และ ExtensionType ที่ซ้อนกัน หากสิ่งนี้ไม่เหมาะกับชั้นเรียนของคุณ คุณจะต้องใช้ tf.experimental.ExtensionTypeBatchEncoder เพื่อแทนที่การทำงานเริ่มต้นนี้ ตัวอย่างเช่น ไม่เหมาะสมที่จะสร้างชุดของค่า tf.SparseTensor โดยเพียงแค่ซ้อนค่าเทนเซอร์แบบเบาบางแต่ละ values indices และ dense_shape ฟิลด์ - ในกรณีส่วนใหญ่ คุณไม่สามารถซ้อนเทนเซอร์เหล่านี้ได้ เนื่องจากพวกมันมีรูปร่างที่เข้ากันไม่ได้ ; และแม้ว่าคุณจะทำได้ ผลลัพธ์ก็ไม่ใช่ SparseTensor ที่ถูกต้อง

BatchableExtensionType ตัวอย่าง: Network

ตัวอย่างเช่น ให้พิจารณาคลาส Network แบบง่ายที่ใช้สำหรับการทำโหลดบาลานซ์ ซึ่งจะติดตามว่าแต่ละโหนดเหลืองานมากเพียงใด และแบนด์วิดท์เท่าใดที่พร้อมใช้งานเพื่อย้ายงานระหว่างโหนด:

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

ในการทำให้ประเภทนี้เป็นแบบแบตช์ได้ ให้เปลี่ยนประเภทฐานเป็น BatchableExtensionType และปรับรูปร่างของแต่ละฟิลด์เพื่อรวมมิติแบตช์ที่เป็นทางเลือก ตัวอย่างต่อไปนี้ยังเพิ่มฟิลด์ shape เพื่อติดตามรูปร่างของแบทช์ ฟิลด์ shape นี้ ไม่ ต้องการโดย tf.data.Dataset หรือ tf.map_fn แต่ tf.Keras จำเป็นต้องใช้

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

จากนั้น คุณสามารถใช้ tf.data.Dataset เพื่อวนซ้ำผ่านกลุ่มเครือข่าย:

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>

และคุณยังสามารถใช้ map_fn เพื่อใช้ฟังก์ชันกับแต่ละองค์ประกอบชุดงาน:

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

TensorFlow APIs ที่รองรับ ExtensionTypes

@tf.function

tf.function เป็นมัณฑนากรที่คำนวณล่วงหน้ากราฟ TensorFlow สำหรับฟังก์ชัน Python ซึ่งสามารถปรับปรุงประสิทธิภาพของโค้ด TensorFlow ของคุณได้อย่างมาก ค่าประเภทส่วนขยายสามารถใช้ได้อย่างโปร่งใสด้วยฟังก์ชัน @tf.function -decorated

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

หากคุณต้องการระบุ input_signature สำหรับ tf.function อย่างชัดเจน คุณสามารถทำได้โดยใช้ TypeSpec ของประเภทส่วนขยาย

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)

ฟังก์ชั่นคอนกรีต

ฟังก์ชันที่เป็นรูปธรรมจะห่อหุ้มกราฟที่ลากเส้นแต่ละรายการที่สร้างขึ้นโดย tf.function สามารถใช้ชนิดต่อขยายได้อย่างโปร่งใสด้วยฟังก์ชันที่เป็นรูปธรรม

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

ควบคุมการทำงานของโฟลว์

ประเภทส่วนขยายได้รับการสนับสนุนโดยการดำเนินการควบคุมการไหลของ TensorFlow:

# Example: using tf.cond to select between two MaskedTensors.  Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>

โฟลว์การควบคุมลายเซ็น

ประเภทส่วนขยายยังได้รับการสนับสนุนโดยคำสั่งควบคุมโฟลว์ใน tf.function (โดยใช้ลายเซ็นต์) ในตัวอย่างต่อไปนี้ คำสั่ง if และ for statement จะถูกแปลงเป็นการดำเนินการ tf.cond และ tf.while_loop โดยอัตโนมัติ ซึ่งสนับสนุนประเภทส่วนขยาย

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]>
<MaskedTensor [4.5, _, 6.5]>

Keras

tf.keras เป็น API ระดับสูงของ TensorFlow สำหรับการสร้างและฝึกอบรมโมเดลการเรียนรู้เชิงลึก ประเภทส่วนขยายอาจถูกส่งผ่านเป็นอินพุตไปยังโมเดล Keras ส่งผ่านระหว่างเลเยอร์ Keras และส่งคืนโดยโมเดล Keras ปัจจุบัน Keras กำหนดข้อกำหนดสองประการสำหรับประเภทส่วนขยาย:

  • ต้องเป็นแบบแบตช์ได้ (ดู "Batchable ExtensionTypes" ด้านบน)
  • ต้องมีฟิลด์หรือคุณสมบัติชื่อ shape shape[0] ถือว่าเป็นมิติชุดงาน

สองส่วนย่อยต่อไปนี้แสดงตัวอย่างว่าประเภทส่วนขยายสามารถใช้กับ Keras ได้อย่างไร

ตัวอย่าง Keras: Network

สำหรับตัวอย่างแรก ให้พิจารณาคลาส Network ที่กำหนดไว้ในส่วน "Batchable ExtensionTypes" ด้านบน ซึ่งสามารถใช้สำหรับการทำโหลดบาลานซ์ระหว่างโหนด คำจำกัดความของมันถูกทำซ้ำที่นี่:

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network w/ 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

คุณสามารถกำหนดเลเยอร์ Keras ใหม่ที่ประมวลผล Network s

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above, in "Batchable ExtensionTypes" section.
    return balance_work_greedy(inputs)

จากนั้นคุณสามารถใช้เลเยอร์นี้เพื่อสร้างแบบจำลองอย่างง่าย ในการป้อน ExtensionType ลงในโมเดล คุณสามารถใช้เลเยอร์ tf.keras.layer.Input โดยตั้งค่า type_spec เป็น TypeSpec ของประเภทส่วนขยาย หากจะใช้โมเดล Keras เพื่อประมวลผลชุดงาน ดังนั้น type_spec จะต้องรวมมิติชุดงานด้วย

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

สุดท้าย คุณสามารถใช้โมเดลกับเครือข่ายเดียวและกับกลุ่มเครือข่ายได้

model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>

ตัวอย่าง Keras: MaskedTensor

ในตัวอย่างนี้ MaskedTensor ถูกขยายเพื่อรองรับ Keras shape ถูกกำหนดเป็นคุณสมบัติที่คำนวณจากฟิลด์ values Keras ต้องการให้คุณเพิ่มคุณสมบัตินี้ให้กับทั้งประเภทส่วนขยายและ TypeSpec MaskedTensor ยังกำหนดตัวแปร __name__ ซึ่งจะจำเป็นสำหรับการทำให้เป็นอนุกรมของ SavedModel (ด้านล่าง)

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

ถัดไป ตัวตกแต่งการจัดส่งจะถูกใช้เพื่อแทนที่การทำงานเริ่มต้นของ TensorFlow API หลายตัว เนื่องจาก API เหล่านี้ถูกใช้โดยเลเยอร์ Keras มาตรฐาน (เช่น เลเยอร์ Dense ) การแทนที่สิ่งเหล่านี้จะทำให้เราใช้เลเยอร์เหล่านั้นกับ MaskedTensor สำหรับจุดประสงค์ของตัวอย่างนี้ matmul สำหรับเทนเซอร์ที่ปิดบังถูกกำหนดให้ถือว่าค่าที่มาสก์เป็นศูนย์ (กล่าวคือ ไม่รวมค่าเหล่านี้ในผลิตภัณฑ์)

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

จากนั้นคุณสามารถสร้างโมเดล Keras ที่ยอมรับอินพุต MaskedTensor โดยใช้เลเยอร์ Keras มาตรฐาน:

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3
1/1 [==============================] - 1s 955ms/step - loss: 10.2833
Epoch 2/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
Epoch 3/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
tf.Tensor(
[[-0.09944128]
 [-0.7225147 ]
 [-1.3020657 ]], shape=(3, 1), dtype=float32)

รูปแบบที่บันทึกไว้

SavedModel เป็นโปรแกรม TensorFlow แบบอนุกรม ซึ่งรวมถึงทั้งน้ำหนักและการคำนวณ สามารถสร้างได้จากโมเดล Keras หรือจากโมเดลแบบกำหนดเอง ไม่ว่าในกรณีใด ประเภทส่วนขยายสามารถใช้ได้อย่างโปร่งใสด้วยฟังก์ชันและเมธอดที่กำหนดโดย SavedModel

SavedModel สามารถบันทึกโมเดล เลเยอร์ และฟังก์ชันที่ประมวลผลประเภทส่วนขยาย ตราบใดที่ประเภทส่วนขยายมีฟิลด์ __name__ ชื่อนี้ใช้เพื่อลงทะเบียนประเภทส่วนขยาย จึงสามารถระบุตำแหน่งได้เมื่อโหลดแบบจำลอง

ตัวอย่าง: การบันทึกโมเดล Keras

โมเดล Keras ที่ใช้ประเภทส่วนขยายอาจถูกบันทึกโดยใช้ SavedModel

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-0.09944128],
       [-0.7225147 ],
       [-1.3020657 ]], dtype=float32)>

ตัวอย่าง: การบันทึกโมเดลแบบกำหนดเอง

SavedModel ยังสามารถใช้เพื่อบันทึกคลาสย่อย tf.Module ที่กำหนดเองด้วยฟังก์ชันที่ประมวลผลประเภทส่วนขยาย

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
<MaskedTensor [_, 200.0, _]>

กำลังโหลด SavedModel เมื่อ ExtensionType ไม่พร้อมใช้งาน

หากคุณโหลด SavedModel ที่ใช้ ExtensionType แต่ไม่มี ExtensionType นั้น (เช่น ยังไม่ได้นำเข้า) คุณจะเห็นคำเตือนและ TensorFlow จะถอยกลับไปใช้อ็อบเจ็กต์ "ประเภทส่วนขยายที่ไม่ระบุชื่อ" ออบเจ็กต์นี้จะมีฟิลด์เดียวกันกับประเภทดั้งเดิม แต่จะไม่มีการปรับแต่งเพิ่มเติมใดๆ ที่คุณได้เพิ่มสำหรับประเภทดังกล่าว เช่น วิธีการหรือคุณสมบัติแบบกำหนดเอง

การใช้ ExtensionTypes กับการให้บริการ TensorFlow

ในปัจจุบัน การ ให้บริการ TensorFlow (และผู้บริโภครายอื่นๆ ของพจนานุกรม "ลายเซ็น" ของ SavedModel) กำหนดให้อินพุตและเอาต์พุตทั้งหมดเป็นเมตริกซ์แบบดิบ หากคุณต้องการใช้ TensorFlow ที่ให้บริการกับโมเดลที่ใช้ประเภทส่วนขยาย คุณสามารถเพิ่มเมธอดของ wrapper ที่เขียนหรือแยกส่วนค่าของประเภทส่วนขยายจากเทนเซอร์ เช่น:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
<tf.Tensor: shape=(), dtype=float32, numpy=12.0>

ชุดข้อมูล

tf.data เป็น API ที่ช่วยให้คุณสร้างไพพ์ไลน์อินพุตที่ซับซ้อนจากชิ้นส่วนที่เรียบง่ายและนำกลับมาใช้ใหม่ได้ โครงสร้างข้อมูลหลักของมันคือ tf.data.Dataset ซึ่งแสดงถึงลำดับขององค์ประกอบ ซึ่งแต่ละองค์ประกอบประกอบด้วยองค์ประกอบตั้งแต่หนึ่งองค์ประกอบขึ้นไป

การสร้างชุดข้อมูลด้วยประเภทส่วนขยาย

ชุดข้อมูลสามารถสร้างจากค่าประเภทส่วนขยายได้โดยใช้ Dataset.from_tensors , Dataset.from_tensor_slices หรือ Dataset.from_generator :

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
<MaskedTensor [0, 1, 2, 3]>
<MaskedTensor [4, 5, 6, 7]>
<MaskedTensor [8, 9, 10, 11]>
<MaskedTensor [12, 13, 14, 15]>
<MaskedTensor [16, 17, 18, 19]>
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>

แบทช์และเลิกแบทช์ชุดข้อมูลที่มีประเภทส่วนขยาย

ชุดข้อมูลที่มีประเภทส่วนขยายสามารถแบทช์และยกเลิกการแบทช์ได้โดยใช้ Dataset.batch และ Dataset.unbatch

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]>
<MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]>
<MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>