Google I / O 18-20 मई को लौटता है! जगह आरक्षित करें और अपना शेड्यूल बनाएं अभी रजिस्टर करें
इस पेज का अनुवाद Cloud Translation API से किया गया है.
Switch to English

ट्रांसफर लर्निंग और फाइन-ट्यूनिंग

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

सेट अप

import numpy as np
import tensorflow as tf
from tensorflow import keras

परिचय

ट्रांसफर लर्निंग में एक समस्या पर सीखी गई विशेषताओं को लेना और एक नई, समान समस्या पर उनका लाभ उठाना शामिल है। उदाहरण के लिए, एक मॉडल से सुविधाएँ जो रैकोनों की पहचान करने के लिए सीखी हैं, एक मॉडल को तनुकिस की पहचान करने के लिए किक-स्टार्ट करना उपयोगी हो सकता है।

ट्रांसफर लर्निंग आमतौर पर उन कार्यों के लिए किया जाता है जहां आपके डेटासेट में खरोंच से पूर्ण पैमाने पर मॉडल को प्रशिक्षित करने के लिए बहुत कम डेटा होता है।

गहरी शिक्षा के संदर्भ में स्थानांतरण सीखने का सबसे आम अवतार निम्नलिखित वर्कफ़्लो है:

  1. पहले से प्रशिक्षित मॉडल से परतें लें।
  2. उन्हें फ्रीज करें, ताकि भविष्य में प्रशिक्षण दौर के दौरान उनमें से किसी भी जानकारी को नष्ट करने से बचें।
  3. जमे हुए परतों के शीर्ष पर कुछ नई, ट्रेन करने योग्य परतें जोड़ें। वे पुरानी विशेषताओं को नए डेटासेट पर भविष्यवाणियों में बदलना सीखेंगे।
  4. अपने डेटासेट पर नई परतों को प्रशिक्षित करें।

एक अंतिम, वैकल्पिक कदम, ठीक-ट्यूनिंग है , जिसमें आपके द्वारा ऊपर प्राप्त पूरे मॉडल (या इसके कुछ भाग) को अनफ्रीज करना शामिल है, और बहुत कम सीखने की दर के साथ नए डेटा पर इसे फिर से प्रशिक्षित करना है। यह संभावित रूप से नए डेटा के लिए प्रीट्रेन की गई विशेषताओं को बढ़ाकर, सार्थक सुधार प्राप्त कर सकता है।

सबसे पहले, हम केरस trainable एपीआई के बारे में विस्तार से जानेंगे, जो ज्यादातर ट्रांसफर लर्निंग और फाइन-ट्यूनिंग वर्कफ़्लोज़ को पूरा करता है।

फिर, हम ImageNet डेटासेट पर दिखाये गये मॉडल को ले कर, और कगेल "बिल्लियों बनाम कुत्तों" वर्गीकरण डेटासेट पर इसे फिर से दिखाते हुए विशिष्ट वर्कफ़्लो का प्रदर्शन करेंगे।

यह दीप लर्निंग से पाइथन और 2016 के ब्लॉग पोस्ट "बहुत कम डेटा का उपयोग करके शक्तिशाली छवि वर्गीकरण मॉडल का निर्माण" से अनुकूलित है।

बर्फ़ीली परतें: trainable विशेषता को समझना

परत और मॉडल में तीन वजन विशेषताएं होती हैं:

  • weights , लेयर के सभी वेट चर की सूची है।
  • trainable_weights उन लोगों की सूची है जिन्हें प्रशिक्षण के दौरान नुकसान को कम करने के लिए अद्यतन (ढाल वंश के माध्यम से) किया जाना है।
  • non_trainable_weights उन लोगों की सूची है जो प्रशिक्षित होने के लिए नहीं हैं। आमतौर पर वे फॉरवर्ड पास के दौरान मॉडल द्वारा अपडेट किए जाते हैं।

उदाहरण: Dense परत में 2 ट्रेन वज़न (कर्नेल और पूर्वाग्रह) हैं

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

