このページは Cloud Translation API によって翻訳されました。
Switch to English

tf.estimator.Estimator

TensorFlow 1つのバージョン GitHubの上のソースを表示

訓練し、TensorFlowモデルを評価するための見積もりクラス。

:から継承Estimator

ノートPCで使用されます

ガイドで使用チュートリアルで使用されます

Estimatorオブジェクトは、で指定されたモデルラップmodel_fn 、入力および他のパラメータの数が与えられると、訓練、評価、又は予測を行うために必要なオペレーションを返します。

すべての出力(などのチェックポイント、イベントファイルは、)に書かれているmodel_dir 、またはそのサブディレクトリ。場合model_dir設定されていない場合、一時ディレクトリが使用されます。

config引数を渡すことができtf.estimator.RunConfig実行環境に関する情報を含むオブジェクトを。これは、に渡されmodel_fn場合、 model_fn (同じようにして入力機能)「設定」という名前のパラメータがあります。場合はconfigパラメータが渡されていない、それがによってインスタンス化されるEstimator 。ローカル実行のための便利なデフォルトが使用されている設定手段を渡していません。 Estimator (利用可能な労働者の数に基づいて特殊化を可能にするために、例えば)モデルに利用可能な構成、および、特にチェックポイントについて、内部を制御するために、そのフィールドの一部を使用することができます。

params引数には、ハイパーが含まれています。それはに渡されmodel_fn場合、 model_fn 「のparams」という名前のパラメータを持っており、同様に入力機能に。 Estimator唯一、それを検査していない、のparamsを通ります。構造params完全最大現像することです。

いずれもEstimatorの方法は、(そのコンストラクタはこれを強制する)サブクラスでオーバーライドすることはできません。サブクラスでは使用すべきmodel_fn基本クラスを設定するには、専門機能を実装する方法を追加することができます。

参照推定量の詳細については。

ウォームスタートするにはEstimator

 estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")
 

ウォームスタート構成の詳細については、参照tf.estimator.WarmStartSettings

model_fn モデルの機能。署名を、次のとおりです。

  • features -これは最初の項目から返されinput_fnに渡さtrainevaluate 、およびpredict 。これは、単一である必要がありtf.Tensordictと同じの。
  • labels -これは2番目の項目から返されinput_fnに渡さtrainevaluate 、およびpredict 。これは、単一であるべきであるtf.Tensor又はdict (マルチヘッドモデル)同様の。モードであればtf.estimator.ModeKeys.PREDICTlabels=None渡されません。場合model_fnの署名が受け入れていないmodemodel_fnまだ処理できなければならないlabels=None
  • mode -オプション。指定これは、トレーニング、評価や予測がある場合。参照してくださいtf.estimator.ModeKeysparams -オプションdictハイパーの。見積もりに渡されたものを受け取ることになりますparamsパラメータ。これは、ハイパーパラメータのチューニングから推定量を設定することができます。
  • config -オプションestimator.RunConfigオブジェクト。そのよう見積もりに渡されたものを受信するconfigパラメータ、またはデフォルト値。自分の中で物事を設定することができますmodel_fnのような構成に基づいてnum_ps_replicas 、またはmodel_dir
  • 戻り値- tf.estimator.EstimatorSpec
model_dir これは、以前に保存したモデルのトレーニングを続行するために推定にディレクトリからロードチェックポイントにも使用することができるモデルのパラメータ、グラフ等を保存するディレクトリ。場合PathLikeオブジェクト、パスが解決されます。場合はNone 、でmodel_dir config設定されている場合に使用されます。両方が設定されている場合、それらは同じでなければなりません。両方がされていない場合はNone 、一時ディレクトリが使用されます。
config estimator.RunConfig設定オブジェクト。
params dictに渡されるハイパーパラメータのmodel_fn 。キーは、パラメータの名前で、値は、基本的なPythonのタイプです。
warm_start_from ウォームスタートするから、またはチェックポイントまたはSavedModelに任意の文字列ファイルパスtf.estimator.WarmStartSettings完全に構成するウォーム開始に反対します。なしている場合、唯一のトレーニング可能な変数は、ウォーム開始します。文字列のファイルパスが代わりに提供されている場合tf.estimator.WarmStartSettings 、その後、すべての変数は、ウォーム開始され、語彙や想定されるtf.Tensor名前が変更されていません。

ValueError パラメータmodel_fn一致していないparams
ValueError これは、サブクラスを経由して、そのクラスがメンバーオーバーライドする場合呼び出された場合はEstimator

イーガーの互換性

