इस पेज का अनुवाद Cloud Translation API से किया गया है.
Switch to English

ट्रांसफर लर्निंग और फाइन-ट्यूनिंग

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

सेट अप

import numpy as np
import tensorflow as tf
from tensorflow import keras

परिचय

ट्रांसफर लर्निंग में एक समस्या पर सीखी गई विशेषताओं को लेना और एक नई, समान समस्या पर उनका लाभ उठाना शामिल है। उदाहरण के लिए, एक मॉडल से सुविधाएँ जो रैकोनों की पहचान करने के लिए सीखी हैं, एक मॉडल को तनुकिस की पहचान करने के लिए किक-स्टार्ट करना उपयोगी हो सकता है।

ट्रांसफर लर्निंग आमतौर पर उन कार्यों के लिए किया जाता है जहां आपके डेटासेट में खरोंच से पूर्ण पैमाने पर मॉडल को प्रशिक्षित करने के लिए बहुत कम डेटा होता है।

गहरी शिक्षा के संदर्भ में स्थानांतरण सीखने का सबसे आम अवतार निम्नलिखित वर्कफ़्लो है:

  1. पहले से प्रशिक्षित मॉडल से परतें लें।
  2. उन्हें फ्रीज करें, ताकि भविष्य में प्रशिक्षण दौर के दौरान उनमें से किसी भी जानकारी को नष्ट करने से बचें।
  3. जमे हुए परतों के शीर्ष पर कुछ नई, ट्रेन करने योग्य परतें जोड़ें। वे पुरानी विशेषताओं को नए डेटासेट पर भविष्यवाणियों में बदलना सीखेंगे।
  4. अपने डेटासेट पर नई परतों को प्रशिक्षित करें।

एक अंतिम, वैकल्पिक कदम, ठीक-ट्यूनिंग है , जिसमें आपके द्वारा ऊपर प्राप्त पूरे मॉडल (या इसके कुछ भाग) को अनफ्रेंड करना शामिल है, और बहुत कम सीखने की दर के साथ नए डेटा पर इसे फिर से प्रशिक्षित करना है। यह संभावित रूप से नए डेटा के लिए प्रीपेन्डेड सुविधाओं को बढ़ाकर, सार्थक सुधार प्राप्त कर सकता है।

सबसे पहले, हम केरस trainable एपीआई के बारे में विस्तार से जानेंगे, जो कि ज्यादातर ट्रांसफर लर्निंग और फाइन-ट्यूनिंग वर्कफ़्लोज़ को पूरा करता है।

फिर, हम ImageNet डेटासेट पर प्रेटेंड किए गए मॉडल को ले कर, और कागेल "बिल्लियों बनाम कुत्तों" वर्गीकरण डेटासेट पर इसे फिर से दिखाते हुए विशिष्ट वर्कफ़्लो का प्रदर्शन करेंगे।

यह दीप लर्निंग से पाइथन और 2016 के ब्लॉग पोस्ट "बहुत कम डेटा का उपयोग करके शक्तिशाली छवि वर्गीकरण मॉडल का निर्माण" से अनुकूलित है।

बर्फ़ीली परतें: trainable विशेषता को समझना

परत और मॉडल में तीन वजन विशेषताएं होती हैं:

  • weights परत के सभी भार चर की सूची है।
  • trainable_weights उन लोगों की सूची है जिन्हें प्रशिक्षण के दौरान नुकसान को कम करने के लिए अद्यतन (ढाल वंश के माध्यम से) किया जाना है।
  • non_trainable_weights उन लोगों की सूची है जो प्रशिक्षित होने के लिए नहीं हैं। आमतौर पर वे फॉरवर्ड पास के दौरान मॉडल द्वारा अपडेट किए जाते हैं।

उदाहरण: Dense परत में 2 ट्रेन वज़न (कर्नेल और पूर्वाग्रह) हैं

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

सामान्य तौर पर, सभी वजन ट्रेन योग्य वजन होते हैं। एकमात्र निर्मित परत जिसमें गैर-ट्रेन करने योग्य भार है, BatchNormalization परत है। यह प्रशिक्षण के दौरान अपने आदानों के माध्य और विचरण पर नज़र रखने के लिए गैर-प्रशिक्षित वजन का उपयोग करता है। अपने स्वयं के कस्टम परतों में गैर-प्रशिक्षित वजन का उपयोग करने का तरीका जानने के लिए, खरोंच से नई परतें लिखने के लिए मार्गदर्शिका देखें।

