Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

tf.estimator.experimental.RNNClassifier

Lihat sumber di GitHub

Sebuah classifier untuk model TensorFlow RNN.

Mewarisi Dari: RNNEstimator

Kereta model jaringan saraf berulang untuk contoh mengklasifikasikan menjadi salah satu dari beberapa kelas.

Contoh:

 token_sequence = sequence_categorical_column_with_hash_bucket(...)
token_emb = embedding_column(categorical_column=token_sequence, ...)

estimator = RNNClassifier(
    sequence_feature_columns=[token_emb],
    units=[32, 16], cell_type='lstm')

# Input builders
def input_fn_train: # returns x, y
  pass
estimator.train(input_fn=input_fn_train, steps=100)

def input_fn_eval: # returns x, y
  pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
  pass
predictions = estimator.predict(input_fn=input_fn_predict)
 

Masukan dari train dan evaluate seharusnya fitur berikut, jika tidak akan ada KeyError :

  • jika weight_column tidak None , fitur dengan key=weight_column yang nilainya adalah Tensor .
  • untuk setiap column di sequence_feature_columns :
    • fitur dengan key=column.name yang value adalah SparseTensor .
  • untuk setiap column di context_feature_columns :
    • jika column adalah CategoricalColumn , fitur dengan key=column.name yang value adalah SparseTensor .
    • jika column adalah WeightedCategoricalColumn , dua fitur: pertama dengan key id nama kolom, yang kedua dengan key berat nama kolom. Kedua fitur value harus menjadi SparseTensor .
    • jika column adalah DenseColumn , fitur dengan key=column.name yang value adalah Tensor .

Rugi dihitung dengan menggunakan Softmax lintas entropi.

sequence_feature_columns Sebuah iterable yang berisi FeatureColumn s yang mewakili masukan berurutan. Semua item di set baik harus urut kolom (misalnya sequence_numeric_column ) atau dibangun dari satu (misalnya embedding_column dengan sequence_categorical_column_* sebagai input).
context_feature_columns Sebuah iterable yang berisi FeatureColumn s untuk input kontekstual. Data diwakili oleh kolom ini akan direplikasi dan diberikan kepada RNN pada setiap timestep. Kolom ini harus contoh kelas yang berasal dari DenseColumn seperti numeric_column , tidak berurutan varian.
units Iterable dari nomor integer unit tersembunyi per lapisan RNN. Jika set, cell_type juga harus ditetapkan dan rnn_cell_fn harus None .
cell_type Sebuah kelas memproduksi sel RNN atau string menentukan jenis sel. String yang didukung adalah: 'simple_rnn' , 'lstm' , dan 'gru' . Jika diatur, units juga harus ditetapkan dan rnn_cell_fn harus None .
rnn_cell_fn Sebuah fungsi yang mengembalikan contoh sel RNN yang akan digunakan untuk membangun RNN. Jika set, units dan cell_type tidak dapat diatur. Hal ini untuk pengguna tingkat lanjut yang membutuhkan kustomisasi tambahan di luar units dan cell_type . Perhatikan bahwa tf.keras.layers.StackedRNNCells diperlukan untuk RNNs ditumpuk.
return_sequences Sebuah boolean yang menunjukkan apakah untuk kembali output terakhir di urutan output, atau urutan penuh. Catatan bahwa jika Benar, weight_column harus ada atau string.
model_dir Direktori untuk menyimpan parameter model, grafik dan lain-lain ini juga dapat digunakan untuk pos-pos pemeriksaan beban dari direktori ke estimator untuk terus melatih model disimpan sebelumnya.
n_classes Jumlah kelas label. Default untuk 2, klasifikasi yaitu biner. Harus> 1.
weight_column Sebuah string atau NumericColumn diciptakan oleh tf.feature_column.numeric_column mendefinisikan kolom fitur yang mewakili bobot. Hal ini digunakan untuk berat badan atau meningkatkan turun contoh selama pelatihan. Ini akan dikalikan dengan hilangnya contoh. Jika string, itu digunakan sebagai kunci untuk mengambil tensor berat dari features . Jika itu adalah NumericColumn , tensor baku diambil oleh kunci weight_column.key , maka weight_column.normalizer_fn diterapkan di atasnya untuk mendapatkan tensor berat badan.
label_vocabulary Sebuah daftar string merupakan nilai yang mungkin label. Jika diberi, label harus tipe string dan memiliki nilai di label_vocabulary . Jika tidak diberikan, yang berarti label sudah dikodekan sebagai integer atau pelampung dalam [0, 1] untuk n_classes=2 dan dikodekan sebagai nilai-nilai integer dalam {0, 1, ..., n_classes-1} untuk n_classes > 2. Juga akan ada kesalahan jika kosakata tidak disediakan dan label tali.
optimizer Sebuah contoh dari tf.Optimizer atau tali menspesifikasikan jenis optimizer. Default untuk Adagrad optimizer.
loss_reduction Salah satu tf.losses.Reduction kecuali NONE . Menjelaskan cara mengurangi kerugian pelatihan selama batch. Default untuk SUM_OVER_BATCH_SIZE .
sequence_mask Sebuah string dengan nama tensor urutan topeng. Jika sequence_mask adalah di fitur kamus, tensor disediakan digunakan, jika tidak urutan topeng dihitung dari panjang fitur berurutan. Urutan mask digunakan dalam mode evaluasi dan pelatihan untuk agregat kerugian dan metrik perhitungan sementara tidak termasuk langkah-langkah padding. Hal ini juga ditambahkan ke prediksi kamus dalam mode prediksi untuk menunjukkan langkah-langkah yang padding.
config RunConfig objek untuk mengkonfigurasi pengaturan runtime.

