Cette page a été traduite par l'API Cloud Translation.
Switch to English

tf.estimator.Estimator

Tensorflow 1 Version Voir la source sur GitHub

classe estimateur pour former et évaluer les modèles tensorflow.

Hérite de: Estimator

Utilisé dans les ordinateurs portables

Utilisé dans le guide Utilisé dans les tutoriels

Le Estimator objet encapsule un modèle qui est spécifié par un model_fn qui, entrées données et un certain nombre d'autres paramètres, renvoie les ops nécessaires pour effectuer la formation, l' évaluation, ou les prévisions.

Toutes les sorties (points de contrôle, les fichiers d'événements, etc.) sont écrits model_dir , ou un sous - répertoire de celui - ci. Si model_dir n'est pas défini, un répertoire temporaire est utilisé.

La config argument peut être transmis tf.estimator.RunConfig objet contenant des informations sur l'environnement d'exécution. Il est transmis au model_fn , si le model_fn a un paramètre appelé « config » (et les fonctions d'entrée de la même manière). Si la config paramètre n'est pas passé, il est instancié par l' Estimator . Ne pas passer des moyens de configuration que par défaut utiles pour l'exécution locale sont utilisés. Estimator rend config à la disposition du modèle (par exemple, pour permettre une spécialisation en fonction du nombre de travailleurs disponibles), et utilise également certains de ses champs de contrôle internes, en particulier en ce qui concerne les points de reprise.

Le params argument contient hyperparam'etres. Il est passé à la model_fn , si le model_fn a un paramètre appelé « params », et aux fonctions d'entrée de la même manière. Estimator ne passe que params le long, il ne vérifie pas. La structure de params est donc tout à fait au développeur.

Aucun des Estimator méthodes de peut être redéfinie dans les sous - classes (son constructeur applique cette). Devraient utiliser les sous - classes model_fn pour configurer la classe de base, et peuvent ajouter des méthodes de mise en œuvre des fonctionnalités spécialisées.

Voir estimateurs pour plus d' informations.

Pour démarrage à chaud un Estimator :

 estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")
 

Pour plus de détails sur la configuration de démarrage de chaud, voir tf.estimator.WarmStartSettings .

model_fn fonction modèle. Suite à la signature:

  • features - Ceci est le premier article retourné du input_fn passé à train , evaluate et predict . Cela devrait être un seul tf.Tensor ou dict de même.
  • labels - Ceci est le deuxième article retourné du input_fn passé à train , evaluate et predict . Cela devrait être un seul tf.Tensor ou dict de même (pour les modèles multi-têtes). Si le mode est tf.estimator.ModeKeys.PREDICT , labels=None seront transmises. Si la model_fn la signature «n'accepte pas le mode , le model_fn doit toujours être capable de gérer les labels=None .
  • mode - en option. Indique si cela est la formation, l'évaluation ou la prévision. Voir tf.estimator.ModeKeys . params - en option dict de hyperparam'etres. Recevra ce qui est passé à Estimator params paramètre. Cela permet de configurer Estimateurs de réglage des paramètres hyper.
  • config - option estimator.RunConfig objet. Recevra ce qui est passé à estimateur comme config paramètre, ou une valeur par défaut. Permet la mise en place des choses dans votre model_fn en fonction de la configuration tels que num_ps_replicas , ou model_dir .
  • Retours - tf.estimator.EstimatorSpec
model_dir Répertoire pour enregistrer les paramètres du modèle, graphique, etc Cela peut également être utilisé pour les postes de contrôle de charge à partir du répertoire dans un estimateur de continuer la formation d'un modèle précédemment enregistré. Si PathLike objet, le chemin sera résolu. Si None , le model_dir dans config sera utilisée si elle est définie. Si les deux sont définis, ils doivent être identiques. Si les deux sont None , un répertoire temporaire sera utilisé.
config estimator.RunConfig objet de configuration.
params dict des paramètres hyper qui seront transmis dans model_fn . Les clés sont les noms des paramètres, les valeurs sont les types de python de base.
warm_start_from Chaîne facultative à un point de contrôle filepath ou SavedModel à démarrage à chaud à partir, ou un tf.estimator.WarmStartSettings objet à configurer entièrement démarrage à chaud. Si aucune, sont démarrés-chaud variables ne Personnalisable. Si la chaîne filepath est fournie à la place d'un tf.estimator.WarmStartSettings , toutes les variables sont démarrés-chaud, et on suppose que les vocabulaires et les tf.Tensor noms ne changent pas.

ValueError paramètres de model_fn ne correspondent pas params .
ValueError si cela est appelé par une sous - classe et si cette classe remplace un membre Estimator .

Compatibilité Désireuse

Méthodes d'appel de Estimator fonctionneront pendant l' exécution avide est activée. Cependant, la model_fn et input_fn n'est pas exécuté avec enthousiasme, Estimator basculeront en mode graphique avant d' appeler toutes les fonctions fournies par l' utilisateur (crochets incl.), De sorte que leur code doit être compatible avec l' exécution en mode graphique. Notez que input_fn code à l' aide tf.data fonctionne généralement à la fois graphique et les modes avides.

config

export_savedmodel

model_dir

model_fn Retourne le model_fn qui est lié à self.params .
params

méthodes

eval_dir

Voir la source

Indique le nom du répertoire dans lequel les paramètres d'évaluation sont sous-évalués.

args
name Nom de l'évaluation si les besoins des utilisateurs d'exécuter plusieurs évaluations sur différents ensembles de données, telles que les données sur la formation vs données de test. Mesures pour les différentes évaluations sont enregistrées dans des dossiers séparés, et apparaissent séparément dans tensorboard.

Retour
Une chaîne qui est le chemin du répertoire contient des mesures d'évaluation.

evaluate

Voir la source

Évalue le modèle compte tenu des données d'évaluation input_fn .

Pour chaque étape, les appels input_fn , qui retourne un lot de données. Évalue jusqu'à ce que:

  • les steps lots sont traités, ou
  • input_fn déclenche une exception en fin de saisie ( tf.errors.OutOfRangeError ou StopIteration ).

args
input_fn Une fonction qui construit les données d'entrée pour l'évaluation. Voir Premade Estimateurs pour plus d' informations. La fonction doit construire et retourner un des éléments suivants:

  • Un tf.data.Dataset objet: sorties de Dataset objet doit être un tuple (features, labels) avec les mêmes contraintes que ci - dessous.
  • Un tuple (features, labels) : Où features est un tf.Tensor ou un dictionnaire de nom de la fonction de chaîne à Tensor et des labels est un Tensor ou un dictionnaire de nom d'étiquette de chaîne à Tensor . Les deux features et les labels sont consommés par model_fn . Ils doivent satisfaire les attentes des model_fn d'entrées.
steps Nombre d'étapes pour pour évaluer le modèle. Si None , jusqu'à ce que evalue input_fn soulève une exception de fin de saisie.
hooks