This page was translated by the Cloud Translation API.
Switch to English

tf.estimator.experimental.RNNEstimator

Ver código fuente en GitHub

Un estimador para los modelos TensorFlow RNN con cabeza especificado por el usuario.

Hereda de: Estimator

Ejemplo:

 token_sequence = sequence_categorical_column_with_hash_bucket(...)
token_emb = embedding_column(categorical_column=token_sequence, ...)

estimator = RNNEstimator(
    head=tf.estimator.RegressionHead(),
    sequence_feature_columns=[token_emb],
    units=[32, 16], cell_type='lstm')

# Or with custom RNN cell:
def rnn_cell_fn(_):
  cells = [ tf.keras.layers.LSTMCell(size) for size in [32, 16] ]
  return tf.keras.layers.StackedRNNCells(cells)

estimator = RNNEstimator(
    head=tf.estimator.RegressionHead(),
    sequence_feature_columns=[token_emb],
    rnn_cell_fn=rnn_cell_fn)

# Input builders
def input_fn_train: # returns x, y
  pass
estimator.train(input_fn=input_fn_train, steps=100)

def input_fn_eval: # returns x, y
  pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
  pass
predictions = estimator.predict(input_fn=input_fn_predict)
 

Entrada del train y evaluate debería haber siguientes características, de lo contrario habrá un KeyError :

  • Si la cabeza de weight_column no es None , una característica con key=weight_column cuyo valor es un Tensor .
  • para cada column en sequence_feature_columns :
    • una característica con key=column.name cuyo value es un SparseTensor .
  • para cada column en context_feature_columns :
    • Si column es una CategoricalColumn , una característica con key=column.name cuyo value es un SparseTensor .
    • Si column es una WeightedCategoricalColumn , dos características: la primera con key el nombre de columna ID, el segundo con key el nombre de columna de peso. Ambas características value debe ser un SparseTensor .
    • Si column es una DenseColumn , una característica con key=column.name cuyo value es un Tensor .

La pérdida y la salida predicha están determinadas por la cabeza especificado.

head Un Head ejemplo. Esto especifica de salida del modelo y la pérdida de función para ser optimizado.
sequence_feature_columns Un iterable que contiene los FeatureColumn s que representan la entrada secuencial. Todos los elementos en el conjunto debe ser o bien la secuencia de columnas (por ejemplo sequence_numeric_column ) o construidas a partir de uno (por ejemplo embedding_column con sequence_categorical_column_* como entrada).
context_feature_columns Un iterable que contiene el FeatureColumn s para la entrada contextual. Los datos representados por estas columnas se replicarán y dados a la RNN en cada paso de tiempo. Estas columnas deben ser instancias de clases derivadas de DenseColumn tales como numeric_column , no el secuencial variantes.
units Iterable de número entero número de unidades ocultas por capa RNN. Si se establece, cell_type También se debe especificar y rnn_cell_fn debe ser None .
cell_type Una clase de fabricación de una pila RNN o una cadena que especifica el tipo de célula. Cuerdas soportados son: 'simple_rnn' , 'lstm' , y 'gru' . Si se establece, units también deben especificarse y rnn_cell_fn deben ser None .
rnn_cell_fn Una función que devuelve una instancia de célula RNN que será utilizado para construir el RNN. Si se establece, units y cell_type no se pueden ajustar. Esto es para usuarios avanzados que necesitan personalización adicional más allá de units y cell_type . Tenga en cuenta que tf.keras.layers.StackedRNNCells es necesario para RNNs apilados.
return_sequences Un booleano que indica si se debe devolver la última salida en la secuencia de salida, o la secuencia completa.
model_dir Directorio para guardar los parámetros del modelo, gráfico, etc Esto también puede ser usado para los puestos de control de carga desde el directorio en un estimador para continuar la formación de un modelo previamente guardado.
optimizer Una instancia de tf.Optimizer o cadena que especifica el tipo optimizador. Por defecto es Adagrad optimizador.
config RunConfig objeto de configurar los ajustes de tiempo de ejecución.

ValueError Si units , cell_type y rnn_cell_fn no son compatibles.

