Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

tf.keras.estimator.model_to_estimator

TensorFlow 1 versi Lihat sumber di GitHub

Membangun sebuah Estimator contoh dari model yang diberikan keras.

Digunakan di notebook

Digunakan dalam panduan Digunakan dalam tutorial

Jika Anda menggunakan infrastruktur atau perkakas lain yang bergantung pada estimator, Anda masih bisa membangun Keras Model dan penggunaan model_to_estimator untuk mengkonversi model Keras ke Pengukur untuk digunakan dengan sistem hilir.

Untuk contoh penggunaan, silakan lihat: Membuat estimator dari Model Keras .

Berat sampel:

Estimator dikembalikan oleh model_to_estimator dikonfigurasi sehingga mereka dapat menangani beban sampel (mirip dengan keras_model.fit(x, y, sample_weights) ).

Untuk lulus bobot sampel saat pelatihan atau mengevaluasi Pengukur, item pertama dikembalikan oleh fungsi input harus kamus dengan tombol features dan sample_weights . Contoh di bawah:

 keras_model = tf.keras.Model(...)
keras_model.compile(...)

estimator = tf.keras.estimator.model_to_estimator(keras_model)

def input_fn():
  return dataset_ops.Dataset.from_tensors(
      ({'features': features, 'sample_weights': sample_weights},
       targets))

estimator.train(input_fn, steps=1)
 

Untuk menyesuaikan estimator eval_metric_ops nama, Anda dapat lulus dalam metric_names_map kamus pemetaan keras output model metrik nama ke nama-nama kustom sebagai berikut:

   input_a = tf.keras.layers.Input(shape=(16,), name='input_a')
  input_b = tf.keras.layers.Input(shape=(16,), name='input_b')
  dense = tf.keras.layers.Dense(8, name='dense_1')
  interm_a = dense(input_a)
  interm_b = dense(input_b)
  merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
  output_a = tf.keras.layers.Dense(3, activation='softmax', name='dense_2')(
          merged)
  output_b = tf.keras.layers.Dense(2, activation='softmax', name='dense_3')(
          merged)
  keras_model = tf.keras.models.Model(
      inputs=[input_a, input_b], outputs=[output_a, output_b])
  keras_model.compile(
      loss='categorical_crossentropy',
      optimizer='rmsprop',
      metrics={
          'dense_2': 'categorical_accuracy',
          'dense_3': 'categorical_accuracy'
      })

  metric_names_map = {
      'dense_2_categorical_accuracy': 'acc_1',
      'dense_3_categorical_accuracy': 'acc_2',
  }
  keras_est = tf.keras.estimator.model_to_estimator(
      keras_model=keras_model,
      config=config,
      metric_names_map=metric_names_map)
 

keras_model Sebuah disusun Keras model objek. Argumen ini saling eksklusif dengan keras_model_path . Penaksir model_fn menggunakan struktur model untuk mengkloning model. Defaultnya None .
keras_model_path Jalan ke model Keras disusun disimpan pada disk, dalam format HDF5, yang dapat dihasilkan dengan save() metode model Keras. Argumen ini saling eksklusif dengan keras_model . Defaultnya None .
custom_objects Kamus untuk kloning obyek disesuaikan. Ini digunakan dengan kelas yang bukan merupakan bagian dari paket pip ini. Misalnya, jika pengguna mempertahankan relu6 kelas yang mewarisi dari tf.keras.layers.Layer , kemudian lulus custom_objects={'relu6': relu6} . Defaultnya None .
model_dir Direktori untuk menyimpan Estimator parameter model, grafik, file ringkasan untuk TensorBoard, dll Jika diset direktori akan dibuat dengan tempfile.mkdtemp
config RunConfig untuk konfigurasi Estimator . Memungkinkan pengaturan hal-hal di model_fn berdasarkan konfigurasi seperti num_ps_replicas , atau model_dir . Defaultnya None . Jika kedua config.model_dir dan model_dir argumen (di atas) yang ditentukan model_dir argumen diutamakan.
checkpoint_format Set format pos pemeriksaan diselamatkan oleh estimator ketika melatih. Mungkin saver atau checkpoint , tergantung pada apakah untuk menyelamatkan pos pemeriksaan dari tf.compat.v1.train.Saver atau tf.train.Checkpoint . Standarnya adalah checkpoint . Estimator menggunakan nama berbasis tf.train.Saver pos pemeriksaan, sementara model Keras menggunakan pos pemeriksaan berbasis objek dari tf.train.Checkpoint . Saat ini, tabungan pos pemeriksaan berbasis objek dari model_to_estimator hanya didukung oleh model Fungsional dan Sequential. Default untuk 'pos pemeriksaan'.
metric_names_map pemetaan kamus opsional Model Keras keluaran metrik nama ke nama kustom. Hal ini dapat digunakan untuk mengesampingkan default Keras output model metrik nama dalam kasus model multi IO penggunaan dan memberikan nama kustom untuk eval_metric_ops di Estimator. Metrik nama Model Keras dapat diperoleh dengan menggunakan model.metrics_names tidak termasuk metrik kerugian seperti kerugian total dan output kerugian. Misalnya, jika model Keras Anda memiliki dua output out_1 dan out_2 , dengan mse kerugian dan acc metrik, maka model.metrics_names akan ['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc'] . Metrik nama Model tidak termasuk metrik kerugian akan ['out_1_acc', 'out_2_acc'] .

Sebuah Pengukur dari model yang keras diberikan.

ValueError Jika tidak keras_model atau keras_model_path diberikan.
ValueError Jika kedua keras_model dan keras_model_path diberikan.
ValueError Jika keras_model_path adalah GCS URI.
ValueError Jika keras_model belum disusun.
ValueError Jika checkpoint_format tidak valid diberikan.