सामान्य तौर पर, सभी भार ट्रेन योग्य वजन होते हैं। एकमात्र अंतर्निहित परत जिसमें गैर-ट्रेन करने योग्य भार है, BatchNormalization परत है। यह प्रशिक्षण के दौरान अपने आदानों के माध्य और विचरण का ध्यान रखने के लिए गैर-प्रशिक्षित वजन का उपयोग करता है। अपने स्वयं के कस्टम परतों में गैर-प्रशिक्षित वजन का उपयोग करने का तरीका जानने के लिए, खरोंच से नई परतें लिखने के लिए मार्गदर्शिका देखें।

उदाहरण: BatchNormalization लेयर में 2 ट्रेन वेट और 2 नॉन-ट्रेनेबल वेट हैं

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

परतें और मॉडल भी एक बूलियन विशेषता सुविधा trainable । इसका मूल्य बदला जा सकता है। स्थापना layer.trainable को False गैर trainable को trainable से सभी परत के वजन ले जाता है। इसे "फ्रीजिंग" परत कहा जाता है: एक स्थिर परत की स्थिति को प्रशिक्षण के दौरान अपडेट नहीं किया जाएगा (या तो जब fit() साथ प्रशिक्षण हो fit() या किसी भी कस्टम लूप के साथ प्रशिक्षण जब ढाल अद्यतन लागू करने के लिए trainable_weights पर निर्भर करता है)।

उदाहरण: False लिए trainable स्थापना

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

जब एक ट्रेन का वजन गैर-प्रशिक्षित हो जाता है, तो प्रशिक्षण के दौरान इसका मूल्य अपडेट नहीं किया जाता है।

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 664ms/step - loss: 0.1025

layer.__call__() training में layer.__call__() training साथ layer.trainablelayer.trainable विशेषता को भ्रमित न layer.__call__() । अधिक जानकारी के लिए, केयर्स FAQ देखें।

trainable विशेषता की पुनरावर्ती सेटिंग

यदि आप किसी मॉडल पर या सबलेयर्स वाली किसी भी लेयर पर trainable = False करने trainable = False सेट करते हैं, तो सभी बच्चों की लेयर्स नॉन-ट्रेनेबल भी हो जाती हैं।

उदाहरण:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively
है

सामान्य हस्तांतरण-अधिगम वर्कफ़्लो

यह हमें बताता है कि कैसेरे में एक विशिष्ट ट्रांसफर लर्निंग वर्कफ़्लो लागू किया जा सकता है:

  1. एक बेस मॉडल को इंस्टेंट करना और उसमें पहले से प्रशिक्षित वेट लोड करना।
  2. बेस मॉडल में सभी परतों को trainable = False करने trainable = False सेट करके फ्रीज करें।
  3. बेस मॉडल से एक (या कई) परतों के आउटपुट के शीर्ष पर एक नया मॉडल बनाएं।
  4. अपने नए मॉडल को अपने नए डेटासेट पर प्रशिक्षित करें।

ध्यान दें कि एक वैकल्पिक, अधिक हल्के वर्कफ़्लो भी हो सकते हैं:

  1. एक बेस मॉडल को इंस्टेंट करना और उसमें पहले से प्रशिक्षित वेट लोड करना।
  2. अपने नए डेटासेट को इसके माध्यम से चलाएं और बेस मॉडल से एक (या कई) परतों के आउटपुट को रिकॉर्ड करें। इसे फीचर निष्कर्षण कहा जाता है।
  3. नए, छोटे मॉडल के लिए इनपुट डेटा के रूप में उस आउटपुट का उपयोग करें।

उस दूसरे वर्कफ़्लो का एक महत्वपूर्ण लाभ यह है कि आप केवल एक बार अपने डेटा पर आधार मॉडल चलाते हैं, बजाय एक बार प्रशिक्षण के। तो यह बहुत तेज़ और सस्ता है।

उस दूसरे वर्कफ़्लो के साथ एक समस्या, हालांकि, यह है कि यह आपको प्रशिक्षण के दौरान अपने नए मॉडल के इनपुट डेटा को गतिशील रूप से संशोधित करने की अनुमति नहीं देता है, जो कि उदाहरण के लिए, डेटा वृद्धि करते समय आवश्यक है। स्थानांतरण सीखने का उपयोग आमतौर पर उन कार्यों के लिए किया जाता है जब आपके नए डेटासेट में खरोंच से पूर्ण पैमाने पर मॉडल को प्रशिक्षित करने के लिए बहुत कम डेटा होता है, और ऐसे परिदृश्यों में डेटा वृद्धि बहुत महत्वपूर्ण है। तो इस प्रकार, हम पहले वर्कफ़्लो पर ध्यान केंद्रित करेंगे।

