प्रशिक्षण चौकियों

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

वाक्यांश "एक TensorFlow मॉडल सहेजना" आमतौर पर दो चीजों में से एक का अर्थ है:

  1. चौकियों, OR
  2. सहेजा गया मॉडल।

चेकपॉइंट एक मॉडल द्वारा उपयोग किए जाने वाले सभी मापदंडों ( tf.Variable ) के सटीक मान को कैप्चर करते हैं। चेकपॉइंट्स में मॉडल द्वारा परिभाषित गणना का कोई विवरण नहीं होता है और इस प्रकार आमतौर पर केवल तभी उपयोगी होते हैं जब स्रोत कोड जो सहेजे गए पैरामीटर मानों का उपयोग करेगा उपलब्ध है।

दूसरी ओर सेव्डमॉडल प्रारूप में पैरामीटर मान (चेकपॉइंट) के अलावा मॉडल द्वारा परिभाषित गणना का क्रमबद्ध विवरण शामिल है। इस प्रारूप में मॉडल मॉडल बनाने वाले स्रोत कोड से स्वतंत्र होते हैं। इस प्रकार वे TensorFlow Serving, TensorFlow Lite, TensorFlow.js, या अन्य प्रोग्रामिंग भाषाओं में प्रोग्राम (C, C++, Java, Go, Rust, C# आदि) TensorFlow APIs के माध्यम से परिनियोजन के लिए उपयुक्त हैं।

इस गाइड में चौकियों को लिखने और पढ़ने के लिए एपीआई शामिल हैं।

सेट अप

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras प्रशिक्षण API से सहेजा जा रहा है

सहेजने और पुनर्स्थापित करने पर tf.keras मार्गदर्शिका देखें।

tf.keras.Model.save_weights एक TensorFlow चेकपॉइंट बचाता है।

net.save_weights('easy_checkpoint')

लेखन चौकियां

TensorFlow मॉडल की स्थिर स्थिति को tf.Variable ऑब्जेक्ट्स में संग्रहित किया जाता है। इन्हें सीधे बनाया जा सकता है, लेकिन अक्सर tf.keras.layers या tf.keras.Model जैसे उच्च-स्तरीय API के माध्यम से बनाए जाते हैं।

वेरिएबल को प्रबंधित करने का सबसे आसान तरीका है उन्हें पायथन ऑब्जेक्ट्स से जोड़ना, फिर उन ऑब्जेक्ट्स को रेफर करना।

tf.train.Checkpoint , tf.keras.layers.Layer , और tf.keras.Model के उपवर्ग स्वचालित रूप से उनकी विशेषताओं को निर्दिष्ट चर को ट्रैक करते हैं। निम्न उदाहरण एक साधारण रैखिक मॉडल का निर्माण करता है, फिर उन चौकियों को लिखता है जिनमें मॉडल के सभी चर के लिए मान होते हैं।

आप Model.save_weights के साथ आसानी से मॉडल-चेकपॉइंट सहेज सकते हैं।

मैनुअल चेकपॉइंटिंग

सेट अप

tf.train.Checkpoint की सभी विशेषताओं को प्रदर्शित करने में मदद करने के लिए, एक खिलौना डेटासेट और अनुकूलन चरण को परिभाषित करें:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

चेकपॉइंट ऑब्जेक्ट बनाएं

मैन्युअल रूप से एक चेकपॉइंट बनाने के लिए एक tf.train.Checkpoint ऑब्जेक्ट का उपयोग करें, जहां आप जिस ऑब्जेक्ट को चेकपॉइंट करना चाहते हैं, उसे ऑब्जेक्ट पर विशेषताओं के रूप में सेट किया गया है।

एक tf.train.CheckpointManager कई चौकियों के प्रबंधन के लिए भी सहायक हो सकता है।

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

मॉडल को ट्रेन और चेकपॉइंट

निम्नलिखित प्रशिक्षण लूप मॉडल और एक अनुकूलक का एक उदाहरण बनाता है, फिर उन्हें एक tf.train.Checkpoint ऑब्जेक्ट में एकत्रित करता है। यह डेटा के प्रत्येक बैच पर प्रशिक्षण चरण को लूप में कॉल करता है, और समय-समय पर डिस्क पर चौकियों को लिखता है।

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

पुनर्स्थापित करें और प्रशिक्षण जारी रखें

पहले प्रशिक्षण चक्र के बाद आप एक नया मॉडल और प्रबंधक पास कर सकते हैं, लेकिन प्रशिक्षण ठीक वहीं से लें जहां आपने छोड़ा था:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

tf.train.CheckpointManager ऑब्जेक्ट पुरानी चौकियों को हटा देता है। ऊपर इसे केवल तीन सबसे हाल की चौकियों को रखने के लिए कॉन्फ़िगर किया गया है।

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

ये पथ, जैसे './tf_ckpts/ckpt-10' डिस्क पर फ़ाइलें नहीं हैं। इसके बजाय वे एक index फ़ाइल और एक या अधिक डेटा फ़ाइलों के लिए उपसर्ग हैं जिनमें चर मान होते हैं। इन उपसर्गों को एक एकल checkpoint फ़ाइल ( './tf_ckpts/checkpoint' ) में एक साथ समूहीकृत किया जाता है जहाँ CheckpointManager अपनी स्थिति को सहेजता है।

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

लोड हो रहा है यांत्रिकी

TensorFlow लोड किए जा रहे ऑब्जेक्ट से शुरू होने वाले नामित किनारों के साथ एक निर्देशित ग्राफ़ को पार करके चेकपॉइंट मानों से चर का मिलान करता है। एज नाम आमतौर पर ऑब्जेक्ट में विशेषता नामों से आते हैं, उदाहरण के लिए "l1" में self.l1 = tf.keras.layers.Dense(5)tf.train.Checkpoint अपने कीवर्ड तर्क नामों का उपयोग करता है, जैसा कि tf.train.Checkpoint(step=...) में "step" में है।

ऊपर के उदाहरण से निर्भरता ग्राफ इस तरह दिखता है:

उदाहरण प्रशिक्षण लूप के लिए निर्भरता ग्राफ का विज़ुअलाइज़ेशन

अनुकूलक लाल रंग में है, नियमित चर नीले रंग में हैं, और अनुकूलक स्लॉट चर नारंगी रंग में हैं। अन्य नोड्स—उदाहरण के लिए, tf.train.Checkpoint का प्रतिनिधित्व करते हुए—काले रंग में हैं।

स्लॉट चर अनुकूलक की स्थिति का हिस्सा होते हैं, लेकिन एक विशिष्ट चर के लिए बनाए जाते हैं। उदाहरण के लिए, ऊपर 'm' किनारे गति के अनुरूप हैं, जिसे एडम अनुकूलक प्रत्येक चर के लिए ट्रैक करता है। स्लॉट चर केवल एक चेकपॉइंट में सहेजे जाते हैं यदि चर और अनुकूलक दोनों सहेजे जाएंगे, इस प्रकार धराशायी किनारों।

एक tf.train.Checkpoint ऑब्जेक्ट पर restore को कॉल करना अनुरोधित पुनर्स्थापनों को कतारबद्ध करता है, जैसे ही Checkpoint ऑब्जेक्ट से मेल खाने वाला पथ होता है, चर मानों को पुनर्स्थापित करता है। उदाहरण के लिए, आप नेटवर्क और परत के माध्यम से इसके लिए एक पथ का पुनर्निर्माण करके ऊपर परिभाषित मॉडल से केवल पूर्वाग्रह लोड कर सकते हैं।

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

इन नई वस्तुओं के लिए निर्भरता ग्राफ आपके द्वारा ऊपर लिखे गए बड़े चेकपॉइंट का बहुत छोटा सबग्राफ है। इसमें केवल पूर्वाग्रह और एक बचत काउंटर शामिल है जो tf.train.Checkpoint संख्या चौकियों के लिए उपयोग करता है।

पूर्वाग्रह चर के लिए एक सबग्राफ का विज़ुअलाइज़ेशन

restore एक स्थिति वस्तु देता है, जिसमें वैकल्पिक अभिकथन होता है। नए Checkpoint में बनाए गए सभी ऑब्जेक्ट बहाल कर दिए गए हैं, इसलिए status.assert_existing_objects_matched पास हो जाता है।

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

चेकपॉइंट में कई ऑब्जेक्ट हैं जो मेल नहीं खाते हैं, जिनमें लेयर का कर्नेल और ऑप्टिमाइज़र के वेरिएबल्स शामिल हैं। status.assert_consumed केवल तभी गुजरता है जब चेकपॉइंट और प्रोग्राम बिल्कुल मेल खाते हैं, और यहां एक अपवाद फेंक देंगे।

आस्थगित बहाली

जब इनपुट आकार उपलब्ध हों, तो TensorFlow में Layer ऑब्जेक्ट अपनी पहली कॉल के लिए चर के निर्माण को स्थगित कर सकते हैं। उदाहरण के लिए, Dense परत के कर्नेल का आकार परत के इनपुट और आउटपुट आकार दोनों पर निर्भर करता है, और इसलिए एक कंस्ट्रक्टर तर्क के रूप में आवश्यक आउटपुट आकार अपने आप में चर बनाने के लिए पर्याप्त जानकारी नहीं है। चूंकि एक Layer को कॉल करना भी चर के मूल्य को पढ़ता है, चर के निर्माण और इसके पहले उपयोग के बीच एक पुनर्स्थापना होनी चाहिए।

इस मुहावरे का समर्थन करने के लिए, tf.train.Checkpoint पुनर्स्थापनाओं को स्थगित कर देता है जिनका अभी तक मिलान करने वाला चर नहीं है।

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]
प्लेसहोल्डर22

