Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

tf.estimator.BaselineEstimator

TensorFlow 1 versi Lihat sumber di GitHub

Sebuah estimator yang dapat membangun dasar yang sederhana.

Mewarisi Dari: Estimator

estimator menggunakan kepala yang ditentukan pengguna.

estimator ini mengabaikan nilai-nilai fitur dan akan belajar untuk memprediksi nilai rata-rata masing-masing label. Misalnya untuk masalah klasifikasi tunggal-label, ini akan memprediksi distribusi probabilitas dari kelas seperti yang terlihat di label. Untuk masalah klasifikasi multi-label, itu akan memprediksi rasio contoh yang mengandung masing-masing kelas.

Contoh:

 
# Build baseline multi-label classifier.
estimator = tf.estimator.BaselineEstimator(
    head=tf.estimator.MultiLabelHead(n_classes=3))

# Input builders
def input_fn_train:
  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class
  # index.
  pass

def input_fn_eval:
  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class
  # index.
  pass

# Fit model.
estimator.train(input_fn=input_fn_train)

# Evaluates cross entropy between the test and train labels.
loss = estimator.evaluate(input_fn=input_fn_eval)["loss"]

# For each class, predicts the ratio of training examples that contain the
# class.
predictions = estimator.predict(new_samples)

 

Masukan dari train dan evaluate seharusnya fitur berikut, jika tidak akan ada KeyError :

  • jika weight_column ditentukan dalam head konstruktor (dan tidak ada) untuk kepala dilewatkan ke konstruktor BaselineEstimator, sebuah fitur dengan key=weight_column yang nilainya adalah Tensor .

head Sebuah Head contoh dibangun dengan metode seperti tf.estimator.MultiLabelHead .
model_dir Direktori untuk menyimpan parameter model, grafik dan lain-lain ini juga dapat digunakan untuk pos-pos pemeriksaan beban dari direktori ke estimator untuk terus melatih model disimpan sebelumnya.
optimizer String, tf.keras.optimizers.* Objek, atau callable yang menciptakan optimizer untuk digunakan untuk pelatihan. Jika tidak ditentukan, akan menggunakan Ftrl sebagai optimizer default.
config RunConfig objek untuk mengkonfigurasi pengaturan runtime.

config

export_savedmodel

model_dir

model_fn Mengembalikan model_fn yang terikat self.params .
params

metode

eval_dir

Lihat sumber

Menunjukkan nama direktori dimana metrik evaluasi yang dibuang.

args
name Nama evaluasi jika kebutuhan pengguna untuk menjalankan beberapa evaluasi pada set data yang berbeda, seperti pada pelatihan Data vs data uji. Metrik untuk evaluasi yang berbeda disimpan dalam folder terpisah, dan muncul secara terpisah di tensorboard.

Pengembalian
Sebuah string yang merupakan jalur direktori berisi metrik evaluasi.

evaluate

Lihat sumber

Mengevaluasi model yang diberikan data evaluasi input_fn .

Untuk setiap langkah, panggilan input_fn , yang mengembalikan satu batch data. Mengevaluasi sampai:

  • steps batch diproses, atau
  • input_fn menimbulkan end-of-masukan pengecualian ( tf.errors.OutOfRangeError atau StopIteration ).

args
input_fn Sebuah fungsi yang membangun input data untuk evaluasi. Lihat Premade estimator untuk informasi lebih lanjut. Fungsi harus membangun dan kembali salah satu berikut:

  • Sebuah tf.data.Dataset objek: Output dari Dataset objek harus menjadi tuple (features, labels) dengan kendala yang sama seperti di bawah ini.
  • Sebuah tuple (features, labels) : Dimana features adalah tf.Tensor atau kamus nama fitur string untuk Tensor dan labels adalah Tensor atau kamus nama label string untuk Tensor . Kedua features dan labels dikonsumsi oleh model_fn . Mereka harus memenuhi harapan model_fn dari input.
steps Sejumlah langkah yang untuk mengevaluasi model yang. Jika None , mengevaluasi sampai input_fn menimbulkan end-of-masukan pengecualian.
hooks Daftar tf.train.SessionRunHook kasus subclass. Digunakan untuk callback dalam panggilan evaluasi.
checkpoint_path Jalan dari pos pemeriksaan tertentu untuk mengevaluasi. Jika None , pos pemeriksaan terbaru dalam model_dir digunakan. Jika tidak ada pos-pos pemeriksaan di model_dir , evaluasi dijalankan dengan baru diinisialisasi Variables bukan yang dipulihkan dari pos pemeriksaan.
name Nama evaluasi jika kebutuhan pengguna untuk menjalankan beberapa evaluasi pada set data yang berbeda, seperti pada pelatihan Data vs data uji. Metrik untuk evaluasi yang berbeda disimpan dalam folder terpisah, dan muncul secara terpisah di tensorboard.

