Restez organisé à l'aide des collections Enregistrez et classez les contenus selon vos préférences.

tensorflow :: opérations :: FixedUnigramCandidateSampler

#include <candidate_sampling_ops.h>

Génère des étiquettes pour l'échantillonnage candidat avec une distribution unigramme apprise.

Résumé

Un échantillonneur unigramme pourrait utiliser une distribution unigramme fixe lue à partir d'un fichier ou transmise en tant que tableau en mémoire au lieu de construire la distribution à partir de données à la volée. Il existe également une option pour fausser la distribution en appliquant une puissance de distorsion aux poids.

Le fichier de vocabulaire doit être au format CSV, le dernier champ étant le poids associé au mot.

Pour chaque lot, cette opération sélectionne un seul ensemble d'étiquettes candidates échantillonnées.

Les avantages de l'échantillonnage des candidats par lot sont la simplicité et la possibilité d'une multiplication efficace par matrice dense. L'inconvénient est que les candidats échantillonnés doivent être choisis indépendamment du contexte et des véritables étiquettes.

Arguments:

  • scope: un objet Scope
  • true_classes: une matrice batch_size * num_true, dans laquelle chaque ligne contient les ID des num_true target_classes dans l'étiquette d'origine correspondante.
  • num_true: nombre d'étiquettes vraies par contexte.
  • num_sampled: nombre de candidats à échantillonner aléatoirement.
  • unique: si unique est vrai, nous échantillonnons avec rejet, de sorte que tous les candidats échantillonnés dans un lot soient uniques. Cela nécessite une certaine approximation pour estimer les probabilités d'échantillonnage après rejet.
  • range_max: L'échantillonneur échantillonnera les entiers de l'intervalle [0, range_max).

Attributs facultatifs (voir Attrs ):

  • vocab_file: Chaque ligne valide de ce fichier (qui doit avoir un format de type CSV) correspond à un identifiant de mot valide. Les identifiants sont dans un ordre séquentiel, à partir de num_reserved_ids. On s'attend à ce que la dernière entrée de chaque ligne soit une valeur correspondant au décompte ou à la probabilité relative. Exactement l'un des vocab_file et unigrams doit être passé à cette opération.
  • distorsion: la distorsion est utilisée pour fausser la distribution de probabilité unigramme. Chaque poids est d'abord élevé à la puissance de la distorsion avant de s'ajouter à la distribution unigramme interne. En conséquence, distorsion = 1.0 donne un échantillonnage unigramme régulier (tel que défini par le fichier de vocabulaire), et distorsion = 0.0 donne une distribution uniforme.
  • num_reserved_ids: En option, certains ID réservés peuvent être ajoutés dans la plage [0, ..., num_reserved_ids) par les utilisateurs. Un cas d'utilisation est qu'un jeton de mot inconnu spécial est utilisé comme ID 0. Ces ID auront une probabilité d'échantillonnage de 0.
  • num_shards: Un échantillonneur peut être utilisé pour échantillonner à partir d'un sous-ensemble de la plage d'origine afin d'accélérer l'ensemble du calcul grâce au parallélisme. Ce paramètre (avec «shard») indique le nombre de partitions utilisées dans le calcul global.
  • shard: Un échantillonneur peut être utilisé pour échantillonner à partir d'un sous-ensemble de la plage d'origine afin d'accélérer l'ensemble du calcul par parallélisme. Ce paramètre (avec «num_shards») indique le numéro de partition particulier d'une opération d'échantillonneur, lorsque le partitionnement est utilisé.
  • unigrammes: une liste de décomptes ou de probabilités unigrammes, un par ID dans un ordre séquentiel. Exactement l'un des vocab_file et unigrams doit être passé à cette opération.
  • seed: Si seed ou seed2 est défini pour être différent de zéro, le générateur de nombres aléatoires est amorcé par la graine donnée. Sinon, il est semé par une graine aléatoire.
  • seed2: Une deuxième graine pour éviter la collision de graines.

Retour:

  • Output sampled_candidates: Un vecteur de longueur num_sampled, dans lequel chaque élément est l'ID d'un candidat échantillonné.
  • Output true_expected_count: matrice batch_size * num_true, représentant le nombre de fois que chaque candidat est censé se produire dans un lot de candidats échantillonnés. Si unique = vrai, alors c'est une probabilité.
  • Output sampled_expected_count: un vecteur de longueur num_sampled, pour chaque candidat échantillonné représentant le nombre de fois où le candidat est censé se produire dans un lot de candidats échantillonnés. Si unique = vrai, alors c'est une probabilité.

Constructeurs et destructeurs

FixedUnigramCandidateSampler (const :: tensorflow::Scope & scope, :: tensorflow::Input true_classes, int64 num_true, int64 num_sampled, bool unique, int64 range_max)
FixedUnigramCandidateSampler (const :: tensorflow::Scope & scope, :: tensorflow::Input true_classes, int64 num_true, int64 num_sampled, bool unique, int64 range_max, const FixedUnigramCandidateSampler::Attrs & attrs)

Attributs publics

operation
sampled_candidates
sampled_expected_count
true_expected_count

Fonctions statiques publiques

Distortion (float x)
NumReservedIds (int64 x)
NumShards (int64 x)
Seed (int64 x)
Seed2 (int64 x)
Shard (int64 x)
Unigrams (const gtl::ArraySlice< float > & x)
VocabFile (StringPiece x)

Structs

tensorflow :: ops :: FixedUnigramCandidateSampler :: Attrs

Définisseurs d' attributs facultatifs pour FixedUnigramCandidateSampler .

Attributs publics

opération

Operation operation

sampled_candidates

::tensorflow::Output sampled_candidates

sampled_expected_count

::tensorflow::Output sampled_expected_count

true_expected_count

::tensorflow::Output true_expected_count

Fonctions publiques

FixedUnigramCandidateSampler

 FixedUnigramCandidateSampler(
  const ::tensorflow::Scope & scope,
  ::tensorflow::Input true_classes,
  int64 num_true,
  int64 num_sampled,
  bool unique,
  int64 range_max
)

FixedUnigramCandidateSampler

 FixedUnigramCandidateSampler(
  const ::tensorflow::Scope & scope,
  ::tensorflow::Input true_classes,
  int64 num_true,
  int64 num_sampled,
  bool unique,
  int64 range_max,
  const FixedUnigramCandidateSampler::Attrs & attrs
)

Fonctions statiques publiques

Distorsion

Attrs Distortion(
  float x
)

NumReservedIds

Attrs NumReservedIds(
  int64 x
)

NumShards

Attrs NumShards(
  int64 x
)

Planter

Attrs Seed(
  int64 x
)

Graine2

Attrs Seed2(
  int64 x
)

Tesson

Attrs Shard(
  int64 x
)

Unigrammes

Attrs Unigrams(
  const gtl::ArraySlice< float > & x
)

VocabFile

Attrs VocabFile(
  StringPiece x
)