呼び出し方法Estimator熱心な実行が有効になっている間に動作します。しかし、 model_fninput_fn熱心に実行されていないが、 Estimator 、それらのコードは、グラフモードの実行と適合性でなければならないので、すべてのユーザが提供する機能(含フック)を呼び出す前に、グラフモードに切り替えます。なおinput_fn使用してコードtf.data一般グラフと熱心なモードの両方で動作します。

config

export_savedmodel

model_dir

model_fn 返しmodel_fnにバインドされているself.params
params

メソッド

eval_dir

ソースを表示

評価指標がダンプされているディレクトリ名を表示します。

引数
name 評価の名前、ユーザは、このようなテストデータ対データの訓練のように異なるデータセット、上の複数の評価を実行する必要がある場合。さまざまな評価のためのメトリクスは、別々のフォルダに保存され、tensorboardで別々に表示されています。

戻り値
ディレクトリのパスを表す文字列は、評価指標が含まれています。

evaluate

ソースを表示

モデル与えられた評価データ評価input_fn

各ステップのために、呼び出しinput_fnデータのバッチを返します。評価されるまで:

  • stepsバッチが処理され、または
  • input_fnエンドの入力例外(上げるtf.errors.OutOfRangeErrorまたはStopIteration )。

引数
input_fn 評価のための入力データを構築する機能。参照既成の推定量を詳細については。関数は、構築し、次のいずれかを返す必要があります:

  • A tf.data.Datasetオブジェクト:の出力Datasetオブジェクトがタプルでなければならない(features, labels)以下と同じ制約を持ちます。
  • タプル(features, labels)featuresあるtf.Tensorかに文字列の機能名の辞書TensorlabelsあるTensorかに文字列のラベル名の辞書Tensor 。両方featureslabelsによって消費されているmodel_fn 。彼らはの期待を満たす必要がありますmodel_fn入力からを。
steps モデルを評価するためのステップの数。場合None 、まで評価さinput_fn入力終了例外が発生します。
hooks 一覧tf.train.SessionRunHookサブクラスのインスタンス。評価コール内部コールバックに使用します。
checkpoint_path 評価する具体的なチェックポイントのパス。場合None 、最新のチェックポイントmodel_dir使用されています。チェックポイントがない場合model_dir 、評価が新規に初期化して実行されたVariables代わりに、チェックポイントから復元されたものの。
name 評価の名前、ユーザは、このようなテストデータ対データの訓練のように異なるデータセット、上の複数の評価を実行する必要がある場合。さまざまな評価のためのメトリクスは、別々のフォルダに保存され、tensorboardで別々に表示されています。

戻り値
で指定された評価指標を含む辞書model_fn名をキーに、ならびにエントリglobal_stepこの評価が実行されたグローバル段差の値を含みます。缶詰の推定のために、辞書は含まloss (ミニバッチ当たりの平均損失)とaverage_loss (サンプルあたりの平均損失)。缶詰分類も返しaccuracy 。缶詰説明変数も返さlabel/meanprediction/mean

発生させます
ValueError もしsteps <= 0

experimental_export_all_saved_models

ソースを表示

輸出SavedModeltf.MetaGraphDefs各要求されたモードについて。

経由で渡された各モードについてinput_receiver_fn_map 、このメソッドは呼び出して新しいグラフを構築input_receiver_fnフィーチャーし、ラベル取得するTensor秒。次は、このメソッドを呼び出しEstimatormodel_fnグラフに(つまり、最新のチェックポイントを欠く、または)これらの機能とラベルに基づいてモデルグラフを生成するために渡されたモードでは、与えられたチェックポイントを復元します。のみのモードのいずれかがに変数を保存するために使用されSavedModel :(優先順tf.estimator.ModeKeys.TRAINtf.estimator.ModeKeys.EVAL 、その後tf.estimator.ModeKeys.PREDICT 、その結果、最大3つの) tf.MetaGraphDefs単一の変数の単一のセットで保存されSavedModelディレクトリ。

変数とについてtf.MetaGraphDefs 、以下のタイムスタンプエクスポートディレクトリexport_dir_base 、および書き込みSavedModel含むことにtf.MetaGraphDef与えられたモードと、それに関連する署名を。

予測のために、エクスポートしたMetaGraphDef 1つを提供しますSignatureDefの各要素についてexport_outputsから返さDICT model_fn 、同じキーを使用して命名します。これらのキーの一つが常にあるtf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEYサービス提供要求は1を指定していない場合に提供されるシグネチャを示します。各署名のために、出力を対応させて設けられているtf.estimator.export.ExportOutput S、および入力が常にによって提供される入力レシーバあるserving_input_receiver_fn

トレーニングと評価のため、 train_op余分なコレクションに格納され、損失、メトリクス、および予測が含まれているSignatureDef問題のモードについて。

