Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

tf.keras.callbacks.EarlyStopping

TensorFlow 1 versi Lihat sumber di GitHub

pelatihan berhenti ketika dipantau metrik telah berhenti meningkatkan.

Mewarisi Dari: Callback

Digunakan di notebook

Digunakan dalam panduan Digunakan dalam tutorial

Dengan asumsi tujuan pelatihan adalah untuk meminimalkan kerugian. Dengan ini, metrik dipantau akan 'loss' , dan modus akan menjadi 'min' . A model.fit() lingkaran pelatihan akan memeriksa di akhir setiap zaman apakah kerugian tidak lagi menurun, mengingat min_delta dan patience jika berlaku. Setelah itu ditemukan tidak lagi menurun, model.stop_training ditandai Benar dan berakhir pelatihan.

Kuantitas yang akan dipantau kebutuhan akan tersedia di logs dict. Untuk membuatnya begitu, lulus kehilangan atau metrik di model.compile() .

monitor Kuantitas dipantau.
min_delta perubahan kuantitas yang dipantau minimum untuk memenuhi syarat sebagai perbaikan, yaitu perubahan mutlak kurang dari min_delta, akan dihitung sebagai tidak ada perbaikan.
patience Jumlah zaman dengan tidak ada perbaikan setelah pelatihan akan dihentikan.
verbose Modus bertele-tele.
mode Salah satu {"auto", "min", "max"} . Dalam min modus, pelatihan akan berhenti ketika kuantitas dipantau telah berhenti menurun; di "max" modus itu akan berhenti ketika kuantitas dipantau telah berhenti meningkat; di "auto" modus, arah secara otomatis disimpulkan dari nama kuantitas dipantau.
baseline nilai dasar untuk kuantitas dipantau. Pelatihan akan berhenti jika model tidak menunjukkan perbaikan dari baseline.
restore_best_weights Apakah untuk mengembalikan bobot model dari zaman dengan nilai terbaik dari kuantitas dipantau. Jika False, bobot model yang diperoleh pada langkah terakhir dari pelatihan yang digunakan.

Contoh:

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# This callback will stop the training when there is no improvement in
# the validation loss for three consecutive epochs.
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
                    epochs=10, batch_size=1, callbacks=[callback],
                    verbose=0)
len(history.history['loss'])  # Only 4 epochs are run.
4

metode

get_monitor_value

Lihat sumber

set_model

Lihat sumber

set_params

Lihat sumber