Google I / O 18-20 मई को लौटता है! जगह आरक्षित करें और अपना शेड्यूल बनाएं अभी रजिस्टर करें
इस पेज का अनुवाद Cloud Translation API से किया गया है.
Switch to English

केरस के साथ MNIST पर एक तंत्रिका नेटवर्क का प्रशिक्षण

यह सरल उदाहरण प्रदर्शित करता है कि टीएफडीएस को केरस मॉडल में कैसे प्लग किया जाए।

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें
import tensorflow as tf
import tensorflow_datasets as tfds

चरण 1: अपनी इनपुट पाइपलाइन बनाएं

निम्नलिखित से सलाह लेकर कुशल इनपुट पाइपलाइन का निर्माण करें:

MNIST लोड करें

निम्नलिखित तर्कों के साथ लोड करें:

  • shuffle_files : MNIST डेटा केवल एक फ़ाइल में संग्रहीत होता है, लेकिन डिस्क पर एकाधिक फ़ाइलों के साथ बड़े डेटासेट के लिए, प्रशिक्षण के दौरान उन्हें फेरबदल करना अच्छा होता है।
  • as_supervised : रिटर्न टपल (img, label) dict के बजाय {'image': img, 'label': label}
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

प्रशिक्षण पाइपलाइन का निर्माण

निम्नलिखित रूपांतरण लागू करें:

  • ds.map : TFDS छवियों को tf.uint8 के रूप में प्रदान करते हैं, जबकि मॉडल tf.float32 की अपेक्षा करता है, इसलिए छवियों को सामान्य करता है
  • ds.cache बेहतर प्रदर्शन के लिए फेरबदल से पहले डेटासेट में मेमोरी फिट होती है।
    नोट: कैशिंग के बाद रैंडम ट्रांसफॉर्मेशन लागू किया जाना चाहिए
  • ds.shuffle : सत्य यादृच्छिकता के लिए, संपूर्ण डेटासेट आकार में फेरबदल बफ़र सेट करें।
    नोट: बड़े डेटासेट के लिए जो मेमोरी में फिट नहीं होते हैं, एक मानक मान 1000 है यदि आपका सिस्टम इसे अनुमति देता है।
  • ds.batch : प्रत्येक युग में अद्वितीय बैचों को प्राप्त करने के लिए फेरबदल के बाद बैच।
  • ds.prefetch : प्रदर्शन के लिए ds.prefetch द्वारा पाइपलाइन को समाप्त करने का अच्छा अभ्यास।
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

मूल्यांकन पाइपलाइन बनाएँ

परीक्षण पाइपलाइन प्रशिक्षण पाइपलाइन के समान है, जिसमें छोटे अंतर हैं:

  • कोई ds.shuffle() कॉल नहीं
  • बैचिंग के बाद कैशिंग किया जाता है (चूंकि बैच युग के बीच समान हो सकते हैं)
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

चरण 2: मॉडल बनाएं और प्रशिक्षित करें

करस में इनपुट पाइपलाइन प्लग करें।

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 4s 4ms/step - loss: 0.6240 - sparse_categorical_accuracy: 0.8288 - val_loss: 0.2043 - val_sparse_categorical_accuracy: 0.9424
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1796 - sparse_categorical_accuracy: 0.9499 - val_loss: 0.1395 - val_sparse_categorical_accuracy: 0.9598
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1215 - sparse_categorical_accuracy: 0.9642 - val_loss: 0.1137 - val_sparse_categorical_accuracy: 0.9678
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0968 - sparse_categorical_accuracy: 0.9724 - val_loss: 0.0974 - val_sparse_categorical_accuracy: 0.9707
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0774 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0852 - val_sparse_categorical_accuracy: 0.9766
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0631 - sparse_categorical_accuracy: 0.9811 - val_loss: 0.0868 - val_sparse_categorical_accuracy: 0.9735
<tensorflow.python.keras.callbacks.History at 0x7f70782baa58>