امروز برای رویداد محلی TensorFlow خود در همه جا پاسخ دهید!
این صفحه به‌وسیله ‏Cloud Translation API‏ ترجمه شده است.
Switch to English

ایست بازرسی آموزش

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

عبارت "صرفه جویی در مدل TensorFlow" به طور معمول یکی از دو معنی است:

  1. ایست بازرسی ، یا
  2. مدل ذخیره شده

tf.Variable بازرسی مقدار دقیق تمام پارامترها (اشیا tf.Variable ) استفاده شده توسط یک مدل را ضبط می کند. ایست های بازرسی هیچ توصیفی از محاسبه تعریف شده توسط مدل ندارند و بنابراین فقط زمانی مفید هستند که کد منبع ای که از مقادیر پارامتر ذخیره شده در دسترس است ، در دسترس باشد.

از طرف دیگر قالب SavedModel شامل توصیف سریالی محاسبات تعریف شده توسط مدل علاوه بر مقادیر پارامتر (ایست بازرسی) است. مدل ها در این قالب مستقل از کد منبع ایجاد کننده مدل هستند. بنابراین آنها برای استقرار از طریق TensorFlow Serving ، TensorFlow Lite ، TensorFlow.js یا برنامه هایی به زبانهای برنامه نویسی دیگر (C ، C ++ ، Java ، Go ، Rust ، C # و غیره TensorFlow API) مناسب هستند.

این راهنما شامل API های نوشتن و خواندن ایست های بازرسی است.

برپایی

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

در حال ذخیره از API های آموزش tf.keras

به راهنمای tf.keras در مورد صرفه جویی و بازیابی مراجعه کنید.

tf.keras.Model.save_weights یک ایست بازرسی tf.keras.Model.save_weights ذخیره می کند.

net.save_weights('easy_checkpoint')

نوشتن ایست بازرسی

حالت ماندگار یک مدل tf.Variable در اشیا tf.Variable ذخیره می شود. اینها می توانند به طور مستقیم ساخته شوند ، اما اغلب از طریق API های سطح بالا مانندtf.keras.layers یا tf.keras.Model .

ساده ترین راه برای مدیریت متغیرها پیوستن آنها به اشیا Py پایتون و سپس ارجاع دادن به آن اشیا است.

زیر کلاس های tf.train.Checkpoint ، tf.keras.layers.Layer و tf.keras.Model متغیرهای اختصاص داده شده به صفات آنها را به طور خودکار ردیابی می کند. مثال زیر یک مدل خطی ساده می سازد ، سپس ایستگاه های بازرسی را می نویسد که حاوی مقادیر برای همه متغیرهای مدل هستند.

با Model.save_weights به راحتی می توانید یک ایستگاه بازرسی مدل را ذخیره کنید.

بازرسی دستی

برپایی

برای کمک به نمایش تمام ویژگی های tf.train.Checkpoint ، یک مجموعه داده اسباب بازی و مرحله بهینه سازی را تعریف کنید:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

اشیاpoint ایست بازرسی ایجاد کنید

برای ایجاد دستی یک ایست بازرسی از یک شی tf.train.Checkpoint استفاده کنید ، جایی که اشیا you می خواهید به عنوان صفات موجود در شی تنظیم شوند.

tf.train.CheckpointManager همچنین می تواند برای مدیریت چندین ایست بازرسی مفید باشد.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

مدل را آموزش داده و بازرسی کنید

حلقه آموزش زیر نمونه ای از مدل و یک بهینه ساز را ایجاد می کند ، سپس آنها را در یک شی tf.train.Checkpoint کند. این مرحله آموزش را به صورت حلقه ای روی هر دسته از داده ها فراخوانی می کند و به صورت دوره ای ایستگاه های بازرسی را بر روی دیسک می نویسد.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.42
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.83
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 17.27
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.81
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.74

بازیابی و ادامه آموزش

بعد از اولین چرخه آموزش می توانید از یک مدل و مدیر جدید عبور کنید ، اما آموزش را دقیقاً همان جایی که ترک کردید انتخاب کنید:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.85
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.87
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.71
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.46
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.21

شی tf.train.CheckpointManager ایست های بازرسی قدیمی را حذف می کند. در بالای آن پیکربندی شده است تا فقط سه ایستگاه بازرسی اخیر را نگه دارد.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

این مسیرها ، به عنوان مثال './tf_ckpts/ckpt-10' ، پرونده هایی روی دیسک نیستند. در عوض آنها پیشوندهای یک پرونده index و یک یا چند پرونده داده هستند که حاوی مقادیر متغیر هستند. این پیشوندها در یک پرونده checkpoint ( './tf_ckpts/checkpoint' ) با هم گروه می شوند که CheckpointManager وضعیت خود را ذخیره می کند.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

مکانیک بارگیری

TensorFlow با عبور از یک نمودار جهت دار با لبه های نامگذاری شده ، متغیرها را با مقادیر کنترل شده مطابقت می دهد ، از شی object در حال بارگیری. نام های لبه معمولاً از اسم ویژگی ها در اشیا ناشی می شوند ، به عنوان مثال "l1" در self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint از نام آرگومان کلمات کلیدی خود استفاده می کند ، مانند "step" در tf.train.Checkpoint(step=...) .

نمودار وابستگی مثال بالا به صورت زیر است:

تجسم نمودار وابستگی برای مثال حلقه آموزش

بهینه ساز با رنگ قرمز ، متغیرهای معمولی با رنگ آبی و متغیرهای شکاف بهینه با نارنجی هستند. گره های دیگر - به عنوان مثال ، نمایانگر tf.train.Checkpoint - به رنگ سیاه هستند.

متغیرهای اسلات بخشی از حالت بهینه ساز هستند ، اما برای یک متغیر خاص ایجاد می شوند. به عنوان مثال لبه های 'm' بالا مربوط به حرکت است ، که بهینه ساز Adam برای هر متغیر دنبال می کند. متغیرهای اسلات فقط در یک ایست بازرسی ذخیره می شوند اگر متغیر و بهینه ساز هر دو ذخیره شوند ، بنابراین لبه های خط خورده.

فراخوانی restore در یک tf.train.Checkpoint در ترمیم های درخواستی در صف قرار می گیرد و به محض وجود مسیر تطبیق از شی Checkpoint ، مقادیر متغیر را بازیابی می کند. به عنوان مثال ، می توانید با بازسازی یک مسیر به آن از طریق شبکه و لایه ، فقط تعصب مدلی را که در بالا تعریف کردید بارگیری کنید.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.831489  3.7156947 2.5892444 3.8669944 4.749503 ]