यहाँ Kirs में पहली वर्कफ़्लो कैसा दिखता है:

सबसे पहले, पूर्व-प्रशिक्षित भार के साथ एक आधार मॉडल को तत्काल।

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

फिर, बेस मॉडल को फ्रीज करें।

base_model.trainable = False

शीर्ष पर एक नया मॉडल बनाएँ।

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

मॉडल को नए डेटा पर प्रशिक्षित करें।

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

फ़ाइन ट्यूनिंग

एक बार जब आपका मॉडल नए डेटा में परिवर्तित हो जाता है, तो आप बेस मॉडल के सभी या कुछ हिस्सों को अनफ्रीज करने की कोशिश कर सकते हैं और बहुत कम लर्निंग रेट के साथ पूरे मॉडल को एंड-टू-एंड रिटेन कर सकते हैं।

यह एक वैकल्पिक अंतिम चरण है जो संभावित रूप से आपको वृद्धिशील सुधार दे सकता है। यह संभावित रूप से त्वरित ओवरफिटिंग को भी जन्म दे सकता है - इसे ध्यान में रखें।

यह बाद जमे हुए परतों के साथ मॉडल अभिसरण करने के लिए प्रशिक्षित किया गया है केवल इस कदम करने के लिए महत्वपूर्ण है। यदि आप पहले से प्रशिक्षित सुविधाओं को रखने वाली ट्रेन की परतों के साथ बेतरतीब ढंग से प्रारंभिक रेल योग्य परतों को मिलाते हैं, तो बेतरतीब ढंग से आरंभ की गई परतें प्रशिक्षण के दौरान बहुत बड़े ग्रेडिएंट अपडेट का कारण बनेंगी, जो आपके पूर्व-प्रशिक्षित सुविधाओं को नष्ट कर देगा।

इस चरण में बहुत कम सीखने की दर का उपयोग करना भी महत्वपूर्ण है, क्योंकि आप प्रशिक्षण के पहले दौर की तुलना में बहुत बड़े मॉडल को प्रशिक्षित कर रहे हैं, जो कि आमतौर पर बहुत छोटा होता है। परिणामस्वरूप, यदि आप बड़े वजन अपडेट लागू करते हैं, तो आपको बहुत जल्दी ओवरफिटिंग का खतरा है। यहां, आप केवल एक वृद्धिशील तरीके से प्रिटेंडेड वेट को फिर से पढ़ना चाहते हैं।

यह है कि पूरे बेस मॉडल की फाइन-ट्यूनिंग को कैसे लागू किया जाए:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

compile() और trainable बारे में महत्वपूर्ण नोट

एक मॉडल पर compile() कॉलिंग का मतलब उस मॉडल के व्यवहार को "फ्रीज" करना है। इसका तात्पर्य यह है कि जिस समय मॉडल संकलित किया जाता है उस समय में trainable विशेषता मान उस मॉडल के जीवन भर संरक्षित किया जाना चाहिए, जब तक कि compile फिर से नहीं कहा जाता है। इसलिए, यदि आप किसी भी trainable मूल्य को बदलते हैं, तो अपने परिवर्तनों को ध्यान में रखने के लिए अपने मॉडल पर फिर से compile() कॉल करना सुनिश्चित करें।

BatchNormalization परत के बारे में महत्वपूर्ण नोट्स