余分な資産を書き込むこともSavedModel経由assets_extra引数。これは、各キーがassets.extraディレクトリに対して(ファイル名を含む)宛先パスを与える辞書、であるべきです。対応する値をコピーするソース・ファイルのフルパスを与えます。たとえば、名前を変更することなく、1つのファイルをコピーする簡単な場合は、次のように指定されている{'my_asset_file.txt': '/path/to/my_asset_file.txt'}

引数
export_dir_base エクスポート含むタイムスタンプ付きのサブディレクトリを作成するディレクトリを含む文字列SavedModel秒。
input_receiver_fn_map 辞書tf.estimator.ModeKeysinput_receiver_fnマッピング、 input_receiver_fn引数を取らず、適切なサブクラスを返す関数であるInputReceiver
assets_extra エクスポートされた内assets.extraディレクトリを移入する方法を指定する辞書SavedModel 、またはNone余分な資産が必要とされていない場合。
as_text 書き込むかどうかをSavedModelテキスト形式でプロトを。
checkpoint_path 輸出へのチェックポイントのパス。場合None (デフォルト)、モデルのディレクトリ内で見つかった最新のチェックポイントが選択されています。

戻り値
バイト・オブジェクトとしてエクスポートされたディレクトリへのパス。

発生させます
ValueError もしあればinput_receiver_fnありませんNone 、何のexport_outputs提供されていない、またはチェックポイントが見つかりません。

export_saved_model

ソースを表示

グラフ推論輸出SavedModel所定のディレクトリに。

詳細なガイドについては、 推定量からSavedModelを

この方法は、最初に呼び出して新しいグラフを構築serving_input_receiver_fnフィーチャー得るために、 Tensor Sを、その後、この呼び出しEstimatormodel_fnこれらの機能に基づいたモデルのグラフを生成します。それは、新鮮なセッションでこのグラフに(つまり、最新のチェックポイントを欠く、または)与えられたチェックポイントを復元します。最後に、与えられた下のタイムスタンプエクスポートディレクトリを作成export_dir_base 、そして書き込みSavedModel 、単一含むにtf.MetaGraphDefこのセッションから保存を。

エクスポートMetaGraphDef 1つを提供しますSignatureDef 、各要素に対してexport_outputsから返された辞書model_fn同じキーを使用して名前が付けられ、。これらのキーの一つが常にあるtf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEYサービス提供要求は1を指定していない場合に提供されるシグネチャを示します。各署名のために、出力を対応させて設けられているtf.estimator.export.ExportOutput S、および入力が常にによって提供される入力レシーバあるserving_input_receiver_fn

余分な資産を書き込むこともSavedModel経由assets_extra引数。これは、各キーがassets.extraディレクトリに対して(ファイル名を含む)宛先パスを与える辞書、であるべきです。対応する値をコピーするソース・ファイルのフルパスを与えます。たとえば、名前を変更することなく、1つのファイルをコピーする簡単な場合は、次のように指定されている{'my_asset_file.txt': '/path/to/my_asset_file.txt'}

experimental_modeパラメータは以下のように単一の列車/ evalの/予測グラフをエクスポートするために使用することができますSavedModel 。参照してくださいexperimental_export_all_saved_models完全なドキュメントのために。

引数
export_dir_base エクスポート含むタイムスタンプ付きのサブディレクトリを作成するディレクトリを含む文字列SavedModel秒。
serving_input_receiver_fn 引数を取らないと返す関数tf.estimator.export.ServingInputReceiverまたはtf.estimator.export.TensorServingInputReceiver
assets_extra エクスポートされた内assets.extraディレクトリを移入する方法を指定する辞書SavedModel 、またはNone余分な資産が必要とされていない場合。
as_text 書き込むかどうかをSavedModelテキスト形式でプロトを。
checkpoint_path 輸出へのチェックポイントのパス。場合None (デフォルト)、モデルのディレクトリ内で見つかった最新のチェックポイントが選択されています。
experimental_mode tf.estimator.ModeKeysモードとを示す値がエクスポートされます。この機能は実験的であることに注意してください。

戻り値
バイト・オブジェクトとしてエクスポートされたディレクトリへのパス。

発生させます
ValueError 何場合serving_input_receiver_fn提供されていない、何のexport_outputs提供されていない、またはチェックポイントが見つかりません。

get_variable_names

ソースを表示

このモデルではすべての変数名のリストを戻します。

戻り値
名前のリスト。

発生させます
ValueError 場合はEstimatorまだチェックポイントを作成していません。

get_variable_value

ソースを表示

戻り値は、nameで指定した変数の値。

引数
name 文字列または文字列、テンソルの名前のリスト。

戻り値
numpyの配列 - テンソルの値。