ValueError Jika units , cell_type , dan rnn_cell_fn tidak kompatibel.

Kompatibilitas bersemangat

Estimator tidak kompatibel dengan eksekusi bersemangat.

config

model_dir

model_fn Mengembalikan model_fn yang terikat self.params .
params

metode

eval_dir

Lihat sumber

Menunjukkan nama direktori dimana metrik evaluasi yang dibuang.

args
name Nama evaluasi jika kebutuhan pengguna untuk menjalankan beberapa evaluasi pada set data yang berbeda, seperti pada pelatihan Data vs data uji. Metrik untuk evaluasi yang berbeda disimpan dalam folder terpisah, dan muncul secara terpisah di tensorboard.

Pengembalian
Sebuah string yang merupakan jalur direktori berisi metrik evaluasi.

evaluate

Lihat sumber

Mengevaluasi model yang diberikan data evaluasi input_fn .

Untuk setiap langkah, panggilan input_fn , yang mengembalikan satu batch data. Mengevaluasi sampai:

  • steps batch diproses, atau
  • input_fn menimbulkan end-of-masukan pengecualian ( tf.errors.OutOfRangeError atau StopIteration ).

args
input_fn Sebuah fungsi yang membangun input data untuk evaluasi. Lihat Premade estimator untuk informasi lebih lanjut. Fungsi harus membangun dan kembali salah satu berikut:

  • Sebuah tf.data.Dataset objek: Output dari Dataset objek harus menjadi tuple (features, labels) dengan kendala yang sama seperti di bawah ini.
  • Sebuah tuple (features, labels) : Dimana features adalah tf.Tensor atau kamus nama fitur string untuk Tensor dan labels adalah Tensor atau kamus nama label string untuk Tensor . Kedua features dan labels dikonsumsi oleh model_fn . Mereka harus memenuhi harapan model_fn dari input.
steps Sejumlah langkah yang untuk mengevaluasi model yang. Jika None , mengevaluasi sampai input_fn menimbulkan end-of-masukan pengecualian.
hooks Daftar tf.train.SessionRunHook kasus subclass. Digunakan untuk callback dalam panggilan evaluasi.
checkpoint_path Jalan dari pos pemeriksaan tertentu untuk mengevaluasi. Jika None , pos pemeriksaan terbaru dalam model_dir digunakan. Jika tidak ada pos-pos pemeriksaan di model_dir , evaluasi dijalankan dengan baru diinisialisasi Variables bukan yang dipulihkan dari pos pemeriksaan.
name Nama evaluasi jika kebutuhan pengguna untuk menjalankan beberapa evaluasi pada set data yang berbeda, seperti pada pelatihan Data vs data uji. Metrik untuk evaluasi yang berbeda disimpan dalam folder terpisah, dan muncul secara terpisah di tensorboard.

Pengembalian
Sebuah dict yang berisi metrik evaluasi yang ditetapkan dalam model_fn mengetik dengan nama, serta entri global_step yang berisi nilai langkah global untuk yang evaluasi ini dilakukan. Untuk estimator kaleng, dict yang berisi loss (rata-rata kerugian per mini-batch) dan average_loss (rata-rata kerugian per sampel). Pengklasifikasi kaleng juga mengembalikan accuracy . Regressors kaleng juga mengembalikan label/mean dan prediction/mean .

kenaikan gaji
ValueError Jika steps <= 0 .

experimental_export_all_saved_models

Lihat sumber

Ekspor sebuah SavedModel dengan tf.MetaGraphDefs untuk setiap mode yang diminta.

Untuk setiap mode berlalu dalam melalui input_receiver_fn_map , metode ini membangun grafik baru dengan memanggil input_receiver_fn untuk mendapatkan fitur dan label Tensor s. Berikutnya, metode ini menyebut Estimator 's model_fn dalam modus lulus untuk menghasilkan grafik model yang didasarkan pada fitur-fitur dan label, dan mengembalikan pos pemeriksaan yang diberikan (atau, kurang itu, pos pemeriksaan terbaru) ke grafik. Hanya salah satu mode digunakan untuk menyimpan variabel ke SavedModel (urutan preferensi: tf.estimator.ModeKeys.TRAIN , tf.estimator.ModeKeys.EVAL , maka tf.estimator.ModeKeys.PREDICT ), sehingga sampai tiga tf.MetaGraphDefs disim