運算語意

以下說明 XlaBuilder 介面中定義的作業語意。一般來說,這些作業會一對一對應至 xla_data.proto 中 RPC 介面中定義的作業。

命名句的注意事項:XLA 交易的一般化資料類型是一種 N 維陣列,其中包含某種統一類型的元素 (例如 32 位元浮點)。在本說明文件中,「陣列」用來表示任意維度陣列。為方便起見,特殊情況會有更明確且熟悉的名稱。例如,向量是 1 維陣列,矩陣則是 2 維陣列。

AfterAll

另請參閱 XlaBuilder::AfterAll

afterAll 採用各種不同的符記,並產生單一符記。權杖是原始類型,可在連帶效果作業之間建立執行緒,以便強制執行排序。AfterAll 可用來在集合運算之後,用於訂購作業的符記彙整。

AfterAll(operands)

引數 類型 語義
operands XlaOp 符記數量

AllGather

另請參閱 XlaBuilder::AllGather

跨備用資源執行串連。

AllGather(operand, all_gather_dim, shard_count, replica_group_ids, channel_id)

引數 類型 語義
operand XlaOp 在不同備用資源之間串連的陣列
all_gather_dim int64 串連維度
replica_groups int64 的向量向量 執行串連作業的群組
channel_id 選用 int64 跨模組通訊的選用管道 ID
  • replica_groups 是執行串連作業的備用資源群組清單 (可使用 ReplicaId 擷取目前備用資源的備用資源 ID)。每個群組中的備用資源順序會決定其輸入內容在結果中的顯示順序。replica_groups 必須是空白 (在這種情況下,所有備用資源都屬於單一群組,並以 0N - 1 的順序排列),或是包含與備用資源數量相同的元素數量。例如,replica_groups = {0, 2}, {1, 3} 會執行備用資源 02,以及 13 之間的串連。
  • shard_count 是每個備用資源群組的大小。如果 replica_groups 空白,就需要使用此方法。
  • channel_id 用於跨模組通訊:只有具有相同 channel_idall-gather 作業可以相互通訊。

輸出形狀是 all_gather_dim 的輸入形狀,使得 shard_count 倍變大。舉例來說,如果兩個備用資源有 [1.0, 2.5][3.0, 5.25] 值,而運算元在兩個備用資源上分別具有 [1.0, 2.5][3.0, 5.25] 值,則這個運算的輸出值,其中 all_gather_dim 是兩個備用資源上的 0[1.0, 2.5, 3.0, 5.25]

AllReduce

另請參閱 XlaBuilder::AllReduce

跨備用資源執行自訂運算。

AllReduce(operand, computation, replica_group_ids, channel_id)

引數 類型 語義
operand XlaOp 使用陣列或非空白元組來減少備用資源中的數量
computation XlaComputation 縮減運算
replica_groups int64 的向量向量 執行縮減作業之間的群組
channel_id 選用 int64 跨模組通訊的選用管道 ID
  • operand 是陣列的元組時,系統會對元組的每個元素執行 all-reduce。
  • replica_groups 是執行縮減作業的備用資源群組清單 (可使用 ReplicaId 擷取目前備用資源的備用資源 ID)。replica_groups 必須是空白 (在這種情況下,所有備用資源都屬於單一群組),或包含與備用資源數量相同的元素數量。例如,replica_groups = {0, 2}, {1, 3} 會在備用資源 02 之間,以及 13 之間減少。
  • channel_id 用於跨模組通訊:只有具有相同 channel_idall-reduce 作業可以相互通訊。

輸出形狀與輸入形狀相同。舉例來說,如果有兩個備用資源,且運算元在兩個備用資源上分別具有 [1.0, 2.5][3.0, 5.25] 值,則這個運算和加總運算的輸出值在兩個備用資源上都會是 [4.0, 7.75]。如果輸入是元組,則輸出結果也會是元組。

計算 AllReduce 的結果需要每個備用資源都有一個輸入內容,因此如果其中一個備用資源執行 AllReduce 節點多次,則正式的備用資源將一直等待。由於備用資源全都執行同一個程式,因此運作的方式並不多,但當迴圈的條件取決於傳入的資料,則可能的迴圈條件取決於傳入的資料,以及加載的資料會讓迴圈的資料在某個備用資源上進行多次疊代。

AllToAll

另請參閱 XlaBuilder::AllToAll

AllToAll 是集合作業,可將資料從所有核心傳送至所有核心。其中包含兩個階段:

  1. 散佈階段。在每個核心上,運算元都會沿著 split_dimensions 分割成 split_count 個區塊,且區塊會分散至所有核心,例如將第 i 個區塊傳送至第 i 個核心。
  2. 收集階段。每個核心都會在 concat_dimension 上串連收到的區塊。

可透過下列方式設定參與的核心:

  • replica_groups:每個 ReplicaGroup 包含進行運算的備用資源 ID 清單 (您可以使用 ReplicaId 擷取目前備用資源的備用資源 ID)。AllToAll 會依指定順序套用在子群組中。舉例來說,replica_groups = { {1,2,3}, {4,5,0} } 表示 AllToAll 會在備用資源 {1, 2, 3} 和收集階段中套用,而收到的區塊會以相同的 1、2、3 的順序串連。然後,在備用資源 4、5、0 中會套用另一個 AllToAll,而串連順序也是 4、5、0。如果 replica_groups 空白,所有備用資源就會依照其外觀的串連順序,屬於一個群組。

必備條件:

  • split_dimension 上運算元的維度大小可由 split_count 除盡。
  • 運算元的形狀不是元組。

AllToAll(operand, split_dimension, concat_dimension, split_count, replica_groups)

引數 類型 語義
operand XlaOp N 維輸入陣列
split_dimension int64 為維度命名且運算元分割的間隔 [0, n) 中的值
concat_dimension int64 間隔 [0, n) 中的值,該值會為維度命名,並將分割區塊串連
split_count int64 參與這項作業的核心數量。如果 replica_groups 空白,這個數量應為備用資源數量;否則,這應等於各群組中的備用資源數量。
replica_groups ReplicaGroup 向量 每個群組都包含備用資源 ID 清單。

以下是 Alltoall 的範例。

XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);

在這個範例中,有 4 個核心參與 Alltoall。在每個核心上,運算元會按照維度 0 分為 4 個部分,因此每個部分都有形狀 f32[4,4]。這 4 個部分會分散至所有核心。接著,每個核心會按照維度 1 的順序,將收到的零件串連為 0 到 4。因此每個核心的輸出內容都有 f32[16,4] 形狀。

BatchNormGrad

如要進一步瞭解演算法,另請參閱 XlaBuilder::BatchNormGrad原始批次正規化論文

計算批次規範的梯度。

BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)

引數 類型 語義
operand XlaOp 要正規化的 n 維陣列 (x)
scale XlaOp 1 個維度陣列 (\(\gamma\))
mean XlaOp 1 個維度陣列 (\(\mu\))
variance XlaOp 1 個維度陣列 (\(\sigma^2\))
grad_output XlaOp 傳送至 BatchNormTraining 的梯度 (\(\nabla y\))
epsilon float Epsilon 值 (\(\epsilon\))
feature_index int64 operand中的特徵維度索引

對於特徵維度 (feature_indexoperand 中特徵維度的索引) 中的每個特徵,運算會使用所有其他維度的 operandoffsetscale 計算漸層。feature_index 必須是 operand 中特徵維度的有效索引。

三個漸層是由下列公式定義 (假設 4D 陣列為 operand,且特徵維度索引為 l,批次大小 m 和空間大小 wh):

\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]

輸入內容 meanvariance 代表批次和空間維度的時刻值。

輸出類型是三個控點的元組:

輸出 類型 語義
grad_operand XlaOp 輸入 operand 的漸層 ($\nabla x$)
grad_scale XlaOp 輸入 scale 的漸層 ($\nabla \gamma$)
grad_offset XlaOp 與輸入 offset($\nabla \beta$) 的相對漸層

BatchNormInference

如要進一步瞭解演算法,另請參閱 XlaBuilder::BatchNormInference原始批次正規化論文

將批次和空間維度的陣列正規化。

BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)

引數 類型 語義
operand XlaOp 要正規化的 n 維陣列
scale XlaOp 1 個維度陣列
offset XlaOp 1 個維度陣列
mean XlaOp 1 個維度陣列
variance XlaOp 1 個維度陣列
epsilon float Epsilon 值
feature_index int64 operand中的特徵維度索引

對於特徵維度中的每個地圖項目 (feature_indexoperand 中特徵維度的索引),運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數將 operand 中的每個元素正規化。feature_index 必須是 operand 中地圖項目維度的有效索引。

BatchNormInference 相當於呼叫 BatchNormTraining,而不計算每個批次的 meanvariance。而是使用 meanvariance 做為預估值。這個運算的目的在於減少推論的延遲時間,因此名為 BatchNormInference

輸出內容是 N 維正規化陣列,其形狀與輸入 operand 相同。

BatchNormTraining

如要進一步瞭解演算法,另請參閱 XlaBuilder::BatchNormTrainingthe original batch normalization paper

將批次和空間維度的陣列正規化。

BatchNormTraining(operand, scale, offset, epsilon, feature_index)

引數 類型 語義
operand XlaOp 要正規化的 n 維陣列 (x)
scale XlaOp 1 個維度陣列 (\(\gamma\))
offset XlaOp 1 個維度陣列 (\(\beta\))
epsilon float Epsilon 值 (\(\epsilon\))
feature_index int64 operand中的特徵維度索引

對於特徵維度中的每個地圖項目 (feature_indexoperand 中特徵維度的索引),運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數將 operand 中的每個元素正規化。feature_index 必須是 operand 中地圖項目維度的有效索引。

operand \(x\) 中的每個批次都會以下列方式處理,其中包含 m 元素,並以 wh 做為空間維度的大小 (假設 operand 是 4 個維度陣列):

  • 計算特徵維度中每個特徵 l \(\mu_l\) 的批次平均數: \(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)

  • 計算批次變異數 \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$

  • 正規化、縮放及位移:\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)

加入 Epsilon 值 (通常為小數值),以免發生由零分除的錯誤。

輸出類型是三個 XlaOp 的元組:

輸出 類型 語義
output XlaOp 與輸入 operand 具有相同形狀的 n 個維度陣列 (y)
batch_mean XlaOp 1 個維度陣列 (\(\mu\))
batch_var XlaOp 1 個維度陣列 (\(\sigma^2\))

batch_meanbatch_var 是使用上述公式,針對批次和空間維度計算所得的結果。

BitcastConvertType

另請參閱 XlaBuilder::BitcastConvertType

與 TensorFlow 中的 tf.bitcast 類似,可執行從資料形狀到目標形狀的元素相關位元投放作業。輸入和輸出大小必須相符:例如,s32 元素會透過 Bitcast 處理常式成為 f32 元素,而一個 s32 元素會成為四個 s8 元素。Bitcast 是以低階轉換方式實作,因此具有不同浮點表示法的機器會得到不同的結果。

BitcastConvertType(operand, new_element_type)

引數 類型 語義
operand XlaOp T 型陣列,帶有暗色 D
new_element_type PrimitiveType U 型

運算元尺寸和目標形狀必須相符,除了最後一個維度之外,最終維度也會依轉換前後的原始大小比率而改變。

來源與目的地元素類型不得為元組。