उदाहरण: BatchNormalization लेयर में 2 ट्रेन वेट और 2 नॉन-ट्रेनेबल वेट हैं

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

परतें और मॉडल भी एक बूलियन विशेषता सुविधा trainable । इसका मूल्य बदला जा सकता है। स्थापना layer.trainable को False गैर trainable को trainable से सभी परत के वजन ले जाता है। इसे "फ्रीजिंग" परत कहा जाता है: प्रशिक्षण के दौरान एक स्थिर परत की स्थिति को अपडेट नहीं किया जाएगा (या तो जब fit() साथ प्रशिक्षण हो fit() या किसी भी कस्टम लूप के साथ प्रशिक्षण जब ढाल अद्यतन लागू करने के लिए trainable_weights पर निर्भर करता है)।

उदाहरण: False लिए trainable स्थापना

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

जब एक ट्रेन का वजन गैर-प्रशिक्षित हो जाता है, तो प्रशिक्षण के दौरान इसका मूल्य अपडेट नहीं किया जाता है।

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 1ms/step - loss: 0.1275

layer.__call__() training में layer.__call__() training साथ layer.trainable विशेषता को भ्रमित न layer.__call__() (जो यह नियंत्रित करता है कि लेयर अपने फॉरवर्ड पास को इनविज़न मोड या ट्रेनिंग मोड में चलाए)। अधिक जानकारी के लिए, केयर्स FAQ देखें।

trainable विशेषता की पुनरावर्ती सेटिंग

यदि आप किसी मॉडल पर या सबलेयर्स वाली किसी भी लेयर पर trainable = False करने trainable = False सेट करते हैं, तो सभी बच्चों की लेयर्स नॉन-ट्रेनेबल भी हो जाती हैं।

उदाहरण:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

सामान्य हस्तांतरण-अधिगम वर्कफ़्लो

यह हमें बताता है कि कैसेर्स में एक विशिष्ट ट्रांसफर लर्निंग वर्कफ़्लो लागू किया जा सकता है:

  1. एक बेस मॉडल को इंस्टेंट करना और उसमें पहले से प्रशिक्षित वेट लोड करना।
  2. बेस मॉडल में सभी परतों को trainable = False करने trainable = False सेट करके फ्रीज करें।
  3. बेस मॉडल से एक (या कई) परतों के आउटपुट के शीर्ष पर एक नया मॉडल बनाएं।
  4. अपने नए डेटासेट पर अपने नए मॉडल को प्रशिक्षित करें।

ध्यान दें कि एक वैकल्पिक, अधिक हल्के वर्कफ़्लो भी हो सकते हैं:

  1. एक बेस मॉडल को इंस्टेंट करना और उसमें पहले से प्रशिक्षित वज़न लोड करना।
  2. अपने नए डेटासेट को इसके माध्यम से चलाएं और बेस मॉडल से एक (या कई) परतों के आउटपुट को रिकॉर्ड करें। इसे फीचर निष्कर्षण कहा जाता है।
  3. नए, छोटे मॉडल के लिए इनपुट डेटा के रूप में उस आउटपुट का उपयोग करें।

उस दूसरे वर्कफ़्लो का एक मुख्य लाभ यह है कि आप केवल एक बार अपने डेटा पर आधार मॉडल चलाते हैं, बल्कि एक बार प्रशिक्षण के अनुसार। तो यह बहुत तेज़ और सस्ता है।

उस दूसरे वर्कफ़्लो के साथ एक समस्या, हालांकि, यह है कि यह आपको प्रशिक्षण के दौरान अपने नए मॉडल के इनपुट डेटा को गतिशील रूप से संशोधित करने की अनुमति नहीं देता है, जो कि उदाहरण के लिए, डेटा वृद्धि करते समय आवश्यक है। स्थानांतरण सीखने का उपयोग आमतौर पर उन कार्यों के लिए किया जाता है जब आपके नए डेटासेट में खरोंच से पूर्ण पैमाने पर मॉडल को प्रशिक्षित करने के लिए बहुत कम डेटा होता है, और ऐसे परिदृश्यों में डेटा वृद्धि बहुत महत्वपूर्ण है। तो इस प्रकार, हम पहले वर्कफ़्लो पर ध्यान केंद्रित करेंगे।