Compatibilidad ansiosos

Estimadores no son compatibles con la ejecución ansiosos.

config

model_dir

model_fn Devuelve el model_fn que está obligado a self.params .
params

Métodos

eval_dir

Ver fuente

Muestra el nombre del directorio en el que se vierten métricas de evaluación.

args
name Nombre de la evaluación si las necesidades del usuario para ejecutar múltiples evaluaciones en diferentes conjuntos de datos, como en el entrenamiento de los datos frente a los datos de prueba. Métricas para diferentes evaluaciones se guardan en carpetas separadas, y aparecen por separado en tensorboard.

Devoluciones
Una cadena que es la ruta de directorio contiene métricas de evaluación.

evaluate

Ver fuente

Evalúa la modelo dado datos de la evaluación input_fn .

Para cada paso, las llamadas input_fn , que devuelve un lote de datos. Evalúa hasta que:

args
input_fn Una función que construye los datos de entrada para la evaluación. Ver preparado de antemano estimadores para más información. La función debe construir y devolver uno de los siguientes:

  • A tf.data.Dataset objeto: Salidas de Dataset de objeto debe ser una tupla (features, labels) con mismas restricciones que a continuación.
  • Una tupla (features, labels) : ¿Dónde features es un tf.Tensor o un diccionario de nombre de elemento de cadena para Tensor y labels es una Tensor o un diccionario de nombre de etiqueta cadena a Tensor . Ambas features y labels son consumidos por model_fn . Deben satisfacer las expectativas de model_fn de entradas.
steps Número de pasos para la cual evaluar modelo. Si None , evalúa hasta input_fn genera una excepción al final de la entrada.
hooks Lista de tf.train.SessionRunHook casos subclase. Se utiliza para las devoluciones de llamada dentro de la llamada evaluación.
checkpoint_path Trayectoria de un punto de control específico para evaluar. Si None , el último puesto de control en model_dir se utiliza. Si no hay puestos de control en model_dir , la evaluación se ejecuta con iniciada recientemente Variables en lugar de los restaurados de puesto de control.
name Nombre de la evaluación si las necesidades del usuario para ejecutar múltiples evaluaciones en diferentes conjuntos de datos, como en el entrenamiento de los datos frente a los datos de prueba. Métricas para diferentes evaluaciones se guardan en carpetas separadas, y aparecen por separado en tensorboard.

Devoluciones
A dict que contiene las métricas de evaluación especificados en model_fn tecleado por el nombre, así como una entrada global_step que contiene el valor de la etapa global para el que se realizó esta evaluación. Para estimadores enlatados, el dict contiene la loss (pérdida media por mini lotes) y el average_loss (pérdida media por muestra). Clasificadores enlatados también devuelven la accuracy . Regresores enlatados también vuelven la label/mean y la prediction/mean .

aumentos
ValueError Si steps <= 0 .

experimental_export_all_saved_models

Ver fuente

Las exportaciones un SavedModel con tf.MetaGraphDefs para cada modo de operación deseado.

Para cada modo aprobada en a través de la input_receiver_fn_map , este método construye un nuevo gráfico llamando a la input_receiver_fn obtener característica y la etiqueta Tensor s. A continuación, este método llama al Estimator 's model_fn en el modo pasado para generar el gráfico modelo basado en las características y etiquetas, y restaura el punto de control dado (o, a falta de eso, el último punto de control) en el gráfico. Sólo uno de los modos se utiliza para ahorrar variables al SavedModel (orden de preferencia: tf.estimator.ModeKeys.TRAIN , tf.estimator.ModeKeys.EVAL , entonces tf.estimator.ModeKeys.PREDICT ), de manera que hasta tres tf.MetaGraphDefs se guardan con un único conjunto de variables en una sola SavedModel directorio.

Para las variables y tf.MetaGraphDefs , un directorio de exportación sellos de tiempo a continuación export_dir_base , y escribe un SavedModel en ella que contiene el tf.MetaGraphDef para el modo dado y sus firmas asociadas.

Para la predicción, la exportado MetaGraphDef proporcionará una