從位元轉換為不同寬度的原始類型

BitcastConvert HLO 指令可支援輸出元素類型 T' 的大小不等於輸入元素 T 的大小。由於整個運算概念上來說是位元轉換,不會變更基礎位元組,因此輸出元素的形狀必須改變。B = sizeof(T), B' = sizeof(T') 有兩種可能情況。

首先,當 B > B' 時,輸出形狀會取得大小最為 B/B' 的新尺寸。例如:

  f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)

有效純量的規則也維持不變:

  f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)

或者,如果是 B' > B,則指示要求輸入形狀的最後一個邏輯維度等於 B'/B,且在轉換期間會捨棄這個維度:

  f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)

請注意,不同位元寬度之間的轉換不會影響元素。

廣播

另請參閱 XlaBuilder::Broadcast

複製陣列中的資料,將維度加入陣列。

Broadcast(operand, broadcast_sizes)

引數 類型 語義
operand XlaOp 要複製的陣列
broadcast_sizes ArraySlice<int64> 新維度的大小

系統會在左側插入新維度,也就是說,如果 broadcast_sizes 的值為 {a0, ..., aN},且運算元形狀維度為 {b0, ..., bM},則輸出形狀的維度為 {a0, ..., aN, b0, ..., bM}

新的維度索引會編入運算元的副本,即

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

舉例來說,如果 operand 是值為 2.0f 的純量 f32,而 broadcast_sizes{2, 3},結果將會是形狀為 f32[2, 3] 的陣列,且結果中的所有值都是 2.0f

BroadcastInDim

另請參閱 XlaBuilder::BroadcastInDim

複製陣列中的資料,以展開陣列的大小和排名。

BroadcastInDim(operand, out_dim_size, broadcast_dimensions)

引數 類型 語義
operand XlaOp 要複製的陣列
out_dim_size ArraySlice<int64> 目標形狀的尺寸
broadcast_dimensions ArraySlice<int64> 各個運算元形狀各維度對應的目標維度

與廣播類似,但可在任意位置新增尺寸,並展開大小為 1 的現有尺寸。

operand 會廣播到 out_dim_size 描述的形狀。broadcast_dimensions 會將 operand 的維度對應至目標形狀的尺寸,即運算元的第 個維度會對應到輸出形狀的 broadcast_dimension[i] 這個維度。operand 的尺寸必須有 1 的大小,或是與對應輸出形狀中的維度大小相同。其餘尺寸的尺寸為 1。拆解維度廣播,然後沿著這些產生的維度播送,以達到輸出形狀。如要進一步瞭解語意,請參閱廣播頁面

撥打電話

另請參閱 XlaBuilder::Call

使用指定引數叫用計算。

Call(computation, args...)

引數 類型 語義
computation XlaComputation T_0, T_1, ..., T_{N-1} -> S 類型的計算,以及 N 個任意類型的參數
args N XlaOp 秒序列 N 個任意類型的引數

args 的順序和類型必須與 computation 的參數相符。不得含有args

孔雀

另請參閱 XlaBuilder::Cholesky

計算一批對稱式 (Hermitian) 明確矩陣的 Cholesky 解譯

Cholesky(a, lower)

引數 類型 語義
a XlaOp 排名超過 2 陣列的複雜或浮點類型
lower bool 是否要使用 a 的頂端或下三角形。

如果 lowertrue,會計算低三角形矩陣 l,因此 $a = l 。l^T$。如果 lowerfalse,就會計算上角矩陣 u,以達到\(a = u^T . u\)。

lower 的值而定,系統只會從 a 的下方/上三角形讀取輸入資料。系統會忽略其他三角形的值。輸出資料會在同一個三角形中傳回;其他三角形的值則是實作定義,可以是任何值。

如果 a 的排名大於 2,系統會將 a 視為一批矩陣,其中次要維度除外的 2 個維度均為批次維度。

如果 a 不對稱度 (Hermitian) 明確,則結果是實作定義。

夾式

另請參閱 XlaBuilder::Clamp

將運算元限制在最小值和最大值之間的範圍內。

Clamp(min, operand, max)

引數 類型 語義
min XlaOp T 類型的陣列
operand XlaOp T 類型的陣列
max XlaOp T 類型的陣列

針對特定運算元以及最小值與最大值之間,如果運算元位於最小值和最大值之間,系統會傳回該運算元;如果運算元小於這個範圍,則傳回最小值;如果運算元高於這個範圍,則傳回最大值。也就是 clamp(a, x, b) = min(max(a, x), b)

這三個陣列的形狀都必須相同。或者,做為 廣播的受限形式,min 和/或 max 可以是 T 類型的純量。

純量 minmax 範例:

let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};

收合

另請參閱 XlaBuilder::Collapsetf.reshape 作業。

將陣列的尺寸收合為單一維度。

Collapse(operand, dimensions)

引數 類型 語義
operand XlaOp T 類型的陣列
dimensions int64 向量 按順序排列且連續的 T 維度子集。

收合會以單一維度取代指定運算元維度的特定部分。輸入引數是 T 類型的任意陣列,以及維度索引的編譯時間常數向量。維度索引必須依序排列 (低到高維度數字),以及 T 維度的連續子集。因此,{0、1、2}、{0、1} 或 {1, 2} 都是有效的維度集,但 {1, 0} 或 {0, 2} 無效。而是會由單一新維度取代,且維度序列中的對應維度位置相同,新的尺寸大小與原始尺寸大小的產物相等。dimensions 中的最低維度數字是迴圈巢狀結構中最慢的變化維度 (最主要的),其會收合這些維度,而最大維度編號最大的差異是最大的 (最輕微)。如果需要更一般的收合順序,請參閱 tf.reshape 運算子。

舉例來說, let v 是 24 個元素的陣列:

let v = f32[4x2x3] { { {10, 11, 12},  {15, 16, 17} },
{ {20, 21, 22},  {25, 26, 27} },
{ {30, 31, 32},  {35, 36, 37} },
{ {40, 41, 42},  {45, 46, 47} } };

// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};

// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };

// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };

CollectivePermute

另請參閱 XlaBuilder::CollectivePermute

CollectivePermute 是跨備用資源傳送及接收資料的集合作業。

CollectivePermute(operand, source_target_pairs)

引數 類型 語義
operand XlaOp N 維輸入陣列
source_target_pairs <int64, int64> 向量 (source_copy_id、 target_不僅_id) 的組合清單。對於每個組合,運算元會從來源備用資源傳送至目標備用資源。

請注意,source_target_pair 設有下列限制:

  • 任兩個組合的目標備用資源 ID 不可相同,而且不得有相同的來源備用資源 ID。
  • 如果備用資源 ID 不是任何組合的目標,則該備用資源的輸出會是由 0 組成且形狀與輸入內容相同的張量。

串連

另請參閱 XlaBuilder::ConcatInDim

串連可將多個陣列運算元的陣列組成。陣列是與各個輸入陣列運算元相同的排名 (兩者的排名必須相同),並包含引數的指定順序。

Concatenate(operands..., dimension)

引數 類型 語義
operands N XlaOp 序列 類型 T 的 N 陣列含有尺寸 [L0, L1, ...]。必須要有 N >= 1。
dimension int64 在間隔 [0, N) 中的值,用來命名要在 operands 之間串連的維度。

除了 dimension 以外,所有維度都必須相同。這是因為 XLA 不支援「ragged」陣列。另請注意,排名-0 的值無法串連 (因為無法同時為發生串連的維度命名)。

1 維範例:

Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}

2D 範例:

let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}

圖表:

須符合條件

另請參閱 XlaBuilder::Conditional

Conditional(pred, true_operand, true_computation, false_operand, false_computation)

引數 類型 語義
pred XlaOp PRED 類型的純量
true_operand XlaOp \(T_0\)類型的引數
true_computation XlaComputation \(T_0 \to S\)類型的 XlaComputation
false_operand XlaOp \(T_1\)類型的引數
false_computation XlaComputation \(T_1 \to S\)類型的 XlaComputation

如果 predtrue,就會執行 true_computation;如果 predfalse,則執行 false_computation,然後傳回結果。

true_computation 必須納入類型為 \(T_0\) 的單一引數,並使用 true_operand (必須屬於相同類型) 叫用。false_computation 必須納入類型為 \(T_1\) 的單一引數,且將透過 false_operand 叫用 (必須是相同類型)。回傳值的 true_computationfalse_computation 類型必須相同。

請注意,系統會根據 pred 的值執行 true_computationfalse_computation 其中之一。

Conditional(branch_index, branch_computations, branch_operands)

引數 類型 語義
branch_index XlaOp S32 類型的純量
branch_computations N XlaComputation 序列 \(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)類型的 XlaComputations
branch_operands N XlaOp 序列 \(T_0 , T_1 , ..., T_{N-1}\)類型的引數

執行 branch_computations[branch_index],並傳回結果。如果 branch_index 是小於 0 或 >= N 的 S32,則會以預設分支版本的形式執行 branch_computations[N-1]

每個 branch_computations[b] 都必須採用類型為 \(T_b\) 的單一引數,並使用 branch_operands[b] (必須屬於相同類型) 叫用。每個 branch_computations[b] 的傳回值類型必須相同。

請注意,根據 branch_index 的值,系統只會執行其中一個 branch_computations

轉換 (對話)

另請參閱 XlaBuilder::Conv

如同 ConvWithGeneralPadding,但邊框間距是以 SAME 或 VALID 的簡短方式來指定。為輸入 (lhs) 加上零填充邊框間距,這樣在不考慮片段時,輸出內容的形狀會與輸入相同。VALID 邊框間距代表無邊框間距。

ConvWithGeneralPadding (卷積)

另請參閱 XlaBuilder::ConvWithGeneralPadding

計算類神經網路所用種類的捲積。這裡的捲積可以視為在 ND 底座上移動的 n 維窗口,然後對視窗的每個可能位置執行計算。

引數 類型 語義
lhs XlaOp 依輸入項目數量的 n+2 陣列排名
rhs XlaOp 核心權重排名 n+2 陣列
window_strides ArraySlice<int64> n-d 核心步伐
padding ArraySlice< pair<int64,int64>> n-d 陣列的 (低、高) 邊框間距
lhs_dilation ArraySlice<int64> n-d lhs 除錯因數陣列
rhs_dilation ArraySlice<int64> n-d rhs 擴散因數陣列
feature_group_count int64 特徵群組數量
batch_group_count int64 批次群組數量

n 代表空間維度的數量。lhs 引數是描述基本區域的 n+2 陣列。即使 Rh 也是輸入,也稱為「輸入」。在類神經網路中,這些是輸入啟用事件。n+2 維度的順序如下:

  • batch:這個維度中的每個座標都代表執行卷積的獨立輸入。
  • z/depth/features:底面積的每個 (y,x) 位置都有與其相關聯的向量,將轉換為這個維度。
  • spatial_dims:說明 n 空間維度,用於定義視窗間移動的基本區域。

rhs 引數是排名 n+2 陣列,用於說明卷積篩選器/核心/窗口。維度的順序如下:

  • output-z:輸出內容的 z 維度。
  • input-z:這個維度乘以 feature_group_count 的大小應等於 lhs 中的 z 尺寸大小。
  • spatial_dims:說明 n 空間維度,用於定義橫跨基本區域的 n-d 窗口。

