Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

tf.compat.v1.estimator.Estimator

Lihat sumber di GitHub

kelas estimator untuk melatih dan mengevaluasi model TensorFlow.

Digunakan di notebook

Digunakan dalam tutorial

The Estimator objek membungkus model yang ditentukan oleh model_fn , yang, diberikan masukan dan sejumlah parameter lain, mengembalikan ops diperlukan untuk melakukan pelatihan, evaluasi, atau prediksi.

Semua output (pos pemeriksaan, file event, dll) ditulis untuk model_dir , atau subdirektori daripadanya. Jika model_dir tidak diatur, direktori sementara digunakan.

The config argumen dapat dilewatkan tf.estimator.RunConfig objek yang berisi informasi tentang lingkungan eksekusi. Hal ini diteruskan ke model_fn , jika model_fn memiliki parameter bernama "config" (dan fungsi masukan dengan cara yang sama). Jika config parameter tidak lulus, itu dipakai oleh Estimator . Tidak lulus config berarti bahwa default berguna untuk eksekusi lokal yang digunakan. Estimator membuat konfigurasi yang tersedia untuk model (misalnya, untuk memungkinkan spesialisasi berdasarkan jumlah pekerja yang tersedia), dan juga menggunakan beberapa bidang untuk mengendalikan internal, terutama mengenai checkpointing.

The params Argumen berisi hyperparameters. Hal ini diteruskan ke model_fn , jika model_fn memiliki parameter bernama "params", dan untuk fungsi masukan dengan cara yang sama. Estimator hanya melewati params bersama, itu tidak memeriksanya. Struktur params karena itu sepenuhnya terserah pengembang.

Tak satu pun dari Estimator metode 's dapat diganti dalam subclass (konstruktor memberlakukan ini). Subclass harus menggunakan model_fn untuk mengkonfigurasi kelas dasar, dan dapat menambahkan metode melaksanakan fungsi khusus.

Lihat estimator untuk informasi lebih lanjut.

Untuk menghangatkan-memulai Estimator :

 estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")
 

Untuk rincian lebih lanjut tentang konfigurasi hangat-start, lihat tf.estimator.WarmStartSettings .

model_fn fungsi model. Berikut tanda tangan:

  • features - ini adalah item pertama kembali dari input_fn dilewatkan ke train , evaluate , dan predict . Ini harus menjadi satu tf.Tensor atau dict dari yang sama.
  • labels - ini adalah item kedua kembali dari input_fn dilewatkan ke train , evaluate , dan predict . Ini harus menjadi satu tf.Tensor atau dict dari yang sama (untuk model multi-head). Jika mode tf.estimator.ModeKeys.PREDICT , labels=None akan diteruskan. Jika model_fn tanda tangan 's tidak menerima mode , yang model_fn masih harus mampu menangani labels=None .
  • mode - Opsional. Menentukan apakah ini pelatihan, evaluasi atau prediksi. Lihat tf.estimator.ModeKeys . params - Opsional dict dari hyperparameters. Akan menerima apa yang akan diteruskan ke Pengukur di params parameter. Hal ini memungkinkan untuk mengkonfigurasi estimator dari parameter hiper tuning.
  • config - Opsional estimator.RunConfig objek. Akan menerima apa yang akan diteruskan ke Pengukur sebagai yang config parameter, atau nilai default. Memungkinkan pengaturan hal-hal di Anda model_fn berdasarkan konfigurasi seperti num_ps_replicas , atau model_dir .
  • Pengembalian - tf.estimator.EstimatorSpec
model_dir Direktori untuk menyimpan parameter model, grafik dan lain-lain ini juga dapat digunakan untuk pos-pos pemeriksaan beban dari direktori menjadi estimator untuk terus melatih model disimpan sebelumnya. Jika PathLike objek, jalan akan diselesaikan. Jika None , yang model_dir di config akan digunakan jika diatur. Jika keduanya ditetapkan, mereka harus sama. Jika keduanya None , direktori sementara akan digunakan.
config estimator.RunConfig konfigurasi objek.
params dict parameter hiper yang akan diteruskan ke model_fn . Kuncinya adalah nama parameter, nilai-nilai jenis python dasar.
warm_start_from Opsional tali filepath ke sebuah pos pemeriksaan atau SavedModel hangat-mulai dari, atau tf.estimator.WarmStartSettings keberatan untuk sepenuhnya mengkonfigurasi hangat-mulai. Jika ada, variabel hanya dilatih hangat-mulai. Jika filepath string yang disediakan bukannya tf.estimator.WarmStartSettings , maka semua variabel yang hangat-mulai, dan diasumsikan bahwa kosakata dan tf.Tensor nama tidak berubah.

ValueError parameter model_fn tidak cocok params .
ValueError jika ini disebut melalui subclass dan jika kelas yang menimpa anggota dari Estimator .