चेकपॉइंट्स का मैन्युअल रूप से निरीक्षण

tf.train.load_checkpoint एक CheckpointReader देता है जो चेकपॉइंट सामग्री को निचले स्तर तक पहुंच प्रदान करता है। इसमें चेकपॉइंट में प्रत्येक चर के लिए प्रत्येक चर की कुंजी से आकृति और dtype के लिए मैपिंग शामिल है। एक चर की कुंजी उसका ऑब्जेक्ट पथ है, जैसे ऊपर प्रदर्शित ग्राफ़ में।

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

इसलिए यदि आप net.l1.kernel के मूल्य में रुचि रखते हैं तो आप निम्न कोड के साथ मान प्राप्त कर सकते हैं:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32
प्लेसहोल्डर26

यह एक get_tensor विधि भी प्रदान करता है जिससे आप एक चर के मान का निरीक्षण कर सकते हैं:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

वस्तु ट्रैकिंग

चेकपॉइंट्स tf.Variable ऑब्जेक्ट्स के मानों को सहेजते हैं और पुनर्स्थापित करते हैं, किसी भी वैरिएबल या ट्रैक करने योग्य ऑब्जेक्ट को उसकी विशेषताओं में से एक में "ट्रैकिंग" करके। एक सेव निष्पादित करते समय, सभी पहुंच योग्य ट्रैक किए गए ऑब्जेक्ट्स से चर को पुनरावर्ती रूप से एकत्र किया जाता है।