window_strides 引數會在空間維度中指定卷積窗的步距。舉例來說,如果第一個空間維度的步長為 3,那麼窗口只能放在第一個空間索引是 3 除盡的座標。

padding 引數會指定要套用至底區域的零邊框間距量。邊框間距量可以是負數 -- 負邊框間距的絕對值代表在執行卷積之前要從指定維度中移除的元素數量。padding[0] 會指定維度 ypadding[1] 的邊框間距,並指定為維度 x 的邊框間距。每個組合的第一個元素都有低邊框間距,第二個元素則採用高邊框間距。低邊框間距會以較低的索引方向套用,而高邊框間距會以較高索引方向的方向套用。舉例來說,如果 padding[1](2,3),左側會出現邊框間距 20,右側為第二個空間維度 30。使用邊框間距,就等同先在輸入 (lhs) 中插入相同的零值,再執行卷積。

lhs_dilationrhs_dilation 引數會指定在每個空間維度中,要套用至面板和 Rh 的分割因數。如果空間維度的縮小係數已內嵌,則該維度中每個項目之間會以隱含方式放置 d-1 洞,增加陣列的大小。這些孔會填滿無人工管理的值,而卷積的意思為零。

rh 的擴散也稱為的能力卷積。詳情請參閱 tf.nn.atrous_conv2d。鏡頭的縮放也稱為轉置卷積。詳情請參閱 tf.nn.conv2d_transpose

feature_group_count 引數 (預設值 1) 可用於分組的對話。feature_group_count 需同時是輸入與輸出特徵維度的除數。如果 feature_group_count 大於 1,表示從概念上來說,輸入和輸出特徵維度和 rhs 輸出特徵維度會平均分割為多個 feature_group_count 群組,每個群組都包含連續的子序列特徵。rhs 的輸入特徵維度必須等於 lhs 輸入特徵維度除以 feature_group_count 所得的值 (因此已設有輸入特徵群組的大小)。這個 i-th 群組可搭配使用,針對許多獨立卷積計算 feature_group_count。這些卷積的結果會在輸出特徵維度中串連在一起。

如果是深度卷積,系統會將 feature_group_count 引數設為輸入特徵維度,並將篩選器從 [filter_height, filter_width, in_channels, channel_multiplier] 重設為 [filter_height, filter_width, 1, in_channels * channel_multiplier]。詳情請參閱 tf.nn.depthwise_conv2d

batch_group_count (預設值 1) 引數可用於反向傳播期間的分組篩選器。batch_group_count 必須是 lhs (輸入) 批次維度大小的除數。如果 batch_group_count 大於 1,表示輸出批次維度的大小應為 input batch / batch_group_countbatch_group_count 必須是輸出特徵大小的除數。

輸出形狀包含以下維度:

  • batch:這個維度乘以 batch_group_count 的大小應等於 Google 相簿中的 batch 尺寸。
  • z:大小與核心上的 output-z 相同 (rhs)。
  • spatial_dims:每個卷積視窗的有效位置一個值。

上圖顯示 batch_group_count 欄位的運作方式。實際上,我們會將每個 lh 分批分割為 batch_group_count 群組,並針對輸出特徵執行相同操作。然後,我們會針對每個群組進行配對卷積,並將輸出內容與輸出特徵維度串連在一起。所有其他維度 (特徵和空間) 的作業語意維持不變。

卷積窗的有效位置取決於邊框間距後基本區域的大小。

如要說明卷積的功用,請考慮使用 2D 卷積,並在輸出內容中挑選一些固定的 batchzyx 座標。然後,(y,x) 是視窗在基本區域內的角落 (例如左上角,視如何解讀空間維度而定)。現在有一個 2D 視窗,從基礎區域取得,每個 2D 點都與一個 1D 向量相關聯,因此會得到 3D 方塊。在卷積核心中,我們修正了輸出座標 z,因此還有 3D 方塊。這兩個方塊具有相同的尺寸,因此可計算兩個方塊之間元素的各項產品總和 (類似點積)。這是輸出值。

請注意,如果 output-z 是5,那麼視窗的每個位置都會在輸出的 z 維度中產生 5 個值。這些值與卷積核心使用的組成部分不同,每個 output-z 座標都有獨立的 3D 方塊值。這可以視為 5 個不同的捲積,每個卷軸都有不同的篩選器。

以下是搭配填充和擷取的 2D 卷積的虛擬程式碼:

for (b, oz, oy, ox) {  // output coordinates
  value = 0;
  for (iz, ky, kx) {  // kernel coordinates and input z
    iy = oy*stride_y + ky - pad_low_y;
    ix = ox*stride_x + kx - pad_low_x;
    if ((iy, ix) inside the base area considered without padding) {
      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
    }
  }
  output(b, oz, oy, ox) = value;
}

ConvertElementType

另請參閱 XlaBuilder::ConvertElementType

與 C++ 中的元素相關 static_cast 類似,執行從資料形狀到目標形狀的各元素轉換作業。維度必須相符,且轉換是層級明確的轉換。例如,s32 元素會透過 s32f32 的轉換日常安排成為 f32 元素。

ConvertElementType(operand, new_element_type)

引數 類型 語義
operand XlaOp T 型陣列,帶有暗色 D
new_element_type PrimitiveType U 型

運算元維度和目標形狀必須相符。來源與目的地元素類型不得為元組。

T=s32U=f32 這類轉換會執行正規化的浮動轉換處理常式,例如來回平均。

let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}

CrossReplicaSum

以加總計算執行 AllReduce

CustomCall

另請參閱 XlaBuilder::CustomCall

在計算中呼叫使用者提供的函式。

CustomCall(target_name, args..., shape)

引數 類型 語義
target_name string 函式的名稱。系統會發出指定這個符號名稱的通話指示。
args N XlaOp 秒序列 N 任意類型的引數,會傳遞至函式。
shape Shape 函式的輸出形狀

無論引數或引數類型為何,函式簽章都相同:

extern "C" void target_name(void* out, void** in);

舉例來說,假設 CustomCall 使用如下:

let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };

CustomCall("myfunc", {x, y}, f32[3x3])

以下是 myfunc 的實作範例:

extern "C" void myfunc(void* out, void** in) {
  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
  EXPECT_EQ(1, x[0]);
  EXPECT_EQ(2, x[1]);
  EXPECT_EQ(10, y[0][0]);
  EXPECT_EQ(20, y[0][1]);
  EXPECT_EQ(30, y[0][2]);
  EXPECT_EQ(40, y[1][0]);
  EXPECT_EQ(50, y[1][1]);
  EXPECT_EQ(60, y[1][2]);
  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
  z[0][0] = x[1] + y[1][0];
  // ...
}

使用者提供函式不得有副作用,且其執行方式必須為冪等。

Dot

另請參閱 XlaBuilder::Dot

Dot(lhs, rhs)

引數 類型 語義
lhs XlaOp T 類型的陣列
rhs XlaOp T 類型的陣列

這項作業的確切語意取決於運算元的排名:

輸入資料 輸出結果 語義
向量 [n] dot 向量 [n] 純量 向量內積
矩陣 [m x k] dot 向量 [k] 向量 [m] 矩陣向量乘法
矩陣 [m x k] dot 矩陣 [k x n] 矩陣 [m x n] 矩陣與矩陣乘法

運算會執行 lhs 第二個維度 (如果排名為 1,第一個為第 1) 和第一個維度 rhs 的產品總和。這就是「約定」維度。lhsrhs 的約定維度大小必須相同。在實務上,可用於在向量、向量/矩陣乘法或矩陣/矩陣乘法之間執行內積。

DotGeneral

另請參閱 XlaBuilder::DotGeneral

DotGeneral(lhs, rhs, dimension_numbers)

引數 類型 語義
lhs XlaOp T 類型的陣列
rhs XlaOp T 類型的陣列
dimension_numbers DotDimensionNumbers 合約和批次維度編號

與 Dot 類似,但允許同時為 lhsrhs 指定合約和批次維度編號。

DotDimensionsNumbers 欄位 類型 語義
lhs_contracting_dimensions 重複的 int64 lhs 個約定維度編號
rhs_contracting_dimensions 重複的 int64 rhs 個約定維度編號
lhs_batch_dimensions 重複的 int64 lhs 個批次維度編號
rhs_batch_dimensions 重複的 int64 rhs 個批次維度編號

DotGeneral 會根據 dimension_numbers 中指定的合約維度執行產品總和。

lhsrhs 的關聯約定維度編號不一定要相同,但維度大小必須相同。

合約維度編號範例:

lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }

rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);

DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }

lhsrhs 中相關聯的批次維度編號必須具有相同的維度大小。

包含批次維度編號 (批次大小 2、2x2 矩陣) 的範例:

lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }

rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);

DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
輸入資料 輸出結果 語義
[b0, m, k] dot [b0, k, n] [b0、m、n] Batch Matmul
[b0, b1, m, k] dot [b0、b1、k、n] [b0、b1、m、n] Batch Matmul

其計算結果是其開頭為批次維度、lhs 非合約/非批次維度,以及 rhs 非合約/非批次維度。

DynamicSlice

另請參閱 XlaBuilder::DynamicSlice

DynamicSlice 會從動態 start_indices 的輸入陣列擷取子陣列。每個維度中的配量大小會傳入 size_indices 中,進而指定每個維度的專屬配量間隔的終點:[起始、開始 + 大小]。start_indices 的形狀必須排名 == 1,且維度大小等於 operand 的排名。

DynamicSlice(operand, start_indices, size_indices)

引數 類型 語義
operand XlaOp T 類型的 N 維陣列
start_indices N XlaOp 序列 列出 N 純量整數,內含每個維度切片的起始索引。值必須大於或等於 0。
size_indices ArraySlice<int64> 列出 N 個整數,內含每個維度的配量大小。每個值都必須大於 0,且開頭 + 大小必須小於或等於維度尺寸,以避免納入模數維度大小。

有效的配量索引會先對 [1, N) 中的每個索引 i 套用下列轉換,再執行配量:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])

這可確保擷取的切片一律與運算元陣列有關。如果切片在套用轉換前是邊界,轉換就不會產生任何作用。

1 維範例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}

DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}

2D 範例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}

DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0,  8.0},
{10.0, 11.0} }

DynamicUpdateSlice

另請參閱 XlaBuilder::DynamicUpdateSlice

DynamicUpdateSlice 產生結果,該結果為輸入陣列 operand 的值,並在 start_indices 覆寫切片 updateupdate 的形狀會決定結果的子陣列形狀。start_indices 的形狀排名必須等於排名 == 1,且維度大小等於 operand 的排名。

DynamicUpdateSlice(operand, update, start_indices)

引數 類型 語義
operand XlaOp T 類型的 N 維陣列
update XlaOp 包含切片更新的 T 類型 N 維陣列。每個更新形狀的維度都必須大於零,而「開始 + 更新」的值必須小於或等於每個維度的運算元大小,以免產生超出範圍的更新索引。
start_indices N XlaOp 序列 列出 N 純量整數,內含每個維度切片的起始索引。值必須大於或等於 0。

有效的配量索引會先對 [1, N) 中的每個索引 i 套用下列轉換,再執行配量:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])

這可確保更新過的切片一律與運算元陣列相關的範圍內。如果切片在套用轉換前是邊界,轉換就不會產生任何作用。

1 維範例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}

DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}

