Bermigrasi dari TF1 ke TF2 dengan TensorFlow Hub

Halaman ini menjelaskan cara tetap menggunakan TensorFlow Hub saat memigrasikan kode TensorFlow Anda dari TensorFlow 1 ke TensorFlow 2. Halaman ini melengkapi panduan migrasi umum TensorFlow .

Untuk TF2, TF Hub telah beralih dari hub.Module API lama untuk membuat tf.compat.v1.Graph seperti yang dilakukan tf.contrib.v1.layers . Sebagai gantinya, sekarang ada hub.KerasLayer untuk digunakan bersama lapisan Keras lainnya untuk membangun tf.keras.Model (biasanya di lingkungan eksekusi bersemangat baru TF2) dan metode hub.load() yang mendasarinya untuk kode TensorFlow tingkat rendah.

API hub.Module tetap tersedia di pustaka tensorflow_hub untuk digunakan di TF1 dan dalam mode kompatibilitas TF1 di TF2. Itu hanya dapat memuat model dalam format TF1 Hub .

API baru hub.load() dan hub.KerasLayer berfungsi untuk TensorFlow 1.15 (dalam mode bersemangat dan grafik) dan di TensorFlow 2. API baru ini dapat memuat aset TF2 SavedModel baru, dan, dengan batasan yang ditetapkan dalam model panduan kompatibilitas , model lawas dalam format TF1 Hub.

Secara umum, disarankan untuk menggunakan API baru jika memungkinkan.

Ringkasan API baru

hub.load() adalah fungsi tingkat rendah baru untuk memuat SavedModel dari TensorFlow Hub (atau layanan yang kompatibel). Ini membungkus tf.saved_model.load() TF2; Panduan SavedModel TensorFlow menjelaskan apa yang dapat Anda lakukan dengan hasilnya.

m = hub.load(handle)
outputs = m(inputs)

Kelas hub.KerasLayer memanggil hub.load() dan mengadaptasi hasilnya untuk digunakan di Keras bersama lapisan Keras lainnya. (Ini bahkan mungkin merupakan pembungkus yang nyaman untuk SavedModels yang dimuat yang digunakan dengan cara lain.)

model = tf.keras.Sequential([
    hub.KerasLayer(handle),
    ...])

Banyak tutorial yang menunjukkan cara kerja API ini. Berikut beberapa contohnya:

Menggunakan API baru dalam pelatihan Estimator

Jika Anda menggunakan TF2 SavedModel di Estimator untuk pelatihan dengan server parameter (atau sebaliknya dalam Sesi TF1 dengan variabel yang ditempatkan pada perangkat jarak jauh), Anda perlu menyetel experimental.share_cluster_devices_in_session di ConfigProto tf.Session, atau Anda akan mendapatkan kesalahan seperti "Perangkat yang ditetapkan '/job:ps/replica:0/task:0/device:CPU:0' tidak cocok dengan perangkat mana pun."

Opsi yang diperlukan dapat diatur sesuai keinginan

session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)

Dimulai dengan TF2.2, opsi ini tidak lagi bersifat eksperimental, dan bagian .experimental dapat dihilangkan.

Memuat model lama dalam format TF1 Hub

Bisa saja TF2 SavedModel baru belum tersedia untuk kasus penggunaan Anda dan Anda perlu memuat model lama dalam format TF1 Hub. Mulai tensorflow_hub rilis 0.7, Anda dapat menggunakan model lama dalam format TF1 Hub bersama dengan hub.KerasLayer seperti yang ditunjukkan di bawah ini:

m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)

Selain itu KerasLayer memperlihatkan kemampuan untuk menentukan tags , signature , output_key dan signature_outputs_as_dict untuk penggunaan model lama yang lebih spesifik dalam format TF1 Hub dan SavedModels lama.

Untuk informasi lebih lanjut tentang kompatibilitas format TF1 Hub, lihat panduan kompatibilitas model .

Menggunakan API tingkat yang lebih rendah

Model format Hub TF1 lama dapat dimuat melalui tf.saved_model.load . Alih-alih

# DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)

disarankan untuk menggunakan:

# TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)

Dalam contoh ini m.signatures adalah dict fungsi konkret TensorFlow yang diberi kunci berdasarkan nama tanda tangan. Memanggil fungsi seperti itu akan menghitung semua outputnya, meskipun tidak digunakan. (Ini berbeda dari evaluasi malas mode grafik TF1.)