XlaSendTPUEmbeddingGradients

パブリック最終クラスXlaSendTPUEmbeddingGradients

埋め込みテーブルの勾配更新を実行する操作。

gradients 引数は、XlaRecvTPUEmbeddingActivations の戻り値と同じ長さと形状を持つ TensorList ですが、埋め込みアクティベーションに関するモデルの損失の勾配が含まれています。埋め込みテーブルは、tpu.initialize_system に指定された TPUEmbeddingConfiguration プロトで指定されたオプティマイザを介して、これらの勾配から更新されます。

パブリックメソッド

静的XlaSendTPUEmbeddingGradients
create (スコープscope、Iterable< Operand <Float>> gradients、Iterable< Operand <Float>> learningRates、 Operand <?> deduplicationData、String config)
新しい XlaSendTPUEmbeddingGradients 操作をラップするクラスを作成するファクトリ メソッド。

継承されたメソッド

パブリックメソッド

public static XlaSendTPUEmbeddingGradients create (スコープscope、Iterable< Operand <Float>> gradients、Iterable< Operand <Float>> learningRates、 Operand <?> deduplicationData、String config)

新しい XlaSendTPUEmbeddingGradients 操作をラップするクラスを作成するファクトリ メソッド。

パラメーター
範囲現在のスコープ
グラデーション埋め込みテーブルを更新するための勾配の TensorList。
学習率オプティマイザーを介して埋め込みテーブルを更新するために使用される学習率の TensorList。 TensorList の長さは、TPUEmbeddingConfiguration プロトで指定された動的学習率タグの数と等しくなければなりません。
重複排除データ重複排除データを含む type=DT_VARIANT の Tensor。テンソルは、N 個の要素を含む XLA ネストされたタプルです (N は、TPU チップあたりのテンソル コアに対する埋め込みの数の比率です)。入れ子になったタプルの各要素は、ランク 1 テンソルのタプルです。各テンソルには、TensorCore での埋め込みルックアップ用のインデックス (DT_UINT32) または埋め込みルックアップ操作の出力に適用する重み (DT_FLOAT) が含まれています。
構成シリアル化された TPUEmbeddingConfiguration プロト。
戻り値
  • XlaSendTPUEmbeddingGradients の新しいインスタンス