2D 範例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0,  13.0},
{14.0,  15.0},
{16.0,  17.0} }

let s = {1, 1}

DynamicUpdateSlice(b, u, s) produces:
{ {0.0,  1.0,  2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }

元素級別二元算術運算

另請參閱 XlaBuilder::Add

系統支援一組元素相關二元算術運算。

Op(lhs, rhs)

其中 OpAdd (加)、Sub (減法)、Div (除法)、Rem (剩餘)、Max (最大)、Min (最小)、LogicalAnd (邏輯 AND) 或 LogicalOr (邏輯 OR)。Mul

引數 類型 語義
lhs XlaOp 左側運算元:T 類型的陣列
rhs XlaOp 右側運算元:T 類型的陣列

引數的形狀必須相似或相容。請參閱廣播說明文件,瞭解形狀相容的意義。運算的結果具有形狀,便是播送兩個輸入陣列的結果。在這個變因中,除非其中一個運算元是純量,否則系統「不」支援不同排名陣列之間的作業。

如果 OpRem,則結果的正負號會從被除數取出,且結果的絕對值一律小於除數的絕對值。

整數除法溢位 (帶正負號/未正負號的除法/餘數為零,或 INT_SMIN 的已簽署除法/remainder,具有 -1) 會產生實作定義值。

這類作業提供支援不同排名廣播支援的替代變化版本:

Op(lhs, rhs, broadcast_dimensions)

其中 Op 與上述相同。此運算的變化版本應用於不同排名陣列之間的算術運算 (例如將矩陣加入向量)。

額外的 broadcast_dimensions 運算元是整數切片,用於將排名較低的運算元排名提升到排名較高的運算元排名。broadcast_dimensions 會將較低排名形狀的維度對應至較高排名形狀的維度。展開形狀中未對應的尺寸會填入大小為 1 的尺寸。解壓縮維度廣播,然後根據這些產生的尺寸播送形狀,讓兩個運算元的形狀均等。如要進一步瞭解語意,請參閱廣播頁面

元素相關比較作業

另請參閱 XlaBuilder::Eq

支援一組標準元素相關二元比較運算。請注意,比較浮點類型時,適用標準 IEEE 754 浮點比較語意。

Op(lhs, rhs)

其中 OpEq (等於)、Ne (不等於)、Ge (大於或等於)、Gt (大於)、Le (小於或等於)、Lt (小於) 之一。另一組運算子 (EqTotalOrder、NTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder 和 LtTotalOrder,可強制執行 -NaN < -Inf < -Inf < -Finite +Nain +Nain +

引數 類型 語義
lhs XlaOp 左側運算元:T 類型的陣列
rhs XlaOp 右側運算元:T 類型的陣列

引數的形狀必須相似或相容。請參閱廣播說明文件,瞭解形狀相容的意義。運算的結果具有形狀,代表廣播元素類型為 PRED 的兩個輸入陣列的結果。在這個變因中,系統「不支援」不同排名陣列之間的作業,除非其中一個運算元是純量。

這類作業提供支援不同排名廣播支援的替代變化版本:

Op(lhs, rhs, broadcast_dimensions)

其中 Op 與上述相同。此運算變化版本應用於不同排名陣列之間的比較作業 (例如將矩陣加入向量)。

額外的 broadcast_dimensions 運算元是整數的一部分,用於指定要用於播送運算元的維度。如要進一步瞭解語意,請參閱廣播頁面

元素相關一元函式

XlaBuilder 支援下列元素相關一元函式:

Abs(operand) 元素相關抽象 x -> |x|

Ceil(operand) 元素相關宣告 x -> ⌈x⌉

Cos(operand) 元素相關餘弦 x -> cos(x)

Exp(operand) 元素寬度自然指數 x -> e^x

Floor(operand) 元素層級下限 x -> ⌊x⌋

Imag(operand) 複雜 (或實際) 形狀中的元素寬度虛部分。x -> imag(x)。如果運算元是浮點類型,就會傳回 0。

IsFinite(operand) 測試 operand 的每個元素是否有限,即並非正無限或負無限大,而不是 NaN。傳回與輸入具有相同形狀的 PRED 值陣列,其中每個元素都是 true 時,只有在對應的輸入元素是有限時。

Log(operand) 元素相關自然對數 x -> ln(x)

LogicalNot(operand) 元素相關邏輯不是 x -> !(x)

Logistic(operand) 元素相關邏輯函式計算 x -> logistic(x)

PopulationCount(operand) 會計算 operand 每個元素中設定的位元數。

Neg(operand) 元素相關否定 x -> -x

Real(operand) 元素相關實際部分,為複雜 (或實) 形狀。x -> real(x)。如果運算元是浮點類型,則會傳回相同的值。

Rsqrt(operand) 平方根運算 x -> 1.0 / sqrt(x) 的元素彈性輪廓。

Sign(operand) 元素相關符號作業 x -> sgn(x),其中

\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]

operand 元素類型的比較運算子。

Sqrt(operand) 元素左右平方根運算 x -> sqrt(x)

Cbrt(operand) 元素相關立方根作業 x -> cbrt(x)

Tanh(operand) 元素寬度雙曲正切 x -> tanh(x)

Round(operand) 元素層級捨入,與零相同。

RoundNearestEven(operand) 元素相關捨入,與最接近的偶數相連。

引數 類型 語義
operand XlaOp 函式的運算元

該函式會套用至 operand 陣列中的每個元素,進而產生具有相同形狀的陣列。operand 可以是純量 (排名 0)。

Fft

XLA FFT 運算會實作向前和反向的 Fourier 轉換,適用於實際和複雜的輸入/輸出。支援最多 3 軸的多維度 FFT。

另請參閱 XlaBuilder::Fft

引數 類型 語義
operand XlaOp 我們正在轉換傅立葉的陣列。
fft_type FftType 請參閱下表。
fft_length ArraySlice<int64> 轉換的軸長度網域長度。由於 RFFT(fft_length=[16]) 的輸出形狀與 RFFT(fft_length=[17]) 相同,因此 IRFFT 必須縮小最內層的軸大小。
FftType 語義
FFT 將複雜到複雜的 FFT。形狀未變更。
IFFT 複雜至複雜的 FFT。形狀未變更。
RFFT 轉寄實際到複雜 FFT。如果 fft_length[-1] 是非零的值,則最內軸的形狀會縮減為 fft_length[-1] // 2 + 1,但會省略轉換信號的反向串連部分,超過 Nyquist 頻率。
IRFFT 非實際為複雜 FFT (即需要複雜、傳回真實)。如果 fft_length[-1] 是非零值,則最內軸的形狀會展開為 fft_length[-1],從 1 的反向合併為 Nyquist 頻率開始推斷已轉換信號的部分,從 1 的反轉到 fft_length[-1] // 2 + 1 項目。

多維度 FFT

如果提供超過 1 個 fft_length,即等同於對最內層的每個軸套用 FFT 運算。請注意,在實際 > 複雜且複雜 - 真實的情況下,最內層的轉換會 (有效) 執行 (RFFT,最後用於 IRFFT),因此最內層是改變大小的因素。其他軸轉換則會是複雜、>複雜。

實作詳情

CPU FFT 採用 Eigen 的 TensorFFT 技術。GPU FFT 使用 cuFFT。

收集

XLA 收集運算將輸入陣列的多個配量 (每個切片可能會有不同的執行階段偏移) 拼接在一起。

一般語意

另請參閱 XlaBuilder::Gather。如需更簡單易懂的說明,請參閱下方的「正式說明」一節。

gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)

引數 類型 語義
operand XlaOp 我們正在收集的陣列,
start_indices XlaOp 包含我們所收集切片起始索引的陣列。
index_vector_dim int64 start_indices 中「包含」起始索引的維度。詳情請見下文說明。
offset_dims ArraySlice<int64> 輸出形狀中的一組維度,偏移為從運算元切割的陣列。
slice_sizes ArraySlice<int64> slice_sizes[i] 是維度 i 區塊的邊界。
collapsed_slice_dims ArraySlice<int64> 每個區塊中收合的維度組合。這些尺寸的大小必須為 1。
start_index_map ArraySlice<int64> 這張地圖說明如何將 start_indices 中的索引對應至運算元。
indices_are_sorted bool 是否保證呼叫均會由呼叫端排序。
unique_indices bool 呼叫端保證索引不會重複。

為方便起見,我們會將輸出陣列中的維度 (在 offset_dims 中) 標示為 batch_dims

輸出結果是排名 batch_dims.size 加上 offset_dims.size 的陣列。

operand.rank 必須等於 offset_dims.sizecollapsed_slice_dims.size 的總和。此外,slice_sizes.size 必須等於 operand.rank

如果 index_vector_dim 等於 start_indices.rank,我們隱含地認為 start_indices 有結尾的 1 維度 (例如,如果 start_indices 是形狀 [6,7],而 index_vector_dim2,我們默示會將 start_indices 的形狀視為 [6,7,1])。

