سخنرانی ها ، جلسات محصول ، کارگاه ها و موارد دیگر را از لیست پخش Google I / O مشاهده کنید

ایست بازرسی آموزش

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

عبارت "صرفه جویی در مدل TensorFlow" به طور معمول یکی از دو معنی است:

  1. ایست بازرسی ، یا
  2. مدل ذخیره شده

tf.Variable بازرسی مقدار دقیق همه پارامترها (اشیا tf.Variable ) استفاده شده توسط یک مدل را ضبط می کند. ایست های بازرسی هیچ توصیفی از محاسبه تعریف شده توسط مدل ندارند و بنابراین فقط زمانی مفید هستند که کد منبع که از مقادیر پارامتر ذخیره شده در دسترس است ، در دسترس باشد.

از طرف دیگر قالب SavedModel شامل توصیف سریالی محاسبات است که علاوه بر مقادیر پارامتر (ایست بازرسی) توسط مدل تعریف شده است. مدل ها در این قالب مستقل از کد منبع ایجاد کننده مدل هستند. بنابراین آنها برای استقرار از طریق TensorFlow Serving ، TensorFlow Lite ، TensorFlow.js یا برنامه هایی به زبان های برنامه نویسی دیگر (C ، C ++ ، Java ، Go ، Rust ، C # و غیره TensorFlow API) مناسب هستند.

این راهنما شامل API های نوشتن و خواندن ایست های بازرسی است.

برپایی

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

در حال ذخیره از API های آموزشی tf.keras

به راهنمای tf.keras در مورد صرفه جویی و بازیابی مراجعه کنید.

tf.keras.Model.save_weights یک ایست بازرسی tf.keras.Model.save_weights ذخیره می کند.

net.save_weights('easy_checkpoint')

نوشتن ایست بازرسی

حالت پایدار مدل tf.Variable در اشیا tf.Variable ذخیره می شود. اینها می توانند به طور مستقیم ساخته شوند ، اما اغلب از طریق API های سطح بالا مانندtf.keras.layers یا tf.keras.Model .

ساده ترین راه برای مدیریت متغیرها پیوستن آنها به اشیا Py پایتون و سپس ارجاع دادن به آن اشیا است.

زیر کلاس های tf.train.Checkpoint ، tf.keras.layers.Layer و tf.keras.Model متغیرهای اختصاص داده شده به صفات آنها را به طور خودکار ردیابی می کند. مثال زیر یک مدل خطی ساده را می سازد ، سپس ایستگاه های بازرسی را می نویسد که حاوی مقادیر برای تمام متغیرهای مدل هستند.

با Model.save_weights به راحتی می توانید یک ایستگاه بازرسی مدل را ذخیره کنید.

بازرسی دستی

برپایی

برای کمک به نمایش تمام ویژگی های tf.train.Checkpoint ، یک مجموعه داده اسباب بازی و مرحله بهینه سازی را تعریف کنید:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

اشیاpoint ایست بازرسی ایجاد کنید

از یک شی tf.train.Checkpoint برای ایجاد دستی یک ایست بازرسی استفاده کنید ، جایی که اشیا you می خواهید به عنوان ویژگی روی شی تنظیم شوند.

tf.train.CheckpointManager همچنین می تواند برای مدیریت چندین ایست بازرسی مفید باشد.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

مدل را آموزش داده و بازرسی کنید

حلقه آموزش زیر نمونه ای از مدل و بهینه ساز را ایجاد می کند ، سپس آنها را در یک شی tf.train.Checkpoint کند. این مرحله آموزش را به صورت حلقه ای روی هر دسته از داده ها فراخوانی می کند و به صورت دوره ای ایستگاه های بازرسی را بر روی دیسک می نویسد.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.00
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 22.42
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.86
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 9.40
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.20

بازیابی و آموزش را ادامه دهید

بعد از اولین چرخه آموزش می توانید از یک مدل و مدیر جدید عبور کنید ، اما آموزش را دقیقاً همان جایی که ترک کردید انتخاب کنید:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.19
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.66
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.90
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.32
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.34

شی tf.train.CheckpointManager ایست های بازرسی قدیمی را حذف می کند. در بالای آن پیکربندی شده است تا فقط سه ایستگاه بازرسی اخیر را نگه دارد.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

این مسیرها ، به عنوان مثال './tf_ckpts/ckpt-10' ، پرونده هایی روی دیسک نیستند. در عوض آنها پیشوندهای یک پرونده index و یک یا چند پرونده داده هستند که حاوی مقادیر متغیر هستند. این پیشوندها در یک پرونده checkpoint ( './tf_ckpts/checkpoint' ) با هم گروه می شوند که CheckpointManager وضعیت خود را ذخیره می کند.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

مکانیک بارگیری

TensorFlow با عبور از یک نمودار جهت دار با لبه های نامگذاری شده ، متغیرها را با مقادیر کنترل شده مطابقت می دهد ، از شی object بارگیری شده نام های لبه معمولاً از نام ویژگی ها در اشیا ناشی می شوند ، به عنوان مثال "l1" در self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint از نام آرگومان کلمات کلیدی خود استفاده می کند ، مانند "step" در tf.train.Checkpoint(step=...) .

نمودار وابستگی مثال بالا به صورت زیر است:

تجسم نمودار وابستگی برای مثال حلقه آموزش

بهینه ساز با رنگ قرمز ، متغیرهای معمولی با رنگ آبی و متغیرهای شکاف بهینه با نارنجی هستند. گره های دیگر - به عنوان مثال ، نمایانگر tf.train.Checkpoint - به رنگ سیاه هستند.

متغیرهای اسلات بخشی از حالت بهینه ساز هستند ، اما برای یک متغیر خاص ایجاد می شوند. به عنوان مثال لبه های 'm' بالا مربوط به حرکت است ، که بهینه ساز Adam برای هر متغیر دنبال می کند. متغیرهای اسلات فقط در یک ایست بازرسی ذخیره می شوند در صورتی که متغیر و بهینه ساز هر دو ذخیره شوند ، بنابراین لبه های خط خورده.

فراخوانی restore در یک tf.train.Checkpoint در ترمیم های درخواستی در صف قرار می گیرد و مقادیر متغیر را به محض وجود مسیر تطبیق از شی Checkpoint ، بازیابی می کند. به عنوان مثال ، می توانید با بازسازی یک مسیر به آن از طریق شبکه و لایه ، فقط تعصب مدلی را که در بالا تعریف کردید بارگیری کنید.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.2704186 3.0526643 3.8114467 3.4453893 4.2802196]

