Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Pelatihan multi-pekerja dengan Keras

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Gambaran

Tutorial ini memperagakan pelatihan terdistribusi multi-pekerja dengan model Keras menggunakan tf.distribute.Strategy API, khususnya tf.distribute.experimental.MultiWorkerMirroredStrategy . Dengan bantuan strategi ini, model Keras yang dirancang untuk dijalankan pada pekerja tunggal dapat bekerja dengan mulus pada banyak pekerja dengan perubahan kode minimal.

Pelatihan Terdistribusi dalam panduan TensorFlow tersedia untuk tinjauan umum tentang strategi distribusi yang didukung TensorFlow bagi mereka yang tertarik dengan pemahaman yang lebih dalam tentang API tf.distribute.Strategy .

Mempersiapkan

Pertama, setup TensorFlow dan impor yang diperlukan.

 import os
import tensorflow as tf
import numpy as np
 

Mempersiapkan dataset

Sekarang, mari kita siapkan dataset MNIST. Dataset MNIST terdiri dari 60.000 contoh pelatihan dan 10.000 contoh uji dari digit tulisan tangan 0–9, yang diformat sebagai gambar monokrom 28x28-piksel. Dalam contoh ini, kita akan mengambil bagian pelatihan dari dataset untuk diperagakan.

 def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # We need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset
 

Bangun model Keras

Di sini kita menggunakan tf.keras.Sequential API untuk membangun dan menyusun jaringan saraf convolutional sederhana model Keras untuk berlatih dengan dataset MNIST kami.

 def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model
 

Pertama-tama mari kita coba melatih model untuk sejumlah kecil zaman dan amati hasilnya dalam satu pekerja untuk memastikan semuanya bekerja dengan benar. Anda harus berharap untuk melihat penurunan kerugian dan akurasi mendekati 1.0 seiring kemajuan zaman.

 per_worker_batch_size = 64
single_worker_dataset = mnist_dataset(per_worker_batch_size)
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
 
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/3
70/70 [==============================] - 0s 2ms/step - loss: 2.2701 - accuracy: 0.2451
Epoch 2/3
70/70 [==============================] - 0s 2ms/step - loss: 2.1827 - accuracy: 0.4777
Epoch 3/3
70/70 [==============================] - 0s 2ms/step - loss: 2.0865 - accuracy: 0.5955

<tensorflow.python.keras.callbacks.History at 0x7fc59381ac50>

Konfigurasi Multi-Pekerja

Sekarang mari kita memasuki dunia pelatihan multi-pekerja. Dalam TensorFlow, variabel lingkungan TF_CONFIG diperlukan untuk pelatihan pada beberapa mesin, yang masing-masing mungkin memiliki peran yang berbeda. TF_CONFIG adalah string JSON yang digunakan untuk menentukan konfigurasi cluster pada setiap pekerja yang merupakan bagian dari cluster.

Ada dua komponen TF_CONFIG : cluster dan task . cluster memberikan informasi tentang klaster pelatihan, yang merupakan dikte yang terdiri dari berbagai jenis pekerjaan seperti worker . Dalam pelatihan multi-pekerja dengan MultiWorkerMirroredStrategy , biasanya ada satu worker yang mengambil sedikit lebih banyak tanggung jawab seperti menyimpan pos pemeriksaan dan menulis file ringkasan untuk TensorBoard di samping apa yang worker biasa lakukan. Pekerja tersebut disebut sebagai pekerja chief , dan merupakan kebiasaan bahwa worker dengan index 0 ditunjuk sebagai worker kepala (sebenarnya inilah cara tf.distribute.Strategy diterapkan). task di sisi lain memberikan informasi tugas saat ini. cluster komponen pertama adalah sama untuk semua pekerja, dan task komponen kedua berbeda pada setiap pekerja dan menentukan type dan index pekerja itu.

Dalam contoh ini, kami menetapkan type tugas ke "worker" dan index tugas ke 0 . Ini berarti mesin yang memiliki pengaturan seperti itu adalah pekerja pertama, yang akan ditunjuk sebagai pekerja utama dan melakukan lebih banyak pekerjaan daripada pekerja lain. Perhatikan bahwa mesin lain akan perlu memiliki TF_CONFIG variabel lingkungan set juga, dan harus memiliki sama cluster dict, tetapi berbeda tugas type atau tugas index tergantung pada apa peran mesin-mesin yang.