沿著維度 i 的輸出陣列邊界,計算方式如下:

  1. 如果 batch_dims 中存在 i (也就是說,對於某些 k 等於 batch_dims[k]),我們會挑選 start_indices.shape 中的對應維度邊界,然後略過 index_vector_dim (亦即在 k < index_vector_dimstart_indices.shape.dims[k+1] 中挑選 start_indices.shape.dims[k]。

  2. 如果 offset_dims 中存在 i (即部分 koffset_dims[k]),則在計算 collapsed_slice_dims 後,我們會挑選 slice_sizes 的對應邊界 (亦即挑選 adjusted_slice_sizes[k],其中 adjusted_slice_sizesslice_sizes,且已移除索引 collapsed_slice_dims 的邊界)。

對應輸出索引 Out 的運算元索引 In 的計算方式如下:

  1. 允許 G = { Out[k] 用於 batch_dims } 中的 k。使用 G 分割向量 S,這樣 S[i] = start_indices[Merge(G, i)],其中 merge(A, b) 會將位於位置 index_vector_dim 的 b 插入 A。請注意,即使 G 為空白,也是如此:如果 G 為空白,則 S = start_indices

  2. 使用 start_index_map 分散 S,使用 S 建立起始索引 Sinoperand更精準:

    1. Sin[start_index_map[k]] = S[k] (如果 k < start_index_map.size)。

    2. Sin[_] = 0 否則。

  3. 根據 collapsed_slice_dims 設定,在 Out 的偏移維度分散索引,藉此將索引 Oin 建立為 operand。更精準:

    1. Oin[remapped_offset_dims(k)] = Out[offset_dims[k]] 如果 k < offset_dims.size (remapped_offset_dims 定義如下)。

    2. Oin[_] = 0 否則。

  4. InOin + Sin,其中 + 代表元素相關加法。

remapped_offset_dims 是網域 [0, offset_dims.size) 和範圍 [0, operand.rank) \ collapsed_slice_dims 的單調函式。假設offset_dims.size4operand.rank6collapsed_slice_dims 為 {02},則 remapped_offset_dims 為 {01132435}。

如果將 indices_are_sorted 設為 true,XLA 就會假設 start_indices 會依照使用者排序 (遞增的 start_index_map 順序)。如果語意不同,則會定義語意。

如果將 unique_indices 設為 true,XLA 會假設所有因被分散的元素都不重複。因此 XLA 可以使用非原子運算。如果 unique_indices 設為 true,且分散的索引不不重複,就會定義語意。

非正式說明與範例

一般來說,輸出陣列中的每個索引 Out 都會對應至運算元陣列中的元素 E,計算方式如下:

  • 我們會使用 Out 中的批次維度,查詢 start_indices 的起始索引。

  • 我們會使用 start_index_map,將起始索引 (大小可能小於 operand.rank) 對應至 operand 的「full」起始索引。

  • 我們使用完整的起始索引,動態切出大小為 slice_sizes 的切片。

  • 我們會收合 collapsed_slice_dims 維度來重塑切片。由於所有收合的切片維度都必須有 1 的邊界,因此此重新形狀一律合法。

  • 我們會使用 Out 中的偏移維度,為這個配量建立索引,以取得與輸出索引 Out 對應的輸入元素 E

在後續所有範例中,index_vector_dim 設為 start_indices.rank - 1。較有趣的 index_vector_dim 值不會從根本改變作業,但會讓視覺呈現方式更加困難。

如要瞭解上述所有選項如何搭配運作,我們來看看從 [16,11] 陣列收集 5 個形狀 [8,6] 配量的範例。[16,11] 陣列中的切片位置可以表示為形狀 S64[2] 的索引向量,因此 5 個位置的組合可用 S64[5,2] 陣列表示。

收集作業的行為隨後可呈現為索引轉換,其中需要 [GO0O1]、輸出形狀中的索引,並透過下列方式對應至輸入陣列中的元素:

首先,請使用 G 從收集索引陣列選取 (XY) 向量。輸出陣列中索引 [G,O0,O1] 的元素是索引 [X+O0,Y+O1] 處輸入陣列中的元素。

slice_sizes[8,6],作用是決定 O0 和 O1 的範圍,這會決定切片的邊界。

這項收集作業可做為批次動態配量,使用 G 做為批次維度。

收集索引可能有多種面向。舉例來說,使用形狀 [4,5,2] 的「gather indices」陣列的上述範例,將轉譯索引,如下所示:

同樣的,這等同於批次動態配量 G0G1 做為批次維度。篩選器大小仍為 [8,6]

XLA 中的集合運算會用以下方式一般化上述非正式語意:

  1. 我們可以設定輸出形狀中的哪些維度是偏移維度 (上一個範例包含 O0O1 的維度)。系統會將輸出批次維度 (上例中含有 G0G1 的維度) 定義為非偏移維度的輸出維度。

  2. 輸出形狀中明確呈現的輸出偏移維度數量可能小於輸入排名。這些「缺少」的維度必須明確列為 collapsed_slice_dims,且切片大小必須為 1。由於它們的配量大小為 1,因此唯一有效的索引是 0,如果不需要任何索引,就不會產生混淆。

  3. 在上個範例中,從「Gather Indices」陣列 (XY) 擷取的切片可能比輸入陣列排名少,而明確對應會指定索引展開方式,使其具有與輸入內容相同的排名。

我們最後使用 (2) 和 (3) 實作 tf.gather_nd

G0G1 會照常從集合索引陣列切出起始索引,但起始索引只有一個元素 X。同樣地,只有一個輸出偏移索引含有 O0 值。不過,在做為輸入陣列中的索引使用之前,這些會依照「Gather Index Mapping」(正式說明中的 start_index_map) 和「Offset Mapping」(正式說明中的 remapped_offset_dims) 和「Offset Mapping」(正式說明中的 remapped_offset_dims) 分別擴充為 [X0] 和 [0O0]O0X00000OOGGGG11GatherIndicestf.gather_nd

本案件的slice_sizes[1,11]。因此,收集索引陣列中的每個索引 X 都會挑選整個資料列,結果是所有資料列的串連。

GetDimensionSize

另請參閱 XlaBuilder::GetDimensionSize

傳回指定運算元的大小。運算元必須是陣列形狀。

GetDimensionSize(operand, dimension)

引數 類型 語義
operand XlaOp N 維輸入陣列
dimension int64 指定維度的間隔 [0, n)

SetDimensionSize

另請參閱 XlaBuilder::SetDimensionSize

設定 XlaOp 指定的維度動態大小。運算元必須是陣列形狀。

SetDimensionSize(operand, size, dimension)

引數 類型 語義
operand XlaOp n 維度輸入陣列。
size XlaOp int32 代表執行階段的動態大小。
dimension int64 指定維度的間隔 [0, n) 值。

使用編譯器追蹤動態維度,將運算元傳遞為結果。

下游縮減作業會忽略填充值。

let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;

// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);

// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);

// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);

GetTupleElement

另請參閱 XlaBuilder::GetTupleElement

這個外掛程式能建立具備編譯時間常數值的元組。

該值必須是編譯時間常數,這樣形狀推論才能判斷結果值的類型。

這類似於 C++ 中的 std::get<int N>(t)。概念說明:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1);  // Inferred shape matches s32.

另請參閱 tf.tuple

動態內廣告

另請參閱 XlaBuilder::Infeed

Infeed(shape)

引數 類型 語義
shape Shape 從 Infeed 介面讀取的資料形狀。形狀的版面配置欄位必須設為與傳送至裝置的資料版面配置相符,否則其行為是未定義。

從裝置隱含的內饋串流介面讀取單一資料項目,將資料解譯為指定的形狀和版面配置,並傳回資料的 XlaOp。一次計算中允許多個動態內作業,但動態內操作必須有總計順序。舉例來說,以下程式碼中的兩個 Infeeds 有總順序,因為迴圈之間存在依附元件。

result1 = while (condition, init = init_value) {
  Infeed(shape)
}

result2 = while (condition, init = result1) {
  Infeed(shape)
}

不支援巢狀元組形狀。如果是空白元組形狀,動態內作業實際上是免人工管理,並繼續執行,而不會從裝置動態內讀取任何資料。

器皿打擊樂

另請參閱 XlaBuilder::Iota

Iota(shape, iota_dimension)

在裝置上建構常數常值,而非可能的大型主機傳輸。建立具有指定形狀的陣列,並保留從 0 開始,並按照指定維度遞增 1 的值。如果是浮點類型,產生的陣列與 ConvertElementType(Iota(...)) 相同,其中 Iota 屬於積分類型,且轉換為浮點類型。

引數 類型 語義
shape Shape Iota() 建立的陣列形狀
iota_dimension int64 要隨時間遞增的維度。

舉例來說,Iota(s32[4, 8], 0) 會傳回

  [[0, 0, 0, 0, 0, 0, 0, 0 ],
   [1, 1, 1, 1, 1, 1, 1, 1 ],
   [2, 2, 2, 2, 2, 2, 2, 2 ],
   [3, 3, 3, 3, 3, 3, 3, 3 ]]

可退貨 (費用:Iota(s32[4, 8], 1))

  [[0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ]]

地圖

另請參閱 XlaBuilder::Map

Map(operands..., computation)

引數 類型 語義
operands N XlaOp 秒序列 N 種 T0..T{N-1} 類型的陣列
computation XlaComputation T_0, T_1, .., T_{N + M -1} -> S」類型的計算,其中 N 個類型為 T 且任意類型的 M 參數
dimensions int64 陣列 地圖維度陣列

將純量函式套用至指定的 operands 陣列,產生相同維度的陣列,其中每個元素都是對輸入陣列中對應元素套用對應函式的結果。

對應函式為不受限制的任意運算,其中含有 N 個純量類型 T,以及類型為 S 的單一輸出。輸出內容的維度與運算元相同,但元素類型 T 會替換為 S。

例如:Map(op1, op2, op3, computation, par1)elem_out <- computation(elem1, elem2, elem3, par1) 對應至輸入陣列中的每個 (多維度) 索引,以產生輸出陣列。

OptimizationBarrier

封鎖任何最佳化傳遞,使其無法跨越障礙移動計算。

確保先評估所有輸入內容,再評估任何依附於阻隔線輸出內容的運算子。

防溢乳墊

另請參閱 XlaBuilder::Pad

Pad(operand, padding_value, padding_config)

引數 類型 語義
operand XlaOp T 類型的陣列
padding_value XlaOp T 類型的純量,用於填入新增的邊框間距
padding_config PaddingConfig 兩邊 (低、高) 及各維度元素之間的邊框間距

透過邊框間距和指定的 padding_value 在陣列的元素之間加上邊框間距,以展開指定的 operand 陣列。padding_config 指定每個維度的邊緣邊框間距量和內部邊框間距。

PaddingConfigPaddingConfigDimension 的重複欄位,其中包含每個維度的三個欄位:edge_padding_lowedge_padding_highinterior_padding

edge_padding_lowedge_padding_high 分別指定在低端 (索引 0 旁邊) 和頂端 (最高索引旁邊) 新增的邊框間距量。邊緣邊框間距量可以是負值,負值邊框間距的絕對值代表要從指定維度中移除的元素數量。

interior_padding 會指定每個維度在任意兩個元素之間加入的邊框間距量,但不得為負數。內部邊框間距發生在邊緣邊框間距之前,因此如果為負數邊緣邊框間距,元素會從內部填充運算元中移除。

如果邊緣邊框間距組合全部為 (0, 0),且內部邊框間距值為 0,則此運算即免人工管理。下圖顯示二維陣列的不同 edge_paddinginterior_padding 值範例。

接球

另請參閱 XlaBuilder::Recv

Recv(shape, channel_handle)

引數 類型 語義
shape Shape 要接收的資料形狀
channel_handle ChannelHandle 每個傳送/接收組合的專屬 ID

接收來自共用相同管道控制代碼的其他運算中的 Send 指示,接收指定形狀的資料。針對收到的資料傳回 XlaOp。

Recv 作業的用戶端 API 代表同步通訊。不過,這個指令會在內部分解成 2 個 HLO 指令 (RecvRecvDone),以啟用非同步資料移轉。另請參閱 HloInstruction::CreateRecvHloInstruction::CreateRecvDone

Recv(const Shape& shape, int64 channel_id)

分配必要的資源,以便接收來自相同 channel_id 的 Send 指示的資料。傳回已分配資源的內容,下列 RecvDone 指令會使用該結構來等待資料移轉完成。結構定義是 {receive buffer (shape), 要求 ID (U32)} 的組合,只能用於 RecvDone 指令。

RecvDone(HloInstruction context)

假設結構定義是由 Recv 指令建立,就會等候資料移轉完成並傳回收到的資料。

遏止

另請參閱 XlaBuilder::Reduce

將縮減函式平行套用至一或多個陣列。

Reduce(operands..., init_values..., computation, dimensions)

引數 類型 語義
operands N XlaOp 的順序 T_0, ..., T_{N-1} 類型的 N 陣列。
init_values N XlaOp 的順序 T_0, ..., T_{N-1} 類型的 N 純量。
computation XlaComputation T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 類型的計算。
dimensions int64 陣列 要減少的維度陣列 (未排序)。

在此情況下:

  • N 必須大於或等於 1。
  • 計算作業必須「大致」具有關聯性 (請參閱下文)。
  • 所有輸入陣列的尺寸都必須相同。
  • 所有初始值都必須在 computation 底下形成身分。
  • 如果值為 N = 1Collate(T)T
  • 如果值為 N > 1Collate(T_0, ..., T_{N-1})T 類型的 N 元素元組。

這項作業會將每個輸入陣列的一或多個維度縮減為純量。每個傳回陣列的排名為 rank(operand) - len(dimensions)。運算的輸出內容為 Collate(Q_0, ..., Q_N),其中 Q_iT_i 類型的陣列,說明的維度如下。

