CTCLossV2

classe finale publique CTCLossV2

Calcule la perte CTC (probabilité logarithmique) pour chaque entrée de lot. Calcule également

le dégradé. Cette classe effectue l'opération softmax pour vous, donc les entrées doivent être par exemple des projections linéaires des sorties par un LSTM.

Classes imbriquées

classe CTCLossV2.Options Attributs facultatifs pour CTCLossV2

Méthodes publiques

CTCLossV2 statique
créer ( Scope scope, Opérande <Float> entrées, Opérande <Long> labelsIndices, Opérande <Integer> labelsValues, Opérande <Integer> séquenceLength, Options... options)
Méthode d'usine pour créer une classe encapsulant une nouvelle opération CTCLossV2.
CTCLossV2.Options statique
ctcMergeRepeated (booléen ctcMergeRepeated)
Sortie <Flottant>
pente ()
Le gradient de « perte ».
CTCLossV2.Options statique
ignoreLongerOutputsThanInputs (booléen ignoreLongerOutputsThanInputs)
Sortie <Flottant>
perte ()
Un vecteur (lot) contenant des log-probabilités.
CTCLossV2.Options statique
preprocessCollapseRepeated (préprocessus booléenCollapseRepeated)

Méthodes héritées

Méthodes publiques

création de CTCLossV2 statique publique (portée de portée , entrées d' opérande <Float>, opérande <Long> labelsIndices, opérande <Integer> labelsValues, opérande <Integer> séquenceLength, options... options)

Méthode d'usine pour créer une classe encapsulant une nouvelle opération CTCLossV2.

Paramètres
portée portée actuelle
contributions 3-D, forme : `(max_time x batch_size x num_classes)`, les logits. L'étiquette vide par défaut est 0 plutôt que num_classes - 1.
étiquettesIndices Les indices d'un `SparseTensor `. `labels_indices(i, :) == [b, t]` signifie que `labels_values(i)` stocke l'identifiant pour `(lot b, heure t)`.
étiquettesValeurs Les valeurs (étiquettes) associées au lot et à l'heure donnés.
séquenceLongueur Un vecteur contenant des longueurs de séquence (lot).
choix porte des valeurs d'attributs facultatifs
Retour
  • une nouvelle instance de CTCLossV2

public statique CTCLossV2.Options ctcMergeRepeated (booléen ctcMergeRepeated)

Paramètres
ctcMergeRepeated Scalaire. Si la valeur est false, pendant le calcul CTC, les étiquettes non vides répétées ne seront pas fusionnées et seront interprétées comme des étiquettes individuelles. Il s'agit d'une version simplifiée de CTC.

sortie publique <Float> gradient ()

Le gradient de « perte ». 3D, forme : `(max_time x batch_size x num_classes)`.

public statique CTCLossV2.Options ignoreLongerOutputsThanInputs (booléen ignoreLongerOutputsThanInputs)

Paramètres
ignoreLongerOutputsThanInputs Scalaire. S'il est défini sur true, lors du calcul CTC, les éléments qui ont des séquences de sortie plus longues que les séquences d'entrée sont ignorés : ils ne contribuent pas au terme de perte et ont un gradient nul.

Sortie publique <Float> perte ()

Un vecteur (lot) contenant des log-probabilités.

public statique CTCLossV2.Options preprocessCollapseRepeated (booléen preprocessCollapseRepeated)

Paramètres
preprocessCollapseRepeated Scalaire, si vrai, les étiquettes répétées sont réduites avant le calcul du CTC.