Untuk tujuan ilustrasi, tutorial ini menunjukkan bagaimana seseorang dapat mengatur TF_CONFIG dengan 2 pekerja di localhost . Dalam praktiknya, pengguna akan membuat beberapa pekerja pada alamat / port IP eksternal, dan mengatur TF_CONFIG pada setiap pekerja dengan tepat.

 os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})
 

Perhatikan bahwa meskipun laju pembelajaran ditetapkan dalam contoh ini, secara umum mungkin perlu untuk menyesuaikan tingkat pembelajaran berdasarkan pada ukuran kumpulan global.

Pilih strategi yang tepat

Dalam TensorFlow, pelatihan terdistribusi terdiri dari pelatihan sinkron, di mana langkah-langkah pelatihan disinkronkan di seluruh pekerja dan replika, dan pelatihan asinkron, di mana langkah-langkah pelatihan tidak disinkronkan secara ketat.

MultiWorkerMirroredStrategy , yang merupakan strategi yang disarankan untuk pelatihan multi-pekerja yang sinkron, akan diperlihatkan dalam panduan ini. Untuk melatih model, gunakan instance dari tf.distribute.experimental.MultiWorkerMirroredStrategy . MultiWorkerMirroredStrategy membuat salinan dari semua variabel di lapisan model pada setiap perangkat di semua pekerja. Ia menggunakan CollectiveOps , sebuah op TensorFlow untuk komunikasi kolektif, untuk mengagregasi gradien dan menjaga variabel tetap sinkron. Panduan tf.distribute.Strategy memiliki detail lebih lanjut tentang strategi ini.

 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

MultiWorkerMirroredStrategy menyediakan beberapa implementasi melalui parameter CollectiveCommunication . RING mengimplementasikan kolektif berbasis cincin menggunakan gRPC sebagai lapisan komunikasi antar host. NCCL menggunakan NCCL Nvidia untuk mengimplementasikan kolektif. AUTO menolak pilihan untuk runtime. Pilihan terbaik implementasi kolektif tergantung pada jumlah dan jenis GPU, dan interkoneksi jaringan di cluster.

Latih model dengan MultiWorkerMirroredStrategy

Dengan integrasi tf.distribute.Strategy API ke tf.keras , satu-satunya perubahan yang akan Anda lakukan untuk mendistribusikan pelatihan kepada multi-pekerja adalah dengan melampirkan pembuatan model dan model.compile() panggil inside strategy.scope() . Ruang lingkup strategi distribusi menentukan bagaimana dan di mana variabel dibuat, dan dalam kasus MultiWorkerMirroredStrategy , variabel yang dibuat adalah MirroredVariable , dan mereka direplikasi pada masing-masing pekerja.

 num_workers = 4

# Here the batch size scales up by number of workers since 
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = build_and_compile_cnn_model()

# Keras' `model.fit()` trains the model with specified number of epochs and
# number of steps per epoch. Note that the numbers here are for demonstration
# purposes only and may not sufficiently produce a model with good quality.
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
 
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
70/70 [==============================] - 0s 3ms/step - loss: 2.2682 - accuracy: 0.2265
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1714 - accuracy: 0.4954
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.0638 - accuracy: 0.6232

<tensorflow.python.keras.callbacks.History at 0x7fc5f4f062e8>

Dataset sharding dan ukuran batch

Dalam pelatihan multi-pekerja dengan MultiWorkerMirroredStrategy , sharding dataset diperlukan untuk memastikan konvergensi dan kinerja. Namun, perhatikan bahwa dalam cuplikan kode di atas, kumpulan data secara langsung diteruskan ke model.fit() tanpa perlu shard; ini karena tf.distribute.Strategy API menangani penyimpangan data secara otomatis. Itu pecahan dataset pada tingkat file yang dapat membuat pecahan miring. Dalam kasus ekstrim di mana hanya ada satu file, hanya beling pertama (yaitu pekerja) akan mendapatkan data pelatihan atau evaluasi dan sebagai hasilnya semua pekerja akan mendapatkan kesalahan.