允許不同的後端與縮減運算建立關聯。這可能會導致數值差異,因為某些縮減函式 (例如加法) 無法與浮點值建立關聯。不過,如果資料範圍有限,那麼加入浮點數就足以與大多數實際用途相關。

範例

在含有 [10, 11, 12, 13] 值且使用縮減函式 f (為 computation) 的單一 1D 陣列中縮減一個維度時,計算方式為

f(10, f(11, f(12, f(init_value, 13)))

但還有其他可能性,例如

f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))

以下是簡化的虛擬程式碼實作方式範例,其中使用加總做為初始值為 0 的縮減計算。

result_shape <- remove all dims in dimensions from operand_shape

# Iterate over all elements in result_shape. The number of r's here is equal
# to the rank of the result
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
  # Initialize this result element
  result[r0, r1...] <- 0

  # Iterate over all the reduction dimensions
  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
    # Increment the result element with the value of the operand's element.
    # The index of the operand's element is constructed from all ri's and di's
    # in the right order (by construction ri's and di's together index over the
    # whole operand shape).
    result[r0, r1...] += operand[ri... di]

以下範例說明如何縮減 2D 陣列 (矩陣)。形狀有排名第 2 的位置,大小為 2 的維度 0,大小為 3 的維度 1:

使用「add」函式將維度 0 或 1 縮小的結果:

請注意,兩個縮減結果都是 1D 陣列。為方便起見,圖表會以欄的方式顯示另一列,另一列為一列。

如需更複雜的範例,以下是 3D 陣列。其排名為 3,大小為 4 的維度 0,大小為 2 的維度 1,大小為 3 的維度 2。為簡單起見,值 1 到 6 會跨維度 0 複製。

和 2D 範例一樣,我們只需減少一個維度即可。舉例來說,如果我們減少維度 0,就會取得 rank-2 陣列,其中所有維度 0 的所有值都折疊成純量:

|  4   8  12 |
| 16  20  24 |

如果減少維度 2,我們也會取得排名-2 陣列,其中維度 2 的所有值都折疊成純量:

| 6  15 |
| 6  15 |
| 6  15 |
| 6  15 |

請注意,輸入中其餘維度之間的相對順序會保留在輸出中,但某些維度可能會指派新數字 (因為排名變更之後)。

此外,也可以減少多個維度。將維度 0 和 1 加入減少後會產生 1D 陣列 [20, 28, 36]

將 3D 陣列超過所有維度就會產生純量 84

抑制止痛

設為 N > 1 時,減少函式應用程式會稍微複雜,因為系統會同時套用至所有輸入內容。運算元會依下列順序提供給計算:

  • 第一個運算元執行縮減的值
  • ...
  • 執行系統會縮減 N'th 運算元的值
  • 第一個運算元的輸入值
  • ...
  • 第 N 個運算元的輸入值

舉例來說,請考量下列縮減函式,該函式可平行計算 1-D 陣列的最大值和 argmax:

f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
  if value >= max:
    return (value, index)
  else:
    return (max, argmax)

如果 1-D 輸入陣列 V = Float[N], K = Int[N] 和 init 值 I_V = Float, I_K = Int,減少整個輸入維度的結果 f_(N-1) 等同於下列遞迴應用程式:

f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))

將此縮減套用到值陣列與連續索引陣列 (例如 iota) 將會同時處理陣列,並傳回包含最大值和相符索引的元組。

ReducePrecision

另請參閱 XlaBuilder::ReducePrecision

模擬將浮點值轉換為較低精確度格式 (例如 IEEE-FP16) 並改回原始格式的效果。儘管某些硬體實作可能不支援所有位元大小,但您可以任意指定較低精確度格式的指數和 mantisa 位元數量。

ReducePrecision(operand, mantissa_bits, exponent_bits)

引數 類型 語義
operand XlaOp 浮點類型 T 的陣列。
exponent_bits int32 精確度較低的指數位元數
mantissa_bits int32 較低精確度格式的 mantisa 位元數量

結果是 T 類型的陣列。輸入值會四捨五入至最接近的值,以指定的反正弦位元數量表示 (使用「和偶數」語意),而超出指數位元數指定範圍的任何值都會限制為正或負無限大。NaN 值會保留,不過可能會轉換成標準 NaN 值。

較低精確度格式必須有至少一個指數位元 (為了區分零值與無限大,因為兩者都有零 mantisa),且必須具有非負數的尾數位元。指數或 mantisa 數量可能會超過 T 類型的對應值;轉換的對應部分僅為免人工管理。

ReduceScatter

另請參閱 XlaBuilder::ReduceScatter

ReduceScatter 是集體作業,可有效執行 AllReduce,然後沿著 scatter_dimension 將結果分割成 shard_count 區塊,讓備用資源群組中的備用資源 i 收到 ith 資料分割。

ReduceScatter(operand, computation, scatter_dim, shard_count, replica_group_ids, channel_id)

引數 類型 語義
operand XlaOp 用來減少備用資源的陣列或非空白元組。
computation XlaComputation 縮減運算
scatter_dimension int64 要散佈的維度。
shard_count int64 要分割「scatter_dimension」的區塊數量
replica_groups int64 的向量 要執行降低幅度的群組
channel_id 自選int64 跨模組通訊的選用管道 ID
  • operand 是陣列的元組時,系統會對元組的每個元素執行減散器。
  • replica_groups 是執行縮減作業的備用資源群組清單 (可使用 ReplicaId 擷取目前備用資源的備用資源 ID)。每個群組中的備用資源順序將決定所有減少結果的分散順序。replica_groups 必須是空白 (在這種情況下,所有備用資源都屬於單一群組),或包含與備用資源數量相同的元素數量。如有多個備用資源群組,則所有備用資源群組的大小必須相同。例如,replica_groups = {0, 2}, {1, 3} 會減少備用資源 02 之間的資料,以及 13 之間的結果,然後分散結果。
  • shard_count 是每個備用資源群組的大小。如果 replica_groups 空白,就需要使用此方法。如果 replica_groups 並非空白,shard_count 必須等於每個備用資源群組的大小。
  • channel_id 用於跨模組通訊:只有具有相同 channel_idreduce-scatter 作業才能相互通訊。

輸出形狀是輸入形狀,scatter_dimension 縮小了 shard_count 倍。舉例來說,如果兩個備用資源有 [1.0, 2.25][3.0, 5.25] 值,而運算元在兩個備用資源上分別具有 [1.0, 2.25][3.0, 5.25] 值,那麼這個運算的輸出值就是第一個備用資源的 0[4.0],第二個備用資源的 [7.5]scatter_dim

ReduceWindow

另請參閱 XlaBuilder::ReduceWindow

對 N 多維陣列序列中的每個元素套用約化函式,以產生 N 多維度陣列的單一或元組做為輸出。每個輸出陣列的元素數量與視窗的有效位置數量相同。集區層可以表示為 ReduceWindow。與 Reduce 類似,套用的 computation 一律會在左側傳遞 init_values

ReduceWindow(operands..., init_values..., computation, window_dimensions, window_strides, padding)

引數 類型 語義
operands N XlaOps T_0,..., T_{N-1} 類型的 N 多維陣列序列,每個陣列都代表視窗放置的基礎區域。
init_values N XlaOps 指定縮減的 N 起始值,每個 N 運算元都有一個。詳情請參閱縮減
computation XlaComputation 用於 T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 類型的縮減函式,會套用至所有輸入運算元每個視窗中的元素。
window_dimensions ArraySlice<int64> 視窗維度值的整數陣列
window_strides ArraySlice<int64> 區間步值的整數陣列
base_dilations ArraySlice<int64> 底數的整數陣列
window_dilations ArraySlice<int64> 窗型除法值的整數陣列
padding Padding 視窗邊框間距類型 (Padding::kSame 會邊框間距為 1 時採用與輸入相同的輸出形狀;或 Padding::kValid,其不再使用邊框間距和「停靠點」期限)

在此情況下:

  • N 必須大於或等於 1。
  • 所有輸入陣列的尺寸都必須相同。
  • 如果值為 N = 1Collate(T)T
  • 如果值為 N > 1Collate(T_0, ..., T_{N-1})(T0,...T{N-1}) 類型的 N 元素元組。

在程式碼和圖下方顯示 ReduceWindow 的使用範例。輸入是大小 [4x6] 的矩陣,window_dimensions 和 window_stride_dimensions 都是 [2x3]。

// Create a computation for the reduction (maximum).
XlaComputation max;
{
  XlaBuilder builder(client_, "max");
  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
  builder.Max(y, x);
  max = builder.Build().value();
}

// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
    input,
    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
    *max,
    /*window_dimensions=*/{2, 3},
    /*window_stride_dimensions=*/{2, 3},
    Padding::kValid);

維度中的 1 步距指定維度中的視窗位置與相鄰視窗中的距離為 1 個元素。為了指定沒有任何視窗彼此重疊,window_stride_dimensions 應等於 window_dimensions。下圖說明如何使用兩個不同的步長值。邊框間距會套用至輸入的每個維度,而計算結果與填補後附帶的維度相同。

如需簡單的邊框間距範例,請考慮使用維度 3 計算縮減視窗下限 (初始值為 MAX_FLOAT),並將跨越 2 除以輸入陣列 [10000, 1000, 100, 10, 1]。邊框間距 kValid 會計算兩個有效區間的最小值:[10000, 1000, 100][100, 10, 1],輸出 [100, 1]。新增 kSame 會先填充陣列,讓縮減視窗之後的形狀在兩邊新增初始元素,因此會與跨行進程的輸入內容「相同」,以取得 [MAX_VALUE, 10000, 1000, 100, 10, 1, MAX_VALUE]。對填充陣列執行縮小視窗時,會在三個視窗 [MAX_VALUE, 10000, 1000][1000, 100, 10][10, 1, MAX_VALUE] 和 Yields [1000, 10, 1] 上運作。

縮減函式的評估順序是任意的,可能不具確定性。因此,減少函式不應過於敏感至重新關聯。詳情請參閱 Reduce 結構定義中有關關聯性的討論。

ReplicaId

另請參閱 XlaBuilder::ReplicaId

傳回備用資源的專屬 ID (U32 純量)。

ReplicaId()

每個備用資源的專屬 ID 是 [0, N) 間隔值中的無正負號整數,其中 N 是備用資源的數量。由於所有備用資源都執行同一程式,因此程式中的 ReplicaId() 呼叫會在每個備用資源上傳回不同的值。

重塑

另請參閱 XlaBuilder::ReshapeCollapse 作業。

將陣列的尺寸重新調整成新設定。

Reshape(operand, new_sizes) Reshape(operand, dimensions, new_sizes)

引數 類型 語義
operand XlaOp T 類型的陣列
dimensions int64 向量 依展開順序
new_sizes int64 向量 各種大小的向量

