Se usó la API de Cloud Translation para traducir esta página.
Switch to English

tf.estimator.Estimator

TensorFlow 1 versión Ver código fuente en GitHub

Estimador de clase para entrenar y evaluar modelos TensorFlow.

Hereda de: Estimator

Se utiliza en los cuadernos

Se utiliza en la guía Se utiliza en los tutoriales

El Estimator objeto se ajusta un modelo que se especifica por un model_fn , que, entradas dadas y un número de otros parámetros, devuelve los ops necesario realizar formación, evaluación, o predicciones.

Todas las salidas (puestos de control, archivos de eventos, etc.) se escriben en model_dir , o un subdirectorio del mismo. Si model_dir no está definida, se utiliza un directorio temporal.

La config argumento se puede pasar tf.estimator.RunConfig objeto que contiene información sobre el entorno de ejecución. Se transmite a la model_fn , si el model_fn tiene un parámetro llamado "config" (y funciones de entrada de la misma manera). Si la config de parámetros no se pasa, se crea una instancia por el Estimator . No pasar medios de configuración que se utilizan valores por defecto útiles para la ejecución local. Estimator hace que config disponible para el modelo (por ejemplo, para permitir la especialización en función del número de trabajadores disponibles), y también utiliza algunos de sus campos para controlar internos, especialmente en relación con los puntos de control.

El params argumento contiene hiperparámetros. Se pasa a la model_fn , si el model_fn tiene un parámetro denominado "params", y para las funciones de entrada de la misma manera. Estimator sólo deja pasar a lo largo de params, no inspeccionarlo. La estructura de params , por tanto, es totalmente de los desarrolladores.

Ninguno de Estimator métodos 's se puede anular en subclases (su constructor hace cumplir esto). Las subclases deben utilizar model_fn para configurar la clase base, y pueden agregar métodos de aplicación funcionalidad especializada.

Ver estimadores para más información.

Para calentar en marcha un Estimator :

 estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")
 

Para más detalles sobre la configuración de arranque en caliente, consulte tf.estimator.WarmStartSettings .

model_fn función del modelo. Sigue la firma:

  • features - Este es el primer elemento de regresar de la input_fn pasado al train , evaluate y predict . Esto debería ser una sola tf.Tensor o dict del mismo.
  • labels - Este es el segundo artículo devuelto desde el input_fn pasó a train , evaluate y predict . Esto debería ser una sola tf.Tensor o dict del mismo (para los modelos de cabezales múltiples). Si el modo es tf.estimator.ModeKeys.PREDICT , labels=None serán pasados. Si el model_fn firma 's no acepta mode , la model_fn aún debe ser capaz de manejar labels=None .
  • mode - Opcional. Especifica si se trata de la formación, la evaluación o la predicción. Ver tf.estimator.ModeKeys . params - Opcional dict de hiperparámetros. Recibirán lo que se pasa al Estimador en params parámetro. Esto permite configurar los estimadores de ajuste de parámetros hiper.
  • config - Opcional estimator.RunConfig objeto. Recibirán lo que se pasa al Estimador como su config de parámetros, o un valor predeterminado. Permite la creación de las cosas en su model_fn base a la configuración como num_ps_replicas , o model_dir .
  • Devoluciones - tf.estimator.EstimatorSpec
model_dir Directorio para guardar los parámetros del modelo, gráfico, etc Esto también puede ser usado para los puestos de control de carga desde el directorio en un estimador para continuar la formación de un modelo previamente guardado. Si PathLike objeto, se resolverá el camino. Si None , el model_dir en config se utilizará si se establece. Si ambos están establecidos, deben ser el mismo. Si ambos son None , se utilizará un directorio temporal.
config estimator.RunConfig objeto de configuración.
params dict de los parámetros hiper que se pasarán en model_fn . Las claves son los nombres de los parámetros, los valores son los tipos básicos de pitón.
warm_start_from Opcional cadena de ruta de archivo a un puesto de control o SavedModel de arranque en caliente de productos o la tf.estimator.WarmStartSettings objeto de configurar totalmente calentamiento inicial. Si ninguno, sólo las variables son entrenables calentamiento comenzado. Si se proporciona la ruta de archivo de una cadena en lugar tf.estimator.WarmStartSettings , a continuación, todas las variables son empezado-caliente, y se supone que los vocabularios y tf.Tensor nombres no se han modificado.

