หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

tf.ensure_shape

TensorFlow 1 รุ่น ดูโค้ดบน GitHub

ปรับปรุงรูปร่างของเมตริกซ์และการตรวจสอบที่รันไทม์ที่รูปร่างถือ

ด้วยการดำเนินการกระตือรือร้นที่นี้คือการยืนยันรูปร่างที่ส่งกลับการป้อนข้อมูล:

x = tf.constant([1,2,3])
print(x.shape)
(3,)
x = tf.ensure_shape(x, [3])
x = tf.ensure_shape(x, [5])
Traceback (most recent call last):

tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
  compatible with expected shape [5]. [Op:EnsureShape]

ภายใน tf.function หรือ v1.Graph บริบทมันจะตรวจสอบทั้ง buildtime และรูปร่างรันไทม์ นี่คือเข้มงวดกว่า tf.Tensor.set_shape ซึ่งตรวจสอบเฉพาะรูปร่าง buildtime

ยกตัวอย่างเช่นการโหลดภาพที่มีขนาดที่รู้จักกัน:

@tf.function
def decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape: ", image.shape)
  image = tf.ensure_shape(image,[28, 28, 3])
  print("Final shape: ", image.shape)
  return image

เมื่อติดตามฟังก์ชั่นไม่ Ops มีการดำเนินการรูปร่างอาจจะไม่รู้จัก ดู ฟังก์ชั่นคอนกรีตคู่มือ สำหรับรายละเอียด

concrete_decode = decode_image.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape:  (None, None, 3)
Final shape:  (28, 28, 3)
image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
image = tf.cast(image,tf.uint8)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
print(image2.shape)
(28, 28, 3)
image = tf.concat([image,image], axis=0)
print(image.shape)
(56, 28, 3)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
Traceback (most recent call last):

tf.errors.InvalidArgumentError:  Shape of tensor DecodePng [56,28,3] is not
  compatible with expected shape [28,28,3].
@tf.function
def bad_decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape: ", image.shape)
  # BAD: forgot to use the returned tensor.
  tf.ensure_shape(image,[28, 28, 3])
  print("Final shape: ", image.shape)
  return image
image = bad_decode_image(png)
Initial shape:  (None, None, 3)
Final shape:  (None, None, 3)
print(image.shape)
(56, 28, 3)

x Tensor
shape TensorShape ตัวแทนรูปร่างของเมตริกซ์นี้ TensorShapeProto , รายการทูเปิลหรือไม่
name ชื่อสำหรับการดำเนินการนี้ A (อุปกรณ์เสริม) เริ่มต้นที่ "EnsureShape"

Tensor มีชนิดเดียวกันและเนื้อหาเป็น x ที่รันไทม์ยก tf.errors.InvalidArgumentError ถ้า shape ไม่เข้ากันกับรูปร่างของ x