यहाँ Kirs में पहले वर्कफ़्लो कैसा दिखता है:

सबसे पहले, पूर्व-प्रशिक्षित भार के साथ एक बेस मॉडल को तत्काल करें।

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

फिर, बेस मॉडल को फ्रीज करें।

base_model.trainable = False

शीर्ष पर एक नया मॉडल बनाएं।

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

नए डेटा पर मॉडल को प्रशिक्षित करें।

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

फ़ाइन ट्यूनिंग

एक बार जब आपका मॉडल नए डेटा में परिवर्तित हो जाता है, तो आप बेस मॉडल के सभी या कुछ हिस्सों को अनफ्रीज करने की कोशिश कर सकते हैं और बहुत कम सीखने की दर के साथ पूरे मॉडल को एंड-टू-एंड रिटेन कर सकते हैं।

यह एक वैकल्पिक अंतिम चरण है जो संभावित रूप से आपको वृद्धिशील सुधार दे सकता है। यह संभावित रूप से त्वरित ओवरफिटिंग को भी जन्म दे सकता है - इसे ध्यान में रखें।

जमे हुए परतों वाले मॉडल को अभिसरण के लिए प्रशिक्षित करने के बाद यह कदम उठाना महत्वपूर्ण है। यदि आप पहले से प्रशिक्षित सुविधाओं को रखने वाली ट्रेन की परतों के साथ बेतरतीब ढंग से प्रारंभिक रेल योग्य परतों को मिलाते हैं, तो बेतरतीब ढंग से आरंभ की गई परतें प्रशिक्षण के दौरान बहुत बड़े ग्रेडिएंट अपडेट का कारण बनेंगी, जो आपके पूर्व-प्रशिक्षित सुविधाओं को नष्ट कर देगा।

इस स्तर पर बहुत कम सीखने की दर का उपयोग करना भी महत्वपूर्ण है, क्योंकि आप प्रशिक्षण के पहले दौर की तुलना में बहुत बड़े मॉडल का प्रशिक्षण ले रहे हैं, जो कि आमतौर पर बहुत छोटा होता है। परिणामस्वरूप, यदि आप बड़े वजन अपडेट लागू करते हैं, तो आपको बहुत जल्दी ओवरफिटिंग का खतरा है। यहां, आप केवल एक वृद्धिशील तरीके से प्रिटेंडेड वेट को फिर से पढ़ना चाहते हैं।

यह है कि पूरे बेस मॉडल की फाइन-ट्यूनिंग को कैसे लागू किया जाए:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

compile() और trainable बारे में महत्वपूर्ण नोट

एक मॉडल पर compile() कॉलिंग का मतलब उस मॉडल के व्यवहार को "फ्रीज" करना है। इसका तात्पर्य यह है कि जिस समय मॉडल संकलित किया जाता है उस समय में trainable विशेषता मान उस मॉडल के जीवन भर संरक्षित किया जाना चाहिए, जब तक कि compile फिर से नहीं कहा जाता है। इसलिए, यदि आप किसी भी trainable मूल्य को बदलते हैं, तो अपने परिवर्तनों को ध्यान में रखने के लिए अपने मॉडल पर फिर से compile() कॉल करना सुनिश्चित करें।

BatchNormalization लेयर के बारे में महत्वपूर्ण नोट्स