ValueError parámetros de model_fn no coinciden params .
ValueError Si esto se llama a través de una subclase y si esa clase anula un miembro del Estimator .

Compatibilidad ansiosos

Métodos de llamada de Estimator trabajarán mientras que la ejecución ansiosa está activado. Sin embargo, el model_fn y input_fn no se ejecuta con entusiasmo, Estimator cambiará al modo gráfico antes de llamar a todas las funciones proporcionadas por el usuario (incl. Ganchos), por lo que su código tiene que ser compatible con la ejecución del modo gráfico. Tenga en cuenta que input_fn código usando tf.data general funciona tanto en el modo gráfico y ansiosos.

config

export_savedmodel

model_dir

model_fn Devuelve el model_fn que está obligado a self.params .
params

Métodos

eval_dir

Ver fuente

Muestra el nombre del directorio en el que se vierten métricas de evaluación.

args
name Nombre de la evaluación si las necesidades del usuario para ejecutar múltiples evaluaciones en diferentes conjuntos de datos, como en el entrenamiento de los datos frente a los datos de prueba. Métricas para diferentes evaluaciones se guardan en carpetas separadas, y aparecen por separado en tensorboard.

Devoluciones
Una cadena que es la ruta de directorio contiene métricas de evaluación.

evaluate

Ver fuente

Evalúa la modelo dado datos de la evaluación input_fn .

Para cada paso, las llamadas input_fn , que devuelve un lote de datos. Evalúa hasta que:

args
input_fn Una función que construye los datos de entrada para la evaluación. Ver preparado de antemano estimadores para más información. La función debe construir y devolver uno de los siguientes:

  • A tf.data.Dataset objeto: Salidas de Dataset de objeto debe ser una tupla (features, labels) con mismas restricciones que a continuación.
  • Una tupla (features, labels) : ¿Dónde features es un tf.Tensor o un diccionario de nombre de elemento de cadena para Tensor y labels es una Tensor o un diccionario de nombre de etiqueta cadena a Tensor . Ambas features y labels son consumidos por model_fn . Deben satisfacer las expectativas de model_fn de entradas.
steps Número de pasos para la cual evaluar modelo. Si None , evalúa hasta input_fn genera una excepción al final de la entrada.
hooks Lista de tf.train.SessionRunHook casos subclase. Se utiliza para las devoluciones de llamada dentro de la llamada evaluación.
checkpoint_path Trayectoria de un punto de control específico para evaluar. Si None , el último puesto de control en model_dir se utiliza. Si no hay puestos de control en model_dir , la evaluación se ejecuta con iniciada recientemente Variables en lugar de los restaurados de puesto de control.
name Nombre de la evaluación si las necesidades del usuario para ejecutar múltiples evaluaciones en diferentes conjuntos de datos, como en el entrenamiento de los datos frente a los datos de prueba. Métricas para diferentes evaluaciones se guardan en carpetas separadas, y aparecen por separado en tensorboard.

Devoluciones
A dict que contiene las métricas de evaluación especificados en model_fn tecleado por el nombre, así como una entrada global_step que contiene el valor de la etapa global para el que se realizó esta evaluación. Para estimadores enlatados, el dict contiene la loss (pérdida media por mini lotes) y el average_loss (pérdida media por muestra). Clasificadores enlatados también devuelven la accuracy . Regresores enlatados también vuelven la label/mean y la prediction/mean .

aumentos
ValueError Si steps <= 0 .

experimental_export_all_saved_models

Ver fuente

Las exportaciones un SavedModel con tf.MetaGraphDefs para cada modo de operación deseado.

Para cada modo aprobada en a través de la input_receiver_fn_map , este método construye un nuevo gráfico llamando a la input_receiver_fn obtener característica y la etiqueta Tensor s. A continuación, este método llama al Estimator 's model_fn en el modo pasado para generar el gráfico modelo basado en las características y etiquetas, y restaura el punto de control dado (o, a falta de eso, el último punto de control) en el gráfico. Sólo uno de los modos se utiliza para ahorrar variables al SavedModel (orden de preferencia: tf.estimator.ModeKeys.TRAIN , tf.estimator.ModeKeys.EVAL , entonces tf.estimator.ModeKeys.PREDICT ), de manera que hasta tres tf.MetaGraphDefs se guardan con un único conjunto de variables en una sola SavedModel directorio.

Para las variables y tf.MetaGraphDefs , un directorio de exportación sellos de tiempo a continuación