कई इमेज मॉडल में BatchNormalization परतें होती हैं। यह परत प्रत्येक कल्पनाशील गिनती पर एक विशेष मामला है। यहाँ कुछ बातों का ध्यान रखना है।

  • BatchNormalization में 2 गैर-प्रशिक्षित वजन होते हैं जो प्रशिक्षण के दौरान अद्यतन होते हैं। ये इनपुट के माध्य और विचरण को ट्रैक करने वाले चर हैं।
  • जब आप bn_layer.trainable = False सेट bn_layer.trainable = False , तो BatchNormalization लेयर BatchNormalization मोड में चलेगा, और इसके माध्य और विचरण के आँकड़े को अपडेट नहीं करेगा। यह सामान्य रूप से अन्य परतों के लिए मामला नहीं है, क्योंकि वजन प्रशिक्षण और अनुमान / प्रशिक्षण मोड दो ऑर्थोगोनल अवधारणाएं हैं । लेकिन दोनों BatchNormalization परत के मामले में बंधे हैं।
  • जब आप एक मॉडल को ठीक करते हैं जिसमें ठीक-ट्यूनिंग करने के लिए BatchNormalization परतें होती हैं, तो आपको बेस मॉडल को कॉल करते समय training=False पास करने से training=False BatchNormalization परतों training=False तरीके से रखना चाहिए। अन्यथा गैर-ट्रेन योग्य भार पर लागू किए गए अपडेट अचानक मॉडल ने जो सीखा है उसे नष्ट कर देगा।

आप इस मार्गदर्शिका के अंत में एंड-टू-एंड उदाहरण एक्शन में इस पैटर्न को देखेंगे।

एक कस्टम ट्रेनिंग लूप के साथ लर्निंग और फाइन-ट्यूनिंग ट्रांसफर करें

यदि fit() बजाय, आप अपने स्वयं के निम्न-स्तरीय प्रशिक्षण लूप का उपयोग कर रहे हैं, तो वर्कफ़्लो अनिवार्य रूप से समान है। जब आप ग्रेडिएंट अपडेट लागू करते हैं तो आपको केवल सूची model.trainable_weights को ध्यान में रखना चाहिए।

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

इसी तरह फाइन-ट्यूनिंग के लिए।

एंड-टू-एंड उदाहरण: बिल्लियों बनाम कुत्तों के डेटासेट पर एक छवि वर्गीकरण मॉडल को ठीक-ट्यूनिंग

इन अवधारणाओं को ठोस बनाने के लिए, आइए आपको एक ठोस अंत-टू-एंड हस्तांतरण सीखने और ठीक-ट्यूनिंग उदाहरण के माध्यम से चलते हैं। हम Xception मॉडल को लोड करेंगे, ImageNet पर पूर्व-प्रशिक्षित, और इसे कागले "बिल्लियों बनाम कुत्तों" वर्गीकरण डेटासेट पर उपयोग करें।

डेटा प्राप्त करना

पहले, चलो TFDS का उपयोग करके बिल्लियों बनाम कुत्तों के डेटासेट प्राप्त करते हैं। यदि आपके पास अपना स्वयं का डेटासेट है, तो आप संभवतः क्लास-विशिष्ट फ़ोल्डरों में दर्ज डिस्क पर छवियों के सेट से इसी तरह के लेबल डेटासेट वस्तुओं को उत्पन्न करने के लिए उपयोगिता tf.keras.preprocessing.image_dataset_from_directory का उपयोग करना चाहते हैं।

बहुत छोटे डेटासेट के साथ काम करते समय ट्रांसफर लर्निंग सबसे उपयोगी है। हमारे डेटासेट को छोटा रखने के लिए, हम प्रशिक्षण के लिए मूल प्रशिक्षण डेटा (25,000 चित्र) का 40%, सत्यापन के लिए 10% और परीक्षण के लिए 10% का उपयोग करेंगे।

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

प्रशिक्षण डाटासेट में ये पहले 9 चित्र हैं - जैसा कि आप देख सकते हैं, वे सभी अलग-अलग आकार हैं।

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

पींग

हम यह भी देख सकते हैं कि लेबल 1 "कुत्ता" है और लेबल 0 "बिल्ली" है।

डेटा का मानकीकरण

हमारी कच्ची छवियों में विभिन्न आकार हैं। इसके अलावा, प्रत्येक पिक्सेल में 0 और 255 (RGB स्तर मान) के बीच 3 पूर्णांक मान होते हैं। यह एक तंत्रिका नेटवर्क को खिलाने के लिए एक महान फिट नहीं है। हमें 2 चीजें करने की जरूरत है:

  • एक निश्चित छवि आकार के लिए मानकीकृत करें। हम 150x150 लेते हैं।
  • -1 और 1. के बीच पिक्सेल मानों को सामान्य करें। हम Normalization परत का उपयोग मॉडल के हिस्से के रूप में करेंगे।