Pengembalian
Sebuah dict yang berisi metrik evaluasi yang ditetapkan dalam model_fn mengetik dengan nama, serta entri global_step yang berisi nilai langkah global untuk yang evaluasi ini dilakukan. Untuk estimator kaleng, dict yang berisi loss (rata-rata kerugian per mini-batch) dan average_loss (rata-rata kerugian per sampel). Pengklasifikasi kaleng juga mengembalikan accuracy . Regressors kaleng juga mengembalikan label/mean dan prediction/mean .

kenaikan gaji
ValueError Jika steps <= 0 .

experimental_export_all_saved_models

Lihat sumber

Ekspor sebuah SavedModel dengan tf.MetaGraphDefs untuk setiap mode yang diminta.

Untuk setiap mode berlalu dalam melalui input_receiver_fn_map , metode ini membangun grafik baru dengan memanggil input_receiver_fn untuk mendapatkan fitur dan label Tensor s. Berikutnya, metode ini menyebut Estimator 's model_fn dalam modus lulus untuk menghasilkan grafik model yang didasarkan pada fitur-fitur dan label, dan mengembalikan pos pemeriksaan yang diberikan (atau, kurang itu, pos pemeriksaan terbaru) ke grafik. Hanya salah satu mode digunakan untuk menyimpan variabel ke SavedModel (urutan preferensi: tf.estimator.ModeKeys.TRAIN , tf.estimator.ModeKeys.EVAL , maka tf.estimator.ModeKeys.PREDICT ), sehingga sampai tiga tf.MetaGraphDefs disimpan dengan satu set variabel dalam satu SavedModel direktori.

Untuk variabel dan tf.MetaGraphDefs , direktori ekspor timestamped bawah export_dir_base , dan menulis SavedModel ke dalamnya berisi tf.MetaGraphDef untuk modus diberikan dan tanda tangan yang terkait.

Untuk prediksi, yang diekspor MetaGraphDef akan memberikan satu SignatureDef untuk setiap elemen dari export_outputs dict kembali dari model_fn , bernama menggunakan tombol yang sama. Salah satu kunci ini selalu tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY , yang menunjukkan tanda tangan akan disajikan ketika permintaan melayani tidak menentukan satu. Untuk setiap tanda tangan, output yang disediakan oleh sesuai tf.estimator.export.ExportOutput s, dan masukan selalu penerima masukan yang diberikan oleh serving_input_receiver_fn .

Untuk pelatihan dan evaluasi, train_op disimpan dalam koleksi ekstra, dan kehilangan, metrik, dan prediksi termasuk dalam SignatureDef untuk modus yang bersangkutan.

Aset tambahan dapat ditulis ke dalam SavedModel melalui assets_extra argumen. Ini harus menjadi dict, di mana setiap tombol memberikan jalur tujuan (termasuk nama file) relatif ke direktori assets.extra. Nilai yang sesuai memberikan path lengkap dari file sumber yang akan disalin. Sebagai contoh, kasus sederhana menyalin file tunggal tanpa nama itu ditetapkan sebagai {'my_asset_file.txt': '/path/to/my_asset_file.txt'} .

args
export_dir_base String yang berisi direktori di mana untuk membuat subdirektori timestamped mengandung diekspor SavedModel s.
input_receiver_fn_map dict dari tf.estimator.ModeKeys untuk input_receiver_fn pemetaan, di mana input_receiver_fn adalah fungsi yang tidak mengambil argumen dan mengembalikan subclass sesuai InputReceiver .
assets_extra Sebuah dict menentukan bagaimana untuk mengisi direktori assets.extra dalam diekspor SavedModel , atau None jika tidak ada aset tambahan yang diperlukan.
as_text apakah untuk menulis SavedModel proto dalam format teks.
checkpoint_path Jalan pos pemeriksaan untuk ekspor. Jika None (default), paling pos pemeriksaan baru-baru ini ditemukan dalam direktori model yang dipilih.

Pengembalian
Path ke direktori diekspor sebagai objek byte.

kenaikan gaji
ValueError jika ada input_receiver_fn adalah None , tidak ada export_outputs disediakan, atau tidak ada pos dapat ditemukan.

export_saved_model

Lihat sumber

Ekspor inferensi grafik sebagai SavedModel ke dir diberikan.

Untuk panduan rinci, lihat SavedModel dari estimator .

Metode ini membangun grafik baru dengan terlebih dahulu memanggil serving_input_receiver_fn untuk mendapatkan fitur Tensor s, dan kemudian memanggil ini Estimator 's model_fn untuk menghasilkan grafik model yang didasarkan pada fitur tersebut. Memulihkan pos pemeriksaan yang diberikan (atau, kurang itu, pos pemeriksaan terbaru) ke grafik ini dalam sesi segar. Akhirnya ia menciptakan sebuah direktori ekspor timestamped bawah diberikan export_dir_base , dan menulis SavedModel ke dalamnya mengandung satu tf.MetaGraphDef diselamatkan dari sesi ini.

The diekspor MetaGraphDef akan menyediakan satu SignatureDef untuk setiap elemen dari export_outputs dict kembali dari model_fn , bernama menggunakan tombol yang sama. Salah satu kunci ini selalu tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY , yang menunjukkan tanda tangan akan disajikan ketika permintaan melayani tidak menentukan satu. Unt