कई छवि मॉडल में BatchNormalization परतें होती हैं। वह परत हर कल्पनीय गणना पर एक विशेष मामला है। यहां कुछ बातों का ध्यान रखना चाहिए।

  • BatchNormalization में 2 गैर-प्रशिक्षित वजन होते हैं जो प्रशिक्षण के दौरान अद्यतन होते हैं। ये इनपुट के माध्य और विचरण को ट्रैक करने वाले चर हैं।
  • जब आप bn_layer.trainable = False सेट bn_layer.trainable = False , तो BatchNormalization परत BatchNormalization मोड में चलेगी, और इसके माध्य और विचरण के आँकड़े को अपडेट नहीं करेगी। यह सामान्य रूप से अन्य परतों के लिए मामला नहीं है, क्योंकि वजन प्रशिक्षण और अनुमान / प्रशिक्षण मोड दो ऑर्थोगोनल अवधारणाएं हैं । लेकिन दोनों BatchNormalization परत के मामले में बंधे हैं।
  • जब आप ठीक-ट्यूनिंग करने के लिए किसी मॉडल को BatchNormalization लेयर्स से BatchNormalization करते हैं, तो बेस मॉडल को कॉल करते समय आपको training=False पास करके BatchNormalization मोड में BatchNormalization लेयर्स को रखना चाहिए। अन्यथा गैर-प्रशिक्षित वजन के लिए लागू किए गए अपडेट अचानक मॉडल ने जो सीखा है उसे नष्ट कर देगा।

आप इस गाइड को एंड-टू-एंड उदाहरण में इस गाइड के अंत में एक्शन में देखेंगे।

एक कस्टम ट्रेनिंग लूप के साथ लर्निंग और फाइन-ट्यूनिंग ट्रांसफर करें

यदि fit() बजाय, आप अपने स्वयं के निम्न-स्तरीय प्रशिक्षण लूप का उपयोग कर रहे हैं, तो वर्कफ़्लो अनिवार्य रूप से समान है। जब आप model.trainable_weights अद्यतन लागू करते हैं तो आपको केवल सूची model.trainable_weights को ध्यान में रखना चाहिए:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

इसी तरह फाइन-ट्यूनिंग के लिए।

अंत-से-अंत का उदाहरण: बिल्लियों बनाम कुत्तों पर एक छवि वर्गीकरण मॉडल को ठीक करना

डाटासेट

इन अवधारणाओं को ठोस बनाने के लिए, चलिए आपको एक ठोस अंत-टू-एंड हस्तांतरण सीखने और ठीक-ट्यूनिंग उदाहरण के माध्यम से चलते हैं। हम Xception मॉडल को लोड करेंगे, ImageNet पर पूर्व-प्रशिक्षित, और इसे कागले "बिल्लियों बनाम कुत्तों" वर्गीकरण डेटासेट पर उपयोग करें।

डेटा प्राप्त करना

पहले, चलो TFDS का उपयोग करके बिल्लियों बनाम कुत्तों के डेटासेट प्राप्त करते हैं। यदि आपके पास अपना स्वयं का डेटासेट है, तो आप संभवतः क्लास-विशिष्ट फ़ोल्डरों में दर्ज डिस्क पर छवियों के एक सेट से समान लेबल डेटासेट वस्तुओं को उत्पन्न करने के लिए उपयोगिता tf.keras.preprocessing.image_dataset_from_directory का उपयोग करना चाहते हैं।

बहुत छोटे डेटासेट के साथ काम करते समय ट्रांसफर लर्निंग सबसे उपयोगी है। हमारे डेटासेट को छोटा रखने के लिए, हम प्रशिक्षण के लिए मूल प्रशिक्षण डेटा (25,000 चित्र) का 40%, सत्यापन के लिए 10% और परीक्षण के लिए 10% का उपयोग करेंगे।

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteIL7NQA/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

प्रशिक्षण डेटासेट में ये पहले 9 चित्र हैं - जैसा कि आप देख सकते हैं, वे सभी अलग-अलग आकार के हैं।

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

हम यह भी देख सकते हैं कि लेबल 1 "कुत्ता" है और लेबल 0 "बिल्ली" है।

डेटा का मानकीकरण

हमारी कच्ची छवियों में कई प्रकार के आकार होते हैं। इसके अलावा, प्रत्येक पिक्सेल में 0 और 255 (RGB स्तर मान) के बीच 3 पूर्णांक मान होते हैं। यह एक तंत्रिका नेटवर्क को खिलाने के लिए एक महान फिट नहीं है। हमें 2 चीजें करने की जरूरत है:

  • एक निश्चित छवि आकार के लिए मानकीकृत करें। हम 150x150 लेते हैं।
  • -1 और 1. के बीच पिक्सेल मानों को सामान्य करें। हम Normalization परत का उपयोग मॉडल के हिस्से के रूप में ही करेंगे।

