יום הקהילה ML הוא 9 בנובמבר! הצטרפו אלינו עדכונים מ- TensorFlow, JAX, ועוד למידע נוסף

זרימת טנסור :: אופ :: GatherV2

#include <array_ops.h>

אסוף פרוסות מציר axis של params לפי indices .

סיכום

indices חייבים להיות טנזור שלם של כל מימד (בדרך כלל 0-D או 1-D). מייצר טנסור פלט עם צורה params.shape[:axis] + indices.shape + params.shape[axis + 1:] שם:

    # Scalar indices (output is rank(params) - 1).
    output[a_0, ..., a_n, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices, b_0, ..., b_n]

    # Vector indices (output is rank(params)).
    output[a_0, ..., a_n, i, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices[i], b_0, ..., b_n]

    # Higher rank indices (output is rank(params) + rank(indices) - 1).
    output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
      params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]

שים לב שבמעבד, אם נמצא אינדקס מחוץ לתחום, מוחזרת שגיאה. ב- GPU, אם נמצא אינדקס מחוץ לתחום, 0 נשמר בערך הפלט המתאים.

ראה גם tf.batch_gather ו- tf.gather_nd .

טענות:

  • היקף: אובייקט Scope
  • params: הטנסור שממנו ניתן לאסוף ערכים. חייב להיות לפחות axis + 1 דרגה axis + 1 .
  • מדדים: טנסור אינדקס. חייב להיות בטווח [0, params.shape[axis]) .
  • ציר: הציר params לאיסוף indices . ברירת מחדל לממד הראשון. תומך באינדקס שלילי.

החזרות:

  • Output : ערכים params שנאספו ממדדים הניתנים על ידי indices , עם צורה params.shape[:axis] + indices.shape + params.shape[axis + 1:] .

קונסטרוקטורים ומשחתנים

GatherV2 (const :: tensorflow::Scope & scope, :: tensorflow::Input params, :: tensorflow::Input indices, :: tensorflow::Input axis)
GatherV2 (const :: tensorflow::Scope & scope, :: tensorflow::Input params, :: tensorflow::Input indices, :: tensorflow::Input axis, const GatherV2::Attrs & attrs)

תכונות ציבוריות

operation
output

פונקציות ציבוריות

node () const
::tensorflow::Node *
operator::tensorflow::Input () const
operator::tensorflow::Output () const

פונקציות סטטיות ציבוריות

BatchDims (int64 x)

סטרוקטורים

tensorflow :: ops :: GatherV2 :: Attrs

קובעי תכונות אופציונליים עבור GatherV2 .

תכונות ציבוריות

מבצע

Operation operation

תְפוּקָה

::tensorflow::Output output

פונקציות ציבוריות

GatherV2

 GatherV2(
  const ::tensorflow::Scope & scope,
  ::tensorflow::Input params,
  ::tensorflow::Input indices,
  ::tensorflow::Input axis
)

GatherV2

 GatherV2(
  const ::tensorflow::Scope & scope,
  ::tensorflow::Input params,
  ::tensorflow::Input indices,
  ::tensorflow::Input axis,
  const GatherV2::Attrs & attrs
)

צוֹמֶת

::tensorflow::Node * node() const 

אופרטור :: זרימת טנסור :: קלט

 operator::tensorflow::Input() const 

אופרטור :: זרימת טנסור :: פלט

 operator::tensorflow::Output() const 

פונקציות סטטיות ציבוריות

BatchDims

Attrs BatchDims(
  int64 x
)