सामान्य तौर पर, यह उन मॉडल को विकसित करने के लिए एक अच्छा अभ्यास है जो कच्चे डेटा को इनपुट के रूप में लेते हैं, जैसा कि उन मॉडलों के विपरीत है जो पहले से ही प्रीप्रोसेस किए गए डेटा लेते हैं। इसका कारण यह है कि, यदि आपका मॉडल प्रीप्रोसेड डेटा की अपेक्षा करता है, तो किसी भी समय आप अपने मॉडल को कहीं और (एक वेब ब्राउज़र में, मोबाइल ऐप में) उपयोग करने के लिए निर्यात करते हैं, तो आपको ठीक उसी प्रीप्रोसेसिंग पाइपलाइन को फिर से लागू करना होगा। यह बहुत जल्दी बहुत मुश्किल हो जाता है। इसलिए हमें मॉडल को मारने से पहले प्रीप्रोसेसिंग की कम से कम संभव मात्रा में करना चाहिए।

यहां, हम डेटा पाइपलाइन में आकार बदलने वाली छवि करेंगे (क्योंकि एक गहरे तंत्रिका नेटवर्क केवल डेटा के सन्निहित बैचों को संसाधित कर सकते हैं), और हम मॉडल के भाग के रूप में इनपुट मान स्केलिंग करेंगे, जब हम इसे बनाते हैं।

आइए छवियों का आकार 150x150 करें:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

इसके अलावा, चलो डेटा को बैचते हैं और लोडिंग की गति को अनुकूलित करने के लिए कैशिंग और प्रीफेटिंग का उपयोग करते हैं।

03 ए 3202780

यादृच्छिक डेटा वृद्धि का उपयोग करना

जब आपके पास एक बड़ी छवि डेटासेट नहीं है, तो प्रशिक्षण छवियों में यादृच्छिक अभी तक यथार्थवादी परिवर्तनों को लागू करके नमूना विविधता को कृत्रिम रूप से प्रस्तुत करना एक अच्छा अभ्यास है, जैसे कि यादृच्छिक क्षैतिज फ़्लिपिंग या छोटे यादृच्छिक घुमाव। यह ओवरफिटिंग को धीमा करते हुए मॉडल को प्रशिक्षण डेटा के विभिन्न पहलुओं को उजागर करने में मदद करता है।

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

आइए देखें कि विभिन्न यादृच्छिक परिवर्तनों के बाद पहले बैच की पहली छवि कैसी दिखती है:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")

पींग

मॉडल बनाना

अब हमने एक मॉडल बनाया है जो पहले बताए गए ब्लूप्रिंट का अनुसरण करता है।

ध्यान दें कि:

  • हम इनपुट मानों को स्केल करने के लिए एक Normalization लेयर (शुरुआत में [0, 255] रेंज में) को [-1, 1] रेंज में जोड़ते हैं।
  • हम वर्गीकरण परत से पहले एक Dropout परत जोड़ते हैं, नियमितीकरण के लिए।
  • हम बेस मॉडल को कॉल करते समय training=False से पास करना सुनिश्चित करते हैं, ताकि यह एक इनवेंशन मोड में चले, ताकि हम ठीक-ट्यूनिंग के लिए बेस मॉडल को अनफ्रीज करने के बाद भी बैटमॉर्न आँकड़े अपडेट न हों।
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 1s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

