Esta página foi traduzida pela API Cloud Translation.
Switch to English

tf.compat.v1.estimator.Estimator

Ver fonte no GitHub

classe estimador para treinar e avaliar modelos TensorFlow.

Usado nos cadernos

Usado nos tutoriais

O Estimator objecto envolve um modelo que é especificado por um model_fn , o qual, dadas as entradas e uma série de outros parâmetros, devolve as operações necessárias para executar formação, avaliação ou previsões.

Todas as saídas (postos de controle, arquivos de eventos, etc.) são escritos para model_dir , ou um subdiretório do mesmo. Se model_dir não está definido, um diretório temporário é usado.

A config argumento pode ser passado tf.estimator.RunConfig objeto que contém informações sobre o ambiente de execução. Ele é passado para o model_fn , se o model_fn tem um parâmetro chamado "config" (e funções de entrada da mesma forma). Se a config parâmetro não é passado, ele é instanciado pelo Estimator . Não passando meios de configuração que padrões úteis para execução local são usados. Estimator faz config-disponível para o modelo (por exemplo, para permitir a especialização com base no número de trabalhadores disponíveis), e também usa alguns de seus campos para controlar internos, especialmente em relação checkpointing.

O params argumento contém hiperparâmetros. Ele é passado para o model_fn , se o model_fn tem um parâmetro denominado "params", e para as funções de entrada da mesma maneira. Estimator passa apenas params junto, não inspecioná-lo. A estrutura dos params é, portanto, inteiramente até o desenvolvedor.

Nenhum dos Estimator métodos 's pode ser substituído em subclasses (seu construtor aplica esse). Subclasses deve usar model_fn para configurar a classe base, e pode adicionar métodos de aplicação funcionalidade especializada.

Veja estimadores para mais informações.

Para aquecer-iniciar um Estimator :

 estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")
 

Para mais detalhes sobre a configuração do warm-start, ver tf.estimator.WarmStartSettings .

model_fn função de modelo. Segue a assinatura:

  • features - Este é o primeiro item retornado do input_fn passado para train , evaluate e predict . Este deve ser um único tf.Tensor ou dict da mesma.
  • labels - Este é o segundo item retornado do input_fn passado para train , evaluate e predict . Este deve ser um único tf.Tensor ou dict da mesma (para modelos multi-cabeça). Se o modo é tf.estimator.ModeKeys.PREDICT , labels=None será passado. Se o model_fn assinatura 's não aceita mode , o model_fn ainda deve ser capaz de lidar com labels=None .
  • mode - Opcional. Especifica se este é o treinamento, avaliação ou previsão. Veja tf.estimator.ModeKeys . params - Opcional dict dos hiperparâmetros. Vai receber o que é passado para Estimador em params parâmetro. Isso permite configurar Estimators de ajuste de parâmetro hiper.
  • config - Opcional estimator.RunConfig objeto. Vai receber o que é passado para Estimador como a sua config de parâmetros, ou um valor padrão. Permite a criação de coisas em sua model_fn com base na configuração, como num_ps_replicas , ou model_dir .
  • Devoluções - tf.estimator.EstimatorSpec
model_dir Diretório para salvar os parâmetros do modelo, gráfico e etc. Isso também pode ser usado para postos de controle de carga a partir do diretório em um estimador para continuar treinando um modelo salvo anteriormente. Se PathLike objeto, o caminho será resolvido. Se None , o model_dir na config será usada se definir. Se ambos estiverem definidos, eles devem ser mesmo. Se ambos são None , será usado um diretório temporário.
config estimator.RunConfig objecto de configuração.
params dict de parâmetros hiper que serão passados em model_fn . As chaves são nomes de parâmetros, os valores são tipos básicos python.
warm_start_from Opcional filepath string para um ponto de verificação ou SavedModel aquecer-começam a partir de, ou um tf.estimator.WarmStartSettings opor-se totalmente configure warm-começando. Se None, variáveis ​​única treináveis ​​são começou-quente. Se o filepath corda é fornecido em vez de um tf.estimator.WarmStartSettings , em seguida, todas as variáveis são começou-quente, e presume-se que vocabulários e tf.Tensor nomes são inalteradas.

ValueError parâmetros de model_fn não correspondem params .
ValueError Se isso é chamado através de uma subclasse e se essa classe substitui um membro do Estimator .

Compatibilidade ansioso

Métodos de chamada de Estimator funcionará enquanto a execução ansioso está habilitado. No entanto, o model_fn e input_fn não é executado com entusiasmo, Estimator irá mudar para o modo gráfico antes de chamar todas as funções fornecidos pelo usuário (incl. Ganchos), assim que seu código tem de ser compatível com a execução modo gráfico. Note-se que input_fn código usando tf.data geralmente funciona tanto gráfico e modos ansiosos.

config

model_dir

model_fn Retorna o model_fn que é obrigado a self.params .
params

Métodos

eval_dir

Ver fonte

Mostra o nome do diretório onde métricas de avaliação são despejados.

args
name Nome da avaliação, se as necessidades do usuário para executar múltiplas avaliações sobre diferentes conjuntos de dados, como no treinamento de dados vs dados de teste. Métricas para diferentes avaliações são guardadas em pastas separadas, e aparecem separadamente em tensorboard.

Devoluções
Uma seqüência que é o caminho do diretório contém métricas de avaliação.

evaluate

Ver fonte

Avalia o modelo dado dados de avaliação input_fn .

Para cada etapa, chama input_fn , que retorna um lote de dados. Avalia até que:

  • steps lotes são processados, ou
  • input_fn levanta uma excepção fim-de-entrada ( tf.errors.OutOfRangeError ou StopIteration ).

args
input_fn Uma função que constrói os dados de entrada para a avaliação. Veja Premade Estimators para mais informações. A função deve construir e retornar um dos seguintes:

  • Um tf.data.Dataset objecto: saídas do Dataset objecto tem de ser um tuplo (features, labels) com as mesmas restrições como abaixo.
  • Uma tupla (features, labels) : Onde features é um tf.Tensor ou um dicionário de nome do recurso string para Tensor e labels é um Tensor ou um dicionário de nome do rótulo string para Tensor . Ambos os features e labels são consumidos por model_fn . Eles devem satisfazer a expectativa de model_fn de entradas.
steps Número de passos para que avaliar o modelo. Se None , avalia até input_fn levanta uma excepção fim-de-entrada.
hooks Lista de tf.train.SessionRunHook casos subclasse. Usado para retornos de chamada dentro da chamada avaliação.
checkpoint_path Caminho de um posto de controle específico para avaliar. Se None , o último posto de controle em model_dir é usado. Se não há postos de controle em model_dir , a avaliação é executado com recém-inicializado Variables em vez dos restaurados a partir do ponto de verificação.
name Nome da avaliação, se as necessidades do usuário para executar múltiplas avaliações sobre diferentes conjuntos de dados, como no treinamento de dados vs dados d