نمودار وابستگی برای این اشیا new جدید ، یک زیرگراف بسیار کوچکتر از ایست بازرسی بزرگتری است که در بالا نوشتید. این فقط شامل بایاس و یک شمارنده ذخیره است که tf.train.Checkpoint برای شماره گذاری ایست های بازرسی استفاده می کند.

تجسم یک زیرنویس برای متغیر بایاس

restore یک وضعیت وضعیت را برمی گرداند ، که ادعاهای اختیاری دارد. همه اشیا created ایجاد شده در Checkpoint جدید بازیابی شده اند ، بنابراین status.assert_existing_objects_matched پاس می دهد.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2a4cbccb38>

اشیا زیادی در ایست بازرسی وجود دارد که با هم مطابقت ندارند ، از جمله هسته لایه و متغیرهای بهینه ساز. status.assert_consumed فقط درصورتی که ایست بازرسی و برنامه دقیقاً مطابقت داشته باشد عبور می کند و یک استثنا را در اینجا قرار می دهد.

ترمیم های تأخیری

اشیا L Layer در TensorFlow ممکن است ایجاد متغیرها را در اولین تماس خود تأخیر دهند ، در صورت وجود اشکال ورودی. به عنوان مثال شکل هسته لایه Dense به شکل ورودی و خروجی لایه بستگی دارد و بنابراین شکل خروجی مورد نیاز به عنوان آرگومان سازنده اطلاعات کافی برای ایجاد متغیر به تنهایی نیست. از آنجا که فراخوانی یک Layer مقدار متغیر را نیز می خواند ، بازسازی باید بین ایجاد متغیر و اولین استفاده از آن انجام شود.

برای پشتیبانی از این اصطلاح ، صف های tf.train.Checkpoint بازیابی می کند که هنوز متغیر مطابق ندارند.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6544    4.6866627 4.729344  4.9574785 4.8010526]]

بازرسی دستی از ایست های بازرسی

tf.train.load_checkpoint یک CheckpointReader برمی گرداند که دسترسی سطح پایین تری را به محتوای ایست بازرسی می دهد. این شامل نگاشت هایی از کلید هر متغیر ، به شکل و نوع نوع هر متغیر در ایست بازرسی است. کلید یک متغیر مسیر شی its آن است ، مانند نمودارهای نمایش داده شده در بالا.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

بنابراین اگر به ارزش net.l1.kernel علاقه مند هستید ، می توانید مقدار آن را با کد زیر بدست آورید:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

این همچنین یک روش get_tensor فراهم می کند که به شما امکان می دهد مقدار یک متغیر را بررسی کنید:

reader.get_tensor(key)
array([[4.6544   , 4.6866627, 4.729344 , 4.9574785, 4.8010526]],
      dtype=float32)

ردیابی فهرست و فرهنگ لغت

همانند انتساب ویژگی های مستقیم مانند self.l1 = tf.keras.layers.Dense(5) ، اختصاص لیست ها و فرهنگ لغت ها به ویژگی ها ، محتوای آنها را ردیابی می کند.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

ممکن است اشیا wra بسته بندی شده را برای لیست ها و دیکشنری ها مشاهده کنید. این بسته بندی ها نسخه های قابل بازبینی از ساختارهای داده ای اساسی هستند. درست مانند بارگذاری بر اساس ویژگی ، این بسته بندی ها به محض اینکه یک متغیر به ظرف اضافه شود مقدار آن را بازیابی می کنند.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

همان ردیابی به طور خودکار در زیر کلاس های tf.keras.Model اعمال می شود و ممکن است به عنوان مثال برای ردیابی لیست لایه ها استفاده شود.

خلاصه

اشیا T TensorFlow مکانیسم خودکار آسان برای صرفه جویی و بازیابی مقادیر متغیرهایی را که استفاده می کنند ، فراهم می کنند.