امروز برای رویداد محلی TensorFlow خود در همه جا پاسخ دهید!
این صفحه به‌وسیله ‏Cloud Translation API‏ ترجمه شده است.
Switch to English

آموزش شبکه عصبی در MNIST با کراس

این مثال ساده نحوه اتصال TFDS به مدل Keras را نشان می دهد.

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت
import tensorflow as tf
import tensorflow_datasets as tfds

مرحله 1: خط لوله ورودی خود را ایجاد کنید

خط لوله ورودی کارآمد را با استفاده از توصیه های زیر ایجاد کنید:

MNIST را بارگیری کنید

با استدلال های زیر بارگیری کنید:

  • shuffle_files : داده های MNIST فقط در یک فایل ذخیره می شوند ، اما برای مجموعه داده های بزرگتر با چندین پرونده روی دیسک ، روش خوبی است که آنها را هنگام آموزش مرتب کنید.
  • as_supervised : به جای {'image': img, 'label': label} tuple (img, label) as_supervised گرداند
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

خط لوله آموزش بسازید

تغییر شکل زیر را اعمال کنید:

  • ds.map : TFDS تصاویر را به صورت tf.uint8 ارائه می دهد ، در حالی که مدل انتظار tf.float32 را دارد ، بنابراین تصاویر را عادی کنید
  • ds.cache همانطور که مجموعه داده در حافظه جا می گیرد ، قبل از اینکه برای عملکرد بهتر آنرا مخفی کنید ، حافظه پنهان را ذخیره کنید.
    توجه: تحولات تصادفی باید پس از ذخیره انجام شود
  • ds.shuffle : برای تصادفی واقعی ، بافر shuffle را روی اندازه کامل مجموعه داده تنظیم کنید.
    توجه: برای مجموعه داده های بزرگتر که در حافظه جای نمی گیرند ، اگر سیستم شما اجازه دهد ، مقدار استاندارد 1000 است.
  • ds.batch : دسته ای پس از تغییر وضعیت برای بدست آوردن دسته های منحصر به فرد در هر دوره.
  • ds.prefetch : روش خوبی برای پایان دادن به خط لوله با پیش تنظیم برای اجرا است .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

خط لوله ارزیابی بسازید

آزمایش خط لوله شبیه خط لوله آموزش است ، با تفاوت های کوچک:

  • تماس ds.shuffle()
  • ذخیره سازی پس از دسته بندی انجام می شود (زیرا دسته ها می توانند بین دوره یکسان باشند)
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

مرحله 2: مدل را ایجاد و آموزش دهید

خط لوله ورودی را به Keras وصل کنید.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 4s 4ms/step - loss: 0.6240 - sparse_categorical_accuracy: 0.8288 - val_loss: 0.2043 - val_sparse_categorical_accuracy: 0.9424
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1796 - sparse_categorical_accuracy: 0.9499 - val_loss: 0.1395 - val_sparse_categorical_accuracy: 0.9598
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1215 - sparse_categorical_accuracy: 0.9642 - val_loss: 0.1137 - val_sparse_categorical_accuracy: 0.9678
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0968 - sparse_categorical_accuracy: 0.9724 - val_loss: 0.0974 - val_sparse_categorical_accuracy: 0.9707
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0774 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0852 - val_sparse_categorical_accuracy: 0.9766
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0631 - sparse_categorical_accuracy: 0.9811 - val_loss: 0.0868 - val_sparse_categorical_accuracy: 0.9735

<tensorflow.python.keras.callbacks.History at 0x7f70782baa58>