जैसा कि self.l1 = tf.keras.layers.Dense(5) जैसे प्रत्यक्ष विशेषता असाइनमेंट के साथ होता है, विशेषताओं के लिए सूचियाँ और शब्दकोश निर्दिष्ट करना उनकी सामग्री को ट्रैक करेगा।

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

आप सूचियों और शब्दकोशों के लिए रैपर ऑब्जेक्ट देख सकते हैं। ये रैपर अंतर्निहित डेटा-संरचनाओं के चेकपॉइंट करने योग्य संस्करण हैं। विशेषता आधारित लोडिंग की तरह, ये रैपर कंटेनर में जोड़े जाने पर एक चर के मान को बहाल कर देते हैं।

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

ट्रैक करने योग्य वस्तुओं में शामिल हैं tf.train.Checkpoint , tf.Module और इसके उपवर्ग (जैसे keras.layers.Layer और keras.Model ), और मान्यता प्राप्त पायथन कंटेनर:

  • dict (और collections.OrderedDict । ऑर्डर किए गए डिक्ट)
  • list
  • tuple (और collections.namedtuple , टाइपिंग. typing.NamedTuple )

अन्य कंटेनर प्रकार समर्थित नहीं हैं, जिनमें शामिल हैं:

  • collections.defaultdict
  • set

अन्य सभी पायथन वस्तुओं को अनदेखा किया जाता है, जिनमें शामिल हैं:

  • int
  • string
  • float

सारांश

TensorFlow ऑब्जेक्ट उनके द्वारा उपयोग किए जाने वाले चर के मानों को सहेजने और पुनर्स्थापित करने के लिए एक आसान स्वचालित तंत्र प्रदान करते हैं।