Jika Anda lebih suka sharding manual untuk pelatihan Anda, sharding otomatis dapat dimatikan melalui tf.data.experimental.DistributeOptions api. Secara konkret,

 options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
dataset_no_auto_shard = multi_worker_dataset.with_options(options)
 

Hal lain yang perlu diperhatikan adalah ukuran batch untuk datasets . Dalam cuplikan kode di atas, kami menggunakan global_batch_size = per_worker_batch_size * num_workers , yang num_workers kali lebih besar dari kasus itu untuk pekerja tunggal, karena ukuran per pekerja yang efektif adalah ukuran batch global (parameter dilewatkan dalam tf.data.Dataset.batch() ) dibagi dengan jumlah pekerja, dan dengan perubahan ini kami menjaga ukuran batch per pekerja sama seperti sebelumnya.

Evaluasi

Jika Anda meneruskan validation_data ke model.fit , itu akan bergantian antara pelatihan dan evaluasi untuk setiap zaman. validation_data pengambilan data evaluasi didistribusikan di antara pekerja yang sama dan hasil evaluasi dikumpulkan dan tersedia untuk semua pekerja. Mirip dengan pelatihan, dataset validasi secara otomatis sharded di tingkat file. Anda perlu mengatur ukuran kumpulan global dalam dataset validasi dan mengatur validation_steps . Dataset berulang juga direkomendasikan untuk evaluasi.

Atau, Anda juga dapat membuat tugas lain yang secara berkala membaca pos pemeriksaan dan menjalankan evaluasi. Inilah yang dilakukan Estimator. Tetapi ini bukan cara yang disarankan untuk melakukan evaluasi dan dengan demikian rinciannya dihilangkan.

Ramalan

Saat ini model.predict tidak berfungsi dengan MultiWorkerMirroredStrategy.

Performa

Anda sekarang memiliki model Keras yang semuanya siap dijalankan di banyak pekerja dengan MultiWorkerMirroredStrategy . Anda dapat mencoba teknik-teknik berikut untuk menyesuaikan kinerja pelatihan multi-pekerja dengan MultiWorkerMirroredStrategy .

  • MultiWorkerMirroredStrategy menyediakan beberapa implementasi komunikasi kolektif . RING mengimplementasikan kolektif berbasis cincin menggunakan gRPC sebagai lapisan komunikasi antar host. NCCL menggunakan NCCL Nvidia untuk mengimplementasikan kolektif. AUTO menolak pilihan untuk runtime. Pilihan terbaik implementasi kolektif tergantung pada jumlah dan jenis GPU, dan interkoneksi jaringan di cluster. Untuk mengganti pilihan otomatis, tentukan nilai yang valid ke parameter communication konstruktor MultiWorkerMirroredStrategy , misalnya communication=tf.distribute.experimental.CollectiveCommunication.NCCL .
  • tf.float variabel ke tf.float jika memungkinkan. Model ResNet resmi mencakup contoh bagaimana hal ini dapat dilakukan.

Toleransi kesalahan

Dalam pelatihan sinkron, kluster akan gagal jika salah satu pekerja gagal dan tidak ada mekanisme pemulihan kegagalan. Menggunakan Keras dengan tf.distribute.Strategy datang dengan keuntungan dari toleransi kesalahan dalam kasus di mana pekerja meninggal atau tidak stabil. Kami melakukan ini dengan mempertahankan status pelatihan dalam sistem file terdistribusi pilihan Anda, sehingga setelah memulai kembali instance yang sebelumnya gagal atau preempt, status pelatihan dipulihkan.

Karena semua pekerja disinkronkan dalam hal masa pelatihan dan langkah-langkah, pekerja lain perlu menunggu pekerja yang gagal atau yang disuruh untuk memulai kembali untuk melanjutkan.

Callback ModelCheckpoint

Callback ModelCheckpoint tidak lagi menyediakan fungsionalitas toleransi kesalahan, silakan gunakan callback BackupAndRestore sebagai gantinya.