نمودار وابستگی برای این اشیا new جدید ، یک زیرگراف بسیار کوچکتر از ایست بازرسی بزرگتری است که در بالا نوشتید. این فقط شامل بایاس و یک شمارنده ذخیره است که tf.train.Checkpoint برای شماره گذاری ایست های بازرسی استفاده می کند.

تجسم یک زیرنویس برای متغیر بایاس

restore یک وضعیت وضعیت را برمی گرداند ، که ادعاهای اختیاری دارد. همه اشیا created ایجاد شده در Checkpoint جدید بازیابی شده اند ، بنابراین status.assert_existing_objects_matched پاس می دهد.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1644447b70>

بسیاری از اشیا many در ایست بازرسی وجود دارد که با هم مطابقت ندارند ، از جمله هسته لایه و متغیرهای بهینه ساز. status.assert_consumed فقط درصورتی که ایست بازرسی و برنامه دقیقاً مطابقت داشته باشد عبور می کند و یک استثنا را در اینجا قرار می دهد.

ترمیم های تأخیری

اشیا L Layer در TensorFlow ممکن است ایجاد متغیرها را در اولین تماس خود تأخیر دهند ، در صورت وجود اشکال ورودی. به عنوان مثال شکل هسته لایه Dense به شکل ورودی و خروجی لایه بستگی دارد و بنابراین شکل خروجی مورد نیاز به عنوان آرگومان سازنده اطلاعات کافی برای ایجاد متغیر به تنهایی نیست. از آنجا که فراخوانی یک Layer مقدار متغیر را نیز می خواند ، بازسازی باید بین ایجاد متغیر و اولین استفاده از آن انجام شود.

برای پشتیبانی از این اصطلاح ، صف های tf.train.Checkpoint بازیابی می کند که هنوز متغیر مطابق ندارند.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5719748 4.6099544 4.931875  4.836442  4.8496275]]

بازرسی دستی از ایست های بازرسی

tf.train.load_checkpoint یک CheckpointReader برمی گرداند که دسترسی سطح پایین تری را به محتوای ایست بازرسی می دهد. این شامل نگاشت هایی از هر کلید قابل تغییر به شکل و نوع نوع هر متغیر در ایست بازرسی است. کلید یک متغیر مسیر شی its آن است ، مانند نمودارهای نمایش داده شده در بالا.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

بنابراین اگر به ارزش net.l1.kernel علاقه مند هستید ، می توانید مقدار آن را با کد زیر بدست آورید:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

این همچنین یک روش get_tensor فراهم می کند که به شما امکان می دهد مقدار یک متغیر را بررسی کنید:

reader.get_tensor(key)
array([[4.5719748, 4.6099544, 4.931875 , 4.836442 , 4.8496275]],
      dtype=float32)

ردیابی لیست و فرهنگ لغت

همانند انتساب ویژگی های مستقیم مانند self.l1 = tf.keras.layers.Dense(5) ، اختصاص لیست ها و فرهنگ نامه ها به ویژگی ها ، محتوای آنها را ردیابی می کند.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

ممکن است اشیاpper بسته بندی شده را برای لیست ها و دیکشنری ها مشاهده کنید. این بسته بندی ها نسخه های قابل بازبینی از ساختارهای داده ای اساسی هستند. درست مانند بارگذاری بر اساس ویژگی ، این بسته بندی ها به محض اینکه یک متغیر به ظرف اضافه شود مقدار آن را بازیابی می کنند.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

همان ردیابی به طور خودکار در زیر کلاس های tf.keras.Model اعمال می شود ، و ممکن است به عنوان مثال برای ردیابی لیست لایه ها استفاده شود.

خلاصه

اشیا T TensorFlow مکانیسم خودکار آسان برای صرفه جویی و بازیابی مقادیر متغیرهایی را که استفاده می کنند ، فراهم می کنند.