從概念上來說,先將陣列整併為一維資料值向量,然後再將此向量修正為新形狀。輸入引數是 T 類型的任意陣列、維度索引的編譯時間常數向量,以及結果維度大小的編譯時間常數向量。dimension 向量中的值 (如有指定) 必須是所有 T 維度的排列方式;如果未指定,則預設值為 {0, ..., rank - 1}dimensions 中的維度順序,從迴圈巢狀結構中最慢的維度 (最主要的) 到變動最快速的維度 (最次要) 的順序,會將輸入陣列收合為單一維度。new_sizes 向量會決定輸出陣列的大小。new_sizes 中索引 0 的值是維度 0 的大小,索引 1 的值是維度 1 的大小,以此類推。new_size 維度的乘積必須等於運算元尺寸大小的乘積。將收合的陣列修正為 new_sizes 定義的多維度陣列時,new_sizes 中的維度會由最慢 (最主要) 和最小變化 (最小的) 排序。

舉例來說, let v 是 24 個元素的陣列:

let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
                    { {20, 21, 22}, {25, 26, 27} },
                    { {30, 31, 32}, {35, 36, 37} },
                    { {40, 41, 42}, {45, 46, 47} } };

In-order collapse:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};

let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
                          {20, 21, 22}, {25, 26, 27},
                          {30, 31, 32}, {35, 36, 37},
                          {40, 41, 42}, {45, 46, 47} };

Out-of-order collapse:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24]  {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
                          15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};

let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] { {10, 20, 30}, {40, 11, 21},
                          {31, 41, 12}, {22, 32, 42},
                          {15, 25, 35}, {45, 16, 26},
                          {36, 46, 17}, {27, 37, 47} };


let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] { { {10, 20}, {30, 40},
                              {11, 21}, {31, 41},
                              {12, 22}, {32, 42} },
                             { {15, 25}, {35, 45},
                              {16, 26}, {36, 46},
                              {17, 27}, {37, 47} } };

作為特殊案例,重塑可將單一元素陣列轉換為純量,反之亦然。比如

Reshape(f32[1x1] { {5} }, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] { {5} };

倒轉 (反向)

另請參閱 XlaBuilder::Rev

Rev(operand, dimensions)

引數 類型 語義
operand XlaOp T 類型的陣列
dimensions ArraySlice<int64> 即可反向調整

沿著指定的 dimensions 反向排序 operand 陣列中的元素順序,產生相同形狀的輸出陣列。多維度索引的運算元陣列的每個元素,會儲存在經過轉換的索引的輸出陣列中。轉換多維度索引會反轉每個維度中的索引,使其反轉 (也就是說,如果大小的 N 是反向維度之一,其索引 i 就會轉換為 N - 1 - i)。

Rev 運算的一個用途是,在類神經網路的漸層計算期間,沿著兩個區間維度反轉卷積權重陣列。

RngNormal

另請參閱 XlaBuilder::RngNormal

這個外掛程式能建構指定形狀的輸出,並根據正規分佈情形產生的隨機數字產生 \(N(\mu, \sigma)\) 。 \(\mu\) 和 \(\sigma\)參數以及輸出形狀必須具備浮點元素類型。參數更進一步必須是純量值。

RngNormal(mu, sigma, shape)

引數 類型 語義
mu XlaOp T 類型的純量 (指定產生的數字平均值)
sigma XlaOp T 類型的純量 (用於指定產生的標準差)
shape Shape T 類型的輸出形狀

RngUniform

另請參閱 XlaBuilder::RngUniform

這個外掛程式能建構指定形狀的輸出內容,並根據間隔 \([a,b)\)的統一分佈狀況產生隨機號碼。參數和輸出元素類型必須是布林值類型、積分類型或浮點類型,且類型必須一致。CPU 和 GPU 後端目前僅支援 F64、F32、F16、BF16、S64、U64、S32 和 U32。此外,參數也必須是純量值。如果 \(b <= a\) 結果是實作定義,

RngUniform(a, b, shape)

引數 類型 語義
a XlaOp T 類型的純量指定時間間隔下限
b XlaOp T 類型的純量指定間隔上限
shape Shape T 類型的輸出形狀

RngBitGenerator

使用指定演算法 (或後端預設值) 產生特定形狀的輸出內容,並以統一的隨機位元填入資料,並傳回更新後的狀態 (形狀與初始狀態相同),以及產生的隨機資料。

初始狀態是目前產生隨機號碼的初始狀態。而必要的形狀和有效值會視使用的演算法而定。

保證具有初始狀態的確定性函式,但「不保證」在後端和不同編譯器版本之間具有確定性。

RngBitGenerator(algorithm, key, shape)

引數 類型 語義
algorithm RandomAlgorithm 要使用的 PRNG 演算法。
initial_state XlaOp PRNG 演算法的初始狀態。
shape Shape 產生資料的輸出形狀。

algorithm 可用的值:

散布圖

XLA 散佈圖作業會產生一系列結果,這些結果為輸入陣列 operands 的值,其中數個配量 (位於 scatter_indices 指定的索引) 使用 update_computation 更新為 updates 中的值序列。

另請參閱 XlaBuilder::Scatter

scatter(operands..., scatter_indices, updates..., update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

引數 類型 語義
operands N XlaOp 的順序 要分隔成 N 個 T_0, ..., T_N 類型的陣列。
scatter_indices XlaOp 陣列包含必須分散到的切片起始索引。
updates N XlaOp 的順序 T_0, ..., T_N 類型的 N 陣列。updates[i] 包含分散 operands[i] 必須使用的值。
update_computation XlaComputation 用於合併輸入陣列中現有值和散佈期間更新的計算。這項計算應是 T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N) 類型。
index_vector_dim int64 scatter_indices 中包含起始索引的維度。
update_window_dims ArraySlice<int64> updates 形狀的維度組合,採用視窗尺寸
inserted_window_dims ArraySlice<int64> 必須插入 updates 形狀的一組視窗尺寸
scatter_dims_to_operand_dims ArraySlice<int64> 維度會對應散佈圖索引和運算元索引空間。系統會將這個陣列解讀為將 i 對應至 scatter_dims_to_operand_dims[i]。請務必使用一對一的教學方式。
indices_are_sorted bool 是否保證呼叫均會由呼叫端排序。

在此情況下:

  • N 必須大於或等於 1。
  • operands[0]、...、operands[N-1] 的尺寸必須相同。
  • updates[0]、...、updates[N-1] 的尺寸必須相同。
  • 如果值為 N = 1Collate(T)T
  • 如果值為 N > 1Collate(T_0, ..., T_N)T 類型的 N 元素元組。

如果 index_vector_dim 等於 scatter_indices.rank,我們隱含地認為 scatter_indices 具有結尾的 1 維度。

我們將 ArraySlice<int64> 類型的 update_scatter_dims 定義為 updates 形狀中不在 update_window_dims 內的維度組合 (遞增順序)。

散佈工具的引數應遵循以下限制:

  • 每個 updates 陣列都必須屬於 update_window_dims.size + scatter_indices.rank - 1 排名。

  • 每個 updates 陣列中的維度 i 邊界都必須符合下列規定:

    • 如果 update_window_dims 中出現 i (意即部分 kupdate_window_dims[k]),則在計算 inserted_window_dims 之後,updates 中的維度 i 邊界不得超過 operand 的對應範圍 (亦即 adjusted_window_bounds[k],其中 adjusted_window_bounds 包含的 operand 邊界已從索引中移除 inserted_window_dims)。
    • 如果 update_scatter_dims 中有 i (即部分 kupdate_scatter_dims[k]),則 updates 中的維度 i 必須等於 scatter_indices 的對應邊界,就會略過 index_vector_dim (亦即如果 k < index_vector_dimscatter_indices.shape.dims[k+1] 則略過 index_vector_dim)。kscatter_indices.shape.dims
  • update_window_dims 必須按遞增順序排序、沒有任何重複的維度號碼,且位於範圍 [0, updates.rank)

  • inserted_window_dims 必須按遞增順序排序、沒有任何重複的維度號碼,且位於範圍 [0, operand.rank)

  • operand.rank 必須等於 update_window_dims.sizeinserted_window_dims.size 的總和。

  • scatter_dims_to_operand_dims.size 必須等於 scatter_indices.shape.dims[index_vector_dim],且值必須位於 [0, operand.rank) 範圍內。

針對每個 updates 陣列中的指定索引 U,在必須套用此更新的對應 operands 陣列中的對應索引 I,計算方式如下:

  1. 允許 G = { U[k] 用於 update_scatter_dims } 中的 k。使用 G 來查詢 scatter_indices 陣列中的索引向量 S,以便 S[i] = scatter_indices[Merge(G, i)],其中 Merge(A, b) 會將位於位置 index_vector_dim 的 b 插入 A。
  2. 使用 scatter_dims_to_operand_dims 對應 S,使用 S 建立索引 Sinoperand。更正式:
    1. Sin[scatter_dims_to_operand_dims[k]] = S[k] (如果 k < scatter_dims_to_operand_dims.size)。
    2. Sin[_] = 0 否則。
  3. 根據 inserted_window_dims 將索引散佈在 Uupdate_window_dims 處,藉此為每個 operands 陣列建立索引 Win。更正式:
    1. Win[window_dims_to_operand_dims(k)] = U[k] (如果 k 位於 update_window_dims),其中 window_dims_to_operand_dims 是網域 [0, update_window_dims.size) 和範圍 [0, operand.rank) \ inserted_window_dims 的單調函式。(例如,如果 update_window_dims.size4operand.rank6inserted_window_dims 是 {02},則 window_dims_to_operand_dims 為 {01132435})。
    2. Win[_] = 0 否則。
  4. IWin + Sin,其中 + 代表元素相關加法。

簡單來說,散佈運算可定義如下。

  • 使用 operands 初始化 output (即所有索引 J 的值),適用於 operands[J] 陣列中的所有 O
    output[J][O] = operands[J][O]
  • 針對 updates[J] 陣列中的每個索引 Uoperand[J] 陣列中對應的索引 O,如果 Ooutput 的有效索引:
    (output[0][O], ..., output[N-1][O]) =update_computation(output[0][O], ..., ,output[N-1][O],updates[0][U], ...,updates[N-1][U])

更新的套用順序不具確定性。因此,當 updates 中的多個索引參照 operands 中的同一個索引時,output 中相對應的值將非確定性。

請注意,傳遞至 update_computation 的第一個參數一律是 output 陣列中的目前值,第二個參數一律會是 updates 陣列的值。特別是當 update_computation 並非可變動時,這一點尤其重要。

如果將 indices_are_sorted 設為 true,XLA 就會假設 start_indices 會依照使用者排序 (遞增的 start_index_map 順序)。如果語意不同,則會定義語意。

相反地,散佈運算可以視為集合運算的「反向」,也就是散佈運算會更新對應集合運算所擷取的輸入元素。

如需詳細的非正式說明與範例,請參閱 Gather 下方的「正式說明」一節。

選取

另請參閱 XlaBuilder::Select

根據述詞陣列的值,從兩個輸入陣列的元素建構輸出陣列。

Select(pred, on_true, on_false)

引數 類型 語義
pred XlaOp PRED 類型的陣列
on_true XlaOp T 類型的陣列
on_false XlaOp T 類型的陣列

陣列 on_trueon_false 的形狀必須相同。這也是輸出陣列的形狀。陣列 pred 的維度必須與 on_trueon_false 相同,且其類型必須是 PRED 元素類型。

對於 pred 的每個元素 P,如果 P 的值是 true,輸出陣列的對應元素會從 on_true 取用;如果 P 的值是 false,則來自 on_false。做為廣播的受限形式,pred 可以是 PRED 類型的純量。在這種情況下,如果 predtrue,則輸出陣列會從 on_true 中取用,如果 predfalse,則會從 on_false 取得。