Callback ModelCheckpoint masih dapat digunakan untuk menyimpan pos-pos pemeriksaan. Tetapi dengan ini, jika pelatihan terganggu atau berhasil diselesaikan, untuk melanjutkan pelatihan dari pos pemeriksaan, pengguna bertanggung jawab untuk memuat model secara manual. Secara opsional, pengguna dapat memilih untuk menyimpan dan mengembalikan model / bobot di luar callback ModelCheckpoint .

Penghematan dan pemuatan model

Untuk menyimpan model Anda menggunakan model.save atau tf.saved_model.save , tujuan untuk menyimpan harus berbeda untuk setiap pekerja. Pada pekerja non-kepala, Anda harus menyimpan model ke direktori sementara, dan pada kepala, Anda harus menyimpan ke direktori model yang disediakan. Direktori sementara pada pekerja harus unik untuk mencegah kesalahan yang dihasilkan dari beberapa pekerja yang mencoba menulis ke lokasi yang sama. Model yang disimpan di semua direktori adalah identik dan biasanya hanya model yang disimpan oleh kepala harus dirujuk untuk memulihkan atau melayani. Kami menyarankan Anda memiliki beberapa logika pembersihan yang menghapus direktori sementara yang dibuat oleh para pekerja setelah pelatihan Anda selesai.

Alasan Anda perlu menghemat pada kepala dan pekerja pada saat yang sama, adalah karena Anda mungkin menjumlahkan variabel selama pos pemeriksaan yang mengharuskan kepala dan pekerja untuk berpartisipasi dalam protokol komunikasi allreduce. Di sisi lain, membiarkan kepala dan pekerja menyimpan ke direktori model yang sama akan menghasilkan kesalahan karena pertikaian.

Dengan MultiWorkerMirroredStrategy , program dijalankan pada setiap pekerja, dan untuk mengetahui apakah pekerja saat ini adalah kepala, kami mengambil keuntungan dari objek cluster resolver yang memiliki atribut task_type dan task_id . task_type memberi tahu Anda apa pekerjaan saat ini (misalnya 'pekerja'), dan task_id memberi tahu Anda pengidentifikasi pekerja tersebut. Pekerja dengan id 0 ditunjuk sebagai pekerja kepala.

Dalam cuplikan kode di bawah ini, write_filepath menyediakan path file untuk ditulis, yang tergantung pada id pekerja. Dalam kasus chief (pekerja dengan id 0), ia menulis ke path file asli; bagi yang lain, ia membuat direktori sementara (dengan id di jalur direktori) untuk menulis:

 model_path = '/tmp/keras-model'

def _is_chief(task_type, task_id):
  # If `task_type` is None, this may be operating as single worker, which works 
  # effectively as chief.
  return task_type is None or task_type == 'chief' or (
            task_type == 'worker' and task_id == 0)

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)
 

Dengan itu, Anda sekarang siap untuk menyimpan:

 multi_worker_model.save(write_model_path)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/keras-model/assets

Seperti yang kami jelaskan di atas, nanti model hanya boleh dimuat dari path path save to, jadi mari kita hapus yang sementara yang disimpan oleh pekerja non-chief:

 if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(os.path.dirname(write_model_path))
 

Sekarang, ketika tiba waktunya untuk memuat, mari gunakan API tf.keras.models.load_model nyaman, dan lanjutkan dengan pekerjaan lebih lanjut. Di sini, kami mengasumsikan hanya menggunakan satu pekerja untuk memuat dan melanjutkan pelatihan, dalam hal ini Anda tidak memanggil tf.keras.models.load_model dalam strategy.scope() .

 loaded_model = tf.keras.models.load_model(model_path)

# Now that we have the model restored, and can continue with the training.
loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)
 
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9825 - accuracy: 0.1102
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9367 - accuracy: 0.1117

<tensorflow.python.keras.callbacks.History at 0x7fc5f4b0d8d0>

Penyimpanan dan pemulihan pos pemeriksaan

Di sisi lain, pos pemeriksaan memungkinkan Anda untuk menyimpan bobot model dan mengembalikannya tanpa harus menyimpan keseluruhan model. Di sini, Anda akan membuat satu tf.train.Checkpoint yang melacak model, yang dikelola oleh tf.train.CheckpointManager sehingga hanya pos pemeriksaan terbaru yang dipertahankan.

 checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
  checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
 