Kompatibilitas bersemangat

Metode Memanggil dari Estimator akan bekerja saat eksekusi bersemangat diaktifkan. Namun, model_fn dan input_fn tidak dijalankan dengan penuh semangat, Estimator akan beralih ke mode grafik sebelum memanggil semua fungsi yang disediakan pengguna (termasuk. Kait), sehingga kode mereka harus kompatibel dengan eksekusi modus grafik. Perhatikan bahwa input_fn kode menggunakan tf.data umumnya bekerja di kedua grafik dan mode bersemangat.

config

model_dir

model_fn Mengembalikan model_fn yang terikat self.params .
params

metode

eval_dir

Lihat sumber

Menunjukkan nama direktori dimana metrik evaluasi yang dibuang.

args
name Nama evaluasi jika kebutuhan pengguna untuk menjalankan beberapa evaluasi pada set data yang berbeda, seperti pada pelatihan Data vs data uji. Metrik untuk evaluasi yang berbeda disimpan dalam folder terpisah, dan muncul secara terpisah di tensorboard.

Pengembalian
Sebuah string yang merupakan jalur direktori berisi metrik evaluasi.

evaluate

Lihat sumber

Mengevaluasi model yang diberikan data evaluasi input_fn .

Untuk setiap langkah, panggilan input_fn , yang mengembalikan satu batch data. Mengevaluasi sampai:

  • steps batch diproses, atau
  • input_fn menimbulkan end-of-masukan pengecualian ( tf.errors.OutOfRangeError atau StopIteration ).

args
input_fn Sebuah fungsi yang membangun input data untuk evaluasi. Lihat Premade estimator untuk informasi lebih lanjut. Fungsi harus membangun dan kembali salah satu berikut:

  • Sebuah tf.data.Dataset objek: Output dari Dataset objek harus menjadi tuple (features, labels) dengan kendala yang sama seperti di bawah ini.
  • Sebuah tuple (features, labels) : Dimana features adalah tf.Tensor atau kamus nama fitur string untuk Tensor dan labels adalah Tensor atau kamus nama label string untuk Tensor . Kedua features dan labels dikonsumsi oleh model_fn . Mereka harus memenuhi harapan model_fn dari input.
steps Sejumlah langkah yang untuk mengevaluasi model yang. Jika None , mengevaluasi sampai input_fn menimbulkan end-of-masukan pengecualian.
hooks Daftar tf.train.SessionRunHook kasus subclass. Digunakan untuk callback dalam panggilan evaluasi.
checkpoint_path Jalan dari pos pemeriksaan tertentu untuk mengevaluasi. Jika None , pos pemeriksaan terbaru dalam model_dir digunakan. Jika tidak ada pos-pos pemeriksaan di model_dir , evaluasi dijalankan dengan baru diinisialisasi Variables bukan yang dipulihkan dari pos pemeriksaan.
name Nama evaluasi jika kebutuhan pengguna untuk menjalankan beberapa evaluasi pada set data yang berbeda, seperti pada pelatihan Data vs data uji. Metrik untuk evaluasi yang berbeda disimpan dalam folder terpisah, dan muncul secara terpisah di tensorboard.

Pengembalian
Sebuah dict yang berisi metrik evaluasi yang ditetapkan dalam model_fn mengetik dengan nama, serta entri global_step yang berisi nilai langkah global untuk yang evaluasi ini dilakukan. Untuk estimator kaleng, dict yang berisi loss (rata-rata kerugian per mini-batch) dan average_loss (rata-rata kerugian per sampel). Pengklasifikasi kaleng juga mengembalikan accuracy . Regressors kaleng juga mengembalikan label/mean dan prediction/mean .

kenaikan gaji
ValueError Jika steps <= 0 .

experimental_export_all_saved_models

Lihat sumber

Ekspor sebuah SavedModel dengan tf.MetaGraphDefs untuk setiap mode yang diminta.

Untuk setiap mode berlalu dalam melalui input_receiver_fn_map , metode ini membangun grafik baru dengan memanggil input_receiver_fn untuk mendapatkan fitur dan label Tensor s. Berikutnya, metode ini menyebut Estimator 's model_fn dalam modus lulus untuk menghasilkan grafik model yang didasarkan pada fitur-fitur dan label, dan mengembalikan pos pemeriksaan yang diberikan (atau, kurang itu, pos pemeriksaan terbaru) ke grafik. Hanya salah satu mode digunakan untuk menyimpan variabel ke SavedModel (urutan preferensi: tf.estimator.ModeKeys.TRAIN , tf.estimator.ModeKeys.EVAL , maka tf.estimator.ModeKeys.PREDICT ), sehingga sampai tiga tf.MetaGraphDefs disimpan dengan satu set variabel dalam satu SavedModel direktori.

Untuk variabel dan tf.MetaGraphDefs , direktori ekspor t