非純量 pred 的範例:

let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};

純量 pred 範例:

let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};

系統支援對元組之間的選取。為此,系統會將元組視為純量類型。如果 on_trueon_false 是元組 (形狀必須相同!),則 pred 必須是 PRED 類型的純量。

SelectAndScatter

另請參閱 XlaBuilder::SelectAndScatter

這項作業可以視為複合作業。先計算 operand 陣列的 ReduceWindow,從每個視窗選取元素,接著將 source 陣列分散至所選元素的索引,然後建構形狀與運算元陣列相同的輸出陣列。二進位 select 函式可以跨各個視窗套用元素,藉此從各個視窗選取元素,並使用屬性呼叫第一個參數的索引向量小於第二個參數的索引向量。如果選取第一個參數,則 select 函式會傳回 true;如果選取第二個參數,則會傳回 false,且函式必須具有傳輸能力 (也就是說,如果 select(a, b)select(b, c)true,則 select(a, c) 也是 true),這樣一來,所選元素就不會取決於特定視窗內的元素遍歷順序。

系統會將 scatter 函式套用至輸出陣列中的每個索引。它採用兩個純量參數:

  1. 輸出陣列中所選索引的目前值
  2. 套用至所選索引的 source 散佈值

此函式會合併兩個參數並傳回純量值,用於在輸出陣列中的所選索引更新值。輸出陣列的所有索引一開始都是設為 init_value

輸出陣列的形狀與 operand 陣列相同,且 source 陣列的形狀必須與對 operand 陣列套用 ReduceWindow 運算的結果相同。SelectAndScatter 可用來對類神經網路中的集區層的漸層值反向傳播。

SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter)

引數 類型 語義
operand XlaOp 視窗會滑入 T 類型的陣列
select XlaComputation T, T -> PRED 類型的二進位計算,適用於各個視窗中的所有元素;如果選取第一個參數,會傳回 true;如果選取第二個參數,則傳回 false
window_dimensions ArraySlice<int64> 視窗維度值的整數陣列
window_strides ArraySlice<int64> 區間步值的整數陣列
padding Padding 視窗的邊框間距類型 (Padding::kSame 或 Padding::kValid)
source XlaOp T 型別陣列,其中的值要用於散佈法
init_value XlaOp 類型 T 的純量值做為輸出陣列的初始值
scatter XlaComputation T, T -> T 類型的二進位計算,用於套用每個散佈來源元素及其目的地元素

下圖顯示使用 SelectAndScatter 的範例,其中 select 函式計算其參數的最大值。請注意,當視窗重疊時,如下方圖 (2) 所示,不同區間可能會多次選取 operand 陣列的索引。在圖中,兩個頂部視窗 (藍色和紅色) 都會選取值 9 的元素,而二進位加法 scatter 函式則會產生值 8 (2 + 6) 的輸出元素。

scatter 函式的評估順序是任意值,可能不具有確定性。因此,scatter 函式不應過於敏感,藉此重新建立關聯。詳情請參閱 Reduce 結構定義中有關關聯性的討論。

傳送

另請參閱 XlaBuilder::Send

Send(operand, channel_handle)

引數 類型 語義
operand XlaOp 要傳送的資料 (T 類型的陣列)
channel_handle ChannelHandle 每個傳送/接收組合的專屬 ID

將指定的運算元資料傳送到其他共用相同管道控制代碼的其他運算中的 Recv 指令。不傳回任何資料。

Recv 作業類似,Send 作業的用戶端 API 代表同步通訊,且會在內部分解成 2 個 HLO 指令 (SendSendDone),以啟用非同步資料移轉。另請參閱 HloInstruction::CreateSendHloInstruction::CreateSendDone

Send(HloInstruction operand, int64 channel_id)

針對具有相同管道 ID 的 Recv 指令分配的資源,啟動非同步轉移運算元。傳回結構定義,下列 SendDone 指令用於等待資料移轉完成。結構定義是 {operand (shape)、要求 ID (U32)} 的元組,且只能用於 SendDone 指令。

SendDone(HloInstruction context)

假設有使用 Send 指令建立情境,系統會等待資料移轉完成。指示不會傳回任何資料。

排定頻道發布時間

每個管道 (RecvRecvDoneSendSendDone) 4 指令的執行順序如下。

  • Recv發生在 Send之前
  • Send發生在 RecvDone之前
  • Recv發生在 RecvDone之前
  • Send發生在 SendDone之前

當後端編譯器針對透過管道指令通訊的每個運算產生線性排程時,運算作業就不能沒有循環。舉例來說,下方的排程會導致死結。

配量

另請參閱 XlaBuilder::Slice

切片功能會從輸入陣列擷取子陣列。子陣列與輸入的排名相同,內含在輸入陣列內的定界框內的值,其中定界框的維度和索引會指定為配量運算的引數。

Slice(operand, start_indices, limit_indices, strides)

引數 類型 語義
operand XlaOp T 類型的 N 維陣列
start_indices ArraySlice<int64> 列出 N 個整數,內含每個維度切片的起始索引。值必須大於或等於 0。
limit_indices ArraySlice<int64> 列出 N 個整數,其中包含每個維度的切片的結尾索引 (不含索引)。每個值都必須大於或等於個別的維度 start_indices 值,且小於或等於維度大小。
strides ArraySlice<int64> N 整數清單,這些整數會決定切片的輸入步長。區塊會挑選維度 d 中的每個 strides[d] 元素。

1 維範例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
  {2.0, 3.0}

2D 範例:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }

Slice(b, {2, 1}, {4, 3}) produces:
  { { 7.0,  8.0},
    {10.0, 11.0} }

排序

另請參閱 XlaBuilder::Sort

Sort(operands, comparator, dimension, is_stable)

引數 類型 語義
operands ArraySlice<XlaOp> 要排序的運算元。
comparator XlaComputation 要使用的比較子運算。
dimension int64 要排序的維度。
is_stable bool 是否應使用穩定排序。

如果只提供一個運算元:

  • 如果運算元是 rank-1 張量 (陣列),結果就會是已排序的陣列。 如果您想要以遞增順序將陣列排序,比較器應執行小於比較的結果。正式排序後,陣列會保留所有索引位置 i, j,且 i < jcomparator(value[i], value[j]) = comparator(value[j], value[i]) = falsecomparator(value[i], value[j]) = true

  • 如果運算元排名較高,運算元會依照提供的維度排序。舉例來說,如果是 rank-2 張量 (矩陣),0 的維度值會獨立排序每欄,1 的維度值則會獨立排序每一列。如未提供維度編號,系統會預設選擇最後一個維度。排序的維度會套用相同的排序順序,與排名-1 的情況相同。

如果提供了 n > 1 運算元:

  • 所有 n 運算元都必須是相同維度的張量。張量的元素類型可能不同。

  • 所有運算元都會一起排序,而不是個別排序。概念上來說,運算元會被視為元組。檢查索引位置 ij 中的每個運算元元素是否需要替換時,系統會使用 2 * n 純量參數呼叫比較子,其中參數 2 * k 對應至 k-th 運算元位置 i 的值,而參數 2 * k + 1 則對應到來自 k-th 運算元位置 j 的值。一般而言,比較子會彼此比較 2 * k2 * k + 1 參數,並可能會使用其他參數組合做為對應的斷路器。

  • 結果是元組,由按排序順序排列的運算元 (以及上述提供的維度)。元組的 i-th 運算元與 Sort 的 i-th 運算元相對應。

舉例來說,如果有三個運算元 operand0 = [3, 1]operand1 = [42, 50]operand2 = [-3.0, 1.1],而比較子只比較小於 operand0 的值,則排序的輸出內容為 ([1, 3], [50, 42], [1.1, -3.0]) 元組。

如果 is_stable 設為 true,則排序保證會穩定,也就是說,如果比較子所視為相等的元素,系統會保留相等值的相對順序。只有在 comparator(e1, e2) = comparator(e2, e1) = false 時,兩個元素 e1e2 才算相等。根據預設,is_stable 會設為 false。

轉置

另請參閱 tf.reshape 作業。

Transpose(operand)

引數 類型 語義
operand XlaOp 要轉置的運算元。
permutation ArraySlice<int64> 如何設定維度靜音。

使用指定排列的運算元維度,因此 ∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]

這與 Reshape(運算元, 排列, 互斥(permutation, operand.shape.dimensions)) 相同。

TriangularSolve

另請參閱 XlaBuilder::TriangularSolve

透過向前或反向取代,解開具有較低或上三角係數矩陣的系統方程式。沿著先行維度播送,這個處理常式會解決 x 變數的其中一個矩陣系統 op(a) * x = b (或 x * op(a) = b),其由 ab 構成,其中 op(a)op(a) = aop(a) = Transpose(a)op(a) = Conj(Transpose(a))

TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)

引數 類型 語義
a XlaOp 排名超過 2 陣列的複雜或浮點類型 (含形狀 [..., M, M])。
b XlaOp 如果 left_side 為 true,則排名 > 2 且類型相同、[..., M, K] 為 true 的陣列,否則為 [..., K, M]
left_side bool 表示要解析表單 op(a) * x = b (true) 或 x * op(a) = b (false) 的系統。
lower bool 是否要使用 a 的頂端或下三角形。
unit_diagonal bool 如為 true,系統會假設 a 的對角元素為 1,且不會存取。
transpose_a Transpose 決定要依原樣使用 a、將其轉置或接受其轉讓。

lower 的值而定,系統只會從 a 的下方/上三角形讀取輸入資料。系統會忽略其他三角形的值。輸出資料會在同一個三角形中傳回;其他三角形的值則是實作定義,可以是任何值。

如果 ab 的排名大於 2,系統會將這兩者視為矩陣的批次,除了次要 2 個維度以外的所有維度。ab 的批次尺寸必須相同。

元組

另請參閱 XlaBuilder::Tuple

包含數量不定資料控點的元組,每個都有自己的形狀。

這類似於 C++ 中的 std::tuple。概念說明:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);

您可以透過 GetTupleElement 作業解構 (存取) 元組。

雖然

另請參閱 XlaBuilder::While

While(condition, body, init)

引數 類型 語義
condition XlaComputation T -> PRED 類型的 XlaComputation,定義迴圈的終止條件。
body XlaComputation T -> T 類型的 XlaComputation,定義迴圈主體。
init T conditionbody 參數的初始值。

依序執行 body,直到 condition 失敗。這類似於在許多其他語言中循環的迴圈,但下列差異和限制除外。

  • While 節點會傳回 T 類型的值,這是 body 上次執行後的結果。
  • 系統會靜態決定 T 類型的形狀,而且所有疊代均須相同。

運算的 T 參數會在第一次疊代中使用 init 值初始化,並在每次後續疊代時從 body 自動更新為新結果。

While 節點的其中一個主要用途是實作類神經網路中重複執行訓練。下方顯示簡化的虛擬程式碼,以及代表運算的圖形。您可以在 while_test.cc 中找到這個程式碼。此範例中 T 的類型是 Tuple,其中包含疊代次數的 int32 和累計器的 vector[10]。針對 1000 次疊代,迴圈會持續將常數向量加入累計器。

// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
  iteration = result(0) + 1;
  new_vector = result(1) + constant_vector[10];
  result = {iteration, new_vector};
}