शीर्ष परत को प्रशिक्षित करें

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 20s 49ms/step - loss: 0.2226 - binary_accuracy: 0.8972 - val_loss: 0.0805 - val_binary_accuracy: 0.9703
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1246 - binary_accuracy: 0.9464 - val_loss: 0.0757 - val_binary_accuracy: 0.9712
Epoch 3/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1153 - binary_accuracy: 0.9480 - val_loss: 0.0724 - val_binary_accuracy: 0.9733
Epoch 4/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1055 - binary_accuracy: 0.9575 - val_loss: 0.0753 - val_binary_accuracy: 0.9721
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1026 - binary_accuracy: 0.9589 - val_loss: 0.0750 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1022 - binary_accuracy: 0.9587 - val_loss: 0.0723 - val_binary_accuracy: 0.9716
Epoch 7/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1009 - binary_accuracy: 0.9570 - val_loss: 0.0731 - val_binary_accuracy: 0.9708
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0947 - binary_accuracy: 0.9576 - val_loss: 0.0726 - val_binary_accuracy: 0.9716
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0872 - binary_accuracy: 0.9624 - val_loss: 0.0720 - val_binary_accuracy: 0.9712
Epoch 10/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0892 - binary_accuracy: 0.9622 - val_loss: 0.0711 - val_binary_accuracy: 0.9716
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0987 - binary_accuracy: 0.9608 - val_loss: 0.0752 - val_binary_accuracy: 0.9712
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0962 - binary_accuracy: 0.9595 - val_loss: 0.0715 - val_binary_accuracy: 0.9738
Epoch 13/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0972 - binary_accuracy: 0.9606 - val_loss: 0.0700 - val_binary_accuracy: 0.9725
Epoch 14/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9568 - val_loss: 0.0779 - val_binary_accuracy: 0.9690
Epoch 15/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0929 - binary_accuracy: 0.9614 - val_loss: 0.0700 - val_binary_accuracy: 0.9729
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0937 - binary_accuracy: 0.9610 - val_loss: 0.0698 - val_binary_accuracy: 0.9742
Epoch 17/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0945 - binary_accuracy: 0.9613 - val_loss: 0.0671 - val_binary_accuracy: 0.9759
Epoch 18/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0868 - binary_accuracy: 0.9612 - val_loss: 0.0692 - val_binary_accuracy: 0.9738
Epoch 19/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0871 - binary_accuracy: 0.9647 - val_loss: 0.0691 - val_binary_accuracy: 0.9746
Epoch 20/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0922 - binary_accuracy: 0.9603 - val_loss: 0.0721 - val_binary_accuracy: 0.9738
<tensorflow.python.keras.callbacks.History at 0x7fb73f231860>

पूरे मॉडल के फाइन-ट्यूनिंग का एक दौर करें

अंत में, चलो बेस मॉडल को अनफ्रीज करते हैं और कम सीखने की दर के साथ पूरे मॉडल को एंड-टू-एंड ट्रेन करते हैं।

महत्वपूर्ण रूप से, हालांकि बेस मॉडल ट्रेन योग्य हो जाता है, यह अभी भी एक अनुमान मोड में चल रहा है क्योंकि हमने training=False तब पारित training=False है जब हम मॉडल का निर्माण करते समय इसे training=False कहते हैं। इसका मतलब यह है कि बैच के सामान्यीकरण की परतें अपने बैच के आंकड़ों को अपडेट नहीं करेंगी। यदि वे करते हैं, तो वे अब तक मॉडल द्वारा सीखे गए अभ्यावेदन पर कहर ढाएंगे।

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 133ms/step - loss: 0.0814 - binary_accuracy: 0.9677 - val_loss: 0.0527 - val_binary_accuracy: 0.9776
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0544 - binary_accuracy: 0.9796 - val_loss: 0.0537 - val_binary_accuracy: 0.9776
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0481 - binary_accuracy: 0.9822 - val_loss: 0.0471 - val_binary_accuracy: 0.9789
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0324 - binary_accuracy: 0.9871 - val_loss: 0.0551 - val_binary_accuracy: 0.9807
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0298 - binary_accuracy: 0.9899 - val_loss: 0.0447 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0262 - binary_accuracy: 0.9901 - val_loss: 0.0469 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0242 - binary_accuracy: 0.9918 - val_loss: 0.0539 - val_binary_accuracy: 0.9798
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0153 - binary_accuracy: 0.9935 - val_loss: 0.0644 - val_binary_accuracy: 0.9794
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0175 - binary_accuracy: 0.9934 - val_loss: 0.0496 - val_binary_accuracy: 0.9819
Epoch 10/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0171 - binary_accuracy: 0.9936 - val_loss: 0.0496 - val_binary_accuracy: 0.9828
<tensorflow.python.keras.callbacks.History at 0x7fb74f74f940>

10 युगों के बाद, ठीक-ट्यूनिंग हमें यहाँ एक अच्छा सुधार हासिल करता है।