Setelah CheckpointManager diatur, Anda sekarang siap untuk menyimpan, dan menghapus pos-pos pemeriksaan pekerja non-kepala yang disimpan.

 checkpoint_manager.save()
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(write_checkpoint_dir)
 

Sekarang, ketika Anda perlu mengembalikan, Anda dapat menemukan pos pemeriksaan terakhir disimpan menggunakan fungsi tf.train.latest_checkpoint nyaman. Setelah memulihkan pos pemeriksaan, Anda dapat melanjutkan dengan pelatihan.

 latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
 
Epoch 1/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9841 - accuracy: 0.6561
Epoch 2/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9445 - accuracy: 0.6805

<tensorflow.python.keras.callbacks.History at 0x7fc5f49d9d30>

Callback BackupAndRestore

BackupAndRestore callback memberikan kesalahan fungsi toleransi, dengan back up model dan nomor zaman saat ini di file pos pemeriksaan sementara di bawah backup_dir argumen untuk BackupAndRestore . Ini dilakukan pada akhir setiap zaman.

Setelah pekerjaan terganggu dan restart, callback mengembalikan pos pemeriksaan terakhir, dan pelatihan berlanjut dari awal zaman yang terganggu. Pelatihan parsial apa pun yang sudah dilakukan di zaman yang belum selesai sebelum interupsi akan dibuang, sehingga tidak mempengaruhi keadaan model akhir.

Untuk menggunakannya, berikan instance dari tf.keras.callbacks.experimental.BackupAndRestore di panggilan tf.keras.Model.fit() .

Dengan MultiWorkerMirroredStrategy, jika seorang pekerja terganggu, seluruh cluster berhenti sampai pekerja yang terganggu itu dimulai kembali. Pekerja lain juga akan memulai kembali, dan pekerja yang terganggu bergabung kembali dengan cluster. Kemudian, setiap pekerja membaca file pos pemeriksaan yang sebelumnya disimpan dan mengambil status sebelumnya, sehingga memungkinkan kluster untuk kembali sinkron. Kemudian pelatihan berlanjut.

Callback BackupAndRestore menggunakan CheckpointManager untuk menyimpan dan mengembalikan status pelatihan, yang menghasilkan file yang disebut pos pemeriksaan yang melacak pos pemeriksaan yang ada bersama dengan yang terbaru. Karena alasan ini, backup_dir tidak boleh digunakan kembali untuk menyimpan pos pemeriksaan lain untuk menghindari benturan nama.

Saat ini, panggilan balik BackupAndRestore mendukung pekerja tunggal tanpa strategi, MirroredStrategy, dan multi-pekerja dengan MultiWorkerMirroredStrategy. Di bawah ini adalah dua contoh untuk pelatihan multi-pekerja dan pelatihan pekerja tunggal.

 # Multi-worker training with MultiWorkerMirroredStrategy.

callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
                       epochs=3,
                       steps_per_epoch=70,
                       callbacks=callbacks)
 
Epoch 1/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2837 - accuracy: 0.1836
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2131 - accuracy: 0.4091
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1310 - accuracy: 0.5485

<tensorflow.python.keras.callbacks.History at 0x7fc5f49a3080>

Jika Anda memeriksa direktori backup_dir Anda tentukan di BackupAndRestore , Anda mungkin melihat beberapa file pos pemeriksaan sementara. File-file itu diperlukan untuk memulihkan instance yang sebelumnya hilang, dan mereka akan dihapus oleh perpustakaan di akhir tf.keras.Model.fit() setelah berhasil keluar dari pelatihan Anda.

Lihat juga

  1. Pelatihan Terdistribusi dalam panduan TensorFlow memberikan gambaran umum tentang strategi distribusi yang tersedia.
  2. Model resmi , banyak di antaranya dapat dikonfigurasi untuk menjalankan beberapa strategi distribusi.
  3. Bagian Kinerja dalam panduan ini memberikan informasi tentang strategi dan alat lain yang dapat Anda gunakan untuk mengoptimalkan kinerja model TensorFlow Anda.