発生させます
ValueError 場合はEstimatorまだチェックポイントを作成していません。

latest_checkpoint

ソースを表示

最新の保存されたチェックポイント・ファイルのファイル名検索しmodel_dir

戻り値
最新のチェックポイントまたはへの完全なパスNoneチェックポイントが見つかりませんでした場合。

predict

ソースを表示

与えられた機能の予測を生成します。

インターリーブ2が作業を出力しない予測することに注意してください。参照: 問題/ 20506

引数
input_fn 機能を構築する機能。まで予測は継続input_fnエンドの入力例外(上げるtf.errors.OutOfRangeErrorまたはStopIteration )。参照既成の推定量を詳細については。関数は、構築し、次のいずれかを返す必要があります:

  • tf.data.Datasetオブジェクト-の出力Datasetオブジェクトは、以下のように同じ制約を持つ必要があります。
  • 機能- A tf.Tensorかに文字列の機能名の辞書Tensor 。機能はによって消費されているmodel_fn 。彼らはの期待を満たす必要がありますmodel_fn入力からを。
  • 最初の項目を特徴量として抽出される場合にはタプル。
predict_keys リストstr 、キーの名前は予測します。場合は使用されているtf.estimator.EstimatorSpec.predictionsあるdict 。場合predict_keys使用され、その後の予測の残り辞書からフィルタリングされます。場合はNone 、すべて返されます。
hooks 一覧tf.train.SessionRunHookサブクラスのインスタンス。予測コール内部コールバックに使用します。
checkpoint_path 特定のチェックポイントのパスを予測します。場合None 、最新のチェックポイントmodel_dir使用されています。チェックポイントがない場合model_dir 、予測は、新たに初期化して実行されたVariables代わりに、チェックポイントから復元されたものの。
yield_single_examples 場合Falseで返されるよう、全体のバッチを得model_fn個々の要素にバッチを分解するのではなく。場合、これは有用であるmodel_fnその最初の寸法バッチサイズに等しくされていないいくつかのテンソルを返します。

収量:

評価値predictionsテンソル。

発生させます
ValueError 予測のバッチの長さは同じではなく、場合yield_single_examplesあるTrue
ValueError 間に矛盾がある場合predict_keyspredictions 。たとえば場合predict_keysませんNoneが、 tf.estimator.EstimatorSpec.predictionsないdict

train

ソースを表示

列車モデル与えられた学習データinput_fn

引数
input_fn minibatchesとしての訓練のための入力データを提供機能。参照既成の推定量を詳細については。関数は、構築し、次のいずれかを返す必要があります:

  • A tf.data.Datasetオブジェクト:の出力Datasetオブジェクトがタプルでなければならない(features, labels)以下と同じ制約を持ちます。
  • タプル(features, labels)featuresあるtf.Tensorかに文字列の機能名の辞書TensorlabelsあるTensorかに文字列のラベル名の辞書Tensor 。両方featureslabelsによって消費されているmodel_fn 。彼らはの期待を満たす必要がありますmodel_fn入力からを。
hooks 一覧tf.train.SessionRunHookサブクラスのインスタンス。トレーニングループ内コールバックに使用します。
steps モデルを訓練するためのステップの数。場合はNone 、永遠まで電車や列車input_fn発生しませんtf.errors.OutOfRangeエラーまたはStopIteration例外を。 stepsインクリメンタル動作します。あなたは2回コールした場合train(steps=10) 、その後の訓練は、合計20のステップで起こります。場合OutOfRangeまたはStopIteration途中で発生し、訓練は20段階の前に停止します。あなたはインクリメンタル行動をしたくない場合は設定してくださいmax_steps代わりに。設定した場合、 max_stepsあってはなりませんNone
max_steps モデルを訓練するための総ステップ数。場合はNone 、永遠まで電車や列車input_fn発生しませんtf.errors.OutOfRangeエラーまたはStopIteration例外を。設定した場合は、 stepsあってはなりませんNone 。場合OutOfRangeまたはStopIteration途中で発生し、トレーニングの前に停止max_steps手順を実行します。 2つの呼び出しtrain(steps=100) 200回のトレーニングの繰り返しを意味します。一方、2つの呼び出しがしますtrain(max_steps=100)の最初の呼び出しは、すべての100個のステップをやったので、2回目の呼び出しは、任意の繰り返しをしないだろうということを意味します。
saving_listeners リストCheckpointSaverListenerオブジェクト。前やチェックポイントの節約の直後に実行コールバックに使用します。

戻り値
selfチェーンのため、。

発生させます
ValueError 両方の場合stepsmax_stepsありませんんNone
ValueError いずれかの場合stepsmax_steps <= 0