सामान्य तौर पर, यह उन मॉडल को विकसित करने के लिए एक अच्छा अभ्यास है जो कच्चे डेटा को इनपुट के रूप में लेते हैं, जैसा कि उन मॉडल के विपरीत है जो पहले से ही प्रीप्रोसेस किए गए डेटा लेते हैं। इसका कारण यह है कि, यदि आपका मॉडल प्रीप्रोसेड डेटा की अपेक्षा करता है, तो किसी भी समय आप अपने मॉडल को किसी अन्य जगह (एक वेब ब्राउज़र में, मोबाइल ऐप में) उपयोग करने के लिए निर्यात करते हैं, आपको सटीक समान प्रीप्रोसेसिंग पाइपलाइन को फिर से लागू करना होगा। यह बहुत जल्दी बहुत मुश्किल हो जाता है। इसलिए हमें मॉडल को मारने से पहले प्रीप्रोसेसिंग की कम से कम संभव मात्रा में करना चाहिए।

यहां, हम डेटा पाइपलाइन में आकार बदलने वाली छवि करेंगे (क्योंकि एक गहरे तंत्रिका नेटवर्क केवल डेटा के सन्निहित बैचों को संसाधित कर सकते हैं), और हम मॉडल के भाग के रूप में इनपुट मान स्केलिंग करेंगे, जब हम इसे बनाते हैं।

आइए छवियों का आकार 150x150 करें:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

इसके अलावा, चलो डेटा को बैचते हैं और लोडिंग की गति को अनुकूलित करने के लिए कैशिंग और प्रीफेटिंग का उपयोग करते हैं।

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

यादृच्छिक डेटा वृद्धि का उपयोग करना

जब आपके पास एक बड़ी छवि डेटासेट नहीं है, तो प्रशिक्षण छवियों में यादृच्छिक अभी तक यथार्थवादी परिवर्तनों को लागू करके कृत्रिम विविधता का नमूना पेश करना एक अच्छा अभ्यास है, जैसे कि यादृच्छिक क्षैतिज फ़्लिपिंग या छोटे यादृच्छिक घुमाव। यह ओवरफिटिंग को धीमा करते हुए मॉडल को प्रशिक्षण डेटा के विभिन्न पहलुओं को उजागर करने में मदद करता है।

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

आइए देखें कि विभिन्न यादृच्छिक परिवर्तनों के बाद पहले बैच की पहली छवि कैसी दिखती है:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[i]))
        plt.axis("off")

png

मॉडल बनाना

अब हमने एक मॉडल बनाया है जो पहले बताए गए ब्लूप्रिंट का अनुसरण करता है।

ध्यान दें कि:

  • हम इनपुट मानों को स्केल करने के लिए एक Normalization लेयर (शुरुआत में [0, 255] रेंज में) को [-1, 1] रेंज में जोड़ते हैं।
  • हम वर्गीकरण परत से पहले Dropout परत जोड़ते हैं, नियमितीकरण के लिए।
  • हम बेस मॉडल को कॉल करते समय training=False से पास करना सुनिश्चित करते हैं, ताकि यह एक इनवेंशन मोड में चले, ताकि हम ठीक-ट्यूनिंग के लिए बेस मॉडल को अनफ्रीज करने के बाद भी बैटमैन आंकड़े अपडेट न हों।
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

शीर्ष परत को प्रशिक्षित करें

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 9s 32ms/step - loss: 0.1758 - binary_accuracy: 0.9226 - val_loss: 0.0897 - val_binary_accuracy: 0.9660
Epoch 2/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1211 - binary_accuracy: 0.9497 - val_loss: 0.0870 - val_binary_accuracy: 0.9686
Epoch 3/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1166 - binary_accuracy: 0.9503 - val_loss: 0.0814 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1125 - binary_accuracy: 0.9534 - val_loss: 0.0825 - val_binary_accuracy: 0.9695
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1073 - binary_accuracy: 0.9569 - val_loss: 0.0763 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1041 - binary_accuracy: 0.9573 - val_loss: 0.0812 - val_binary_accuracy: 0.9686
Epoch 7/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1023 - binary_accuracy: 0.9567 - val_loss: 0.0820 - val_binary_accuracy: 0.9669
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1005 - binary_accuracy: 0.9597 - val_loss: 0.0779 - val_binary_accuracy: 0.9695
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9580 - val_loss: 0.0813 - val_binary_accuracy: 0.9699
Epoch 10/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0940 - binary_accuracy: 0.9651 - val_loss: 0.0762 - val_binary_accuracy: 0.9729
Epoch 11/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0974 - binary_accuracy: 0.9613 - val_loss: 0.0752 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0965 - binary_accuracy: 0.9591 - val_loss: 0.0760 - val_binary_accuracy: 0.9721
Epoch 13/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0962 - binary_accuracy: 0.9598 - val_loss: 0.0785 - val_binary_accuracy: 0.9712
Epoch 14/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0966 - binary_accuracy: 0.9616 - val_loss: 0.0831 - val_binary_accuracy: 0.9699
Epoch 15/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1000 - binary_accuracy: 0.9574 - val_loss: 0.0741 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0940 - binary_accuracy: 0.9628 - val_loss: 0.0781 - val_binary_accuracy: 0.9686
Epoch 17/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0915 - binary_accuracy: 0.9634 - val_loss: 0.0843 - val_binary_accuracy: 0.9678
Epoch 18/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0937 - binary_accuracy: 0.9620 - val_loss: 0.0829 - val_binary_accuracy: 0.9669
Epoch 19/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0988 - binary_accuracy: 0.9601 - val_loss: 0.0862 - val_binary_accuracy: 0.9686
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0928 - binary_accuracy: 0.9644 - val_loss: 0.0798 - val_binary_accuracy: 0.9703

<tensorflow.python.keras.callbacks.History at 0x7f6104f04518>

पूरे मॉडल की ठीक-ठीक ट्यूनिंग करें

अंत में, चलो बेस मॉडल को अनफ्रीज करते हैं और कम सीखने की दर के साथ पूरे मॉडल को एंड-टू-एंड ट्रेन करते हैं।

महत्वपूर्ण रूप से, हालांकि बेस मॉडल ट्रेन योग्य हो जाता है, यह अभी भी एक अनुमान मोड में चल रहा है जब से हमने training=False पारित training=False मॉडल बनाते समय इसे कॉल करते समय training=False । इसका मतलब यह है कि बैच के सामान्यीकरण की परतें अपने बैच के आंकड़ों को अपडेट नहीं करेंगी। यदि वे करते हैं, तो वे अब तक मॉडल द्वारा सीखे गए अभ्यावेदन पर कहर ढाएंगे।

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
  2/291 [..............................] - ETA: 17s - loss: 0.1439 - binary_accuracy: 0.9219WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

291/291 [==============================] - 38s 132ms/step - loss: 0.0786 - binary_accuracy: 0.9706 - val_loss: 0.0631 - val_binary_accuracy: 0.9772
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0553 - binary_accuracy: 0.9790 - val_loss: 0.0537 - val_binary_accuracy: 0.9781
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0442 - binary_accuracy: 0.9829 - val_loss: 0.0532 - val_binary_accuracy: 0.9819
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0369 - binary_accuracy: 0.9858 - val_loss: 0.0460 - val_binary_accuracy: 0.9832
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0335 - binary_accuracy: 0.9870 - val_loss: 0.0561 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0253 - binary_accuracy: 0.9910 - val_loss: 0.0559 - val_binary_accuracy: 0.9819
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0232 - binary_accuracy: 0.9920 - val_loss: 0.0432 - val_binary_accuracy: 0.9845
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0185 - binary_accuracy: 0.9930 - val_loss: 0.0396 - val_binary_accuracy: 0.9854
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0147 - binary_accuracy: 0.9948 - val_loss: 0.0439 - val_binary_accuracy: 0.9832
Epoch 10/10
291/291 [==============================] - 37s 129ms/step - loss: 0.0117 - binary_accuracy: 0.9954 - val_loss: 0.0538 - val_binary_accuracy: 0.9819

<tensorflow.python.keras.callbacks.History at 0x7f611c26e438>

10 युगों के बाद, ठीक-ट्यूनिंग हमें यहाँ एक अच्छा सुधार हासिल करता है।