एक सवाल है? TensorFlow फ़ोरम विज़िट फ़ोरम पर समुदाय से जुड़ें

XLA: मशीन लर्निंग के लिए कंपाइलर ऑप्टिमाइज़ करना

XLA (त्वरित रेखीय बीजगणित) रैखिक बीजगणित के लिए एक डोमेन-विशिष्ट संकलक है जो संभावित रूप से कोई स्रोत कोड परिवर्तन के साथ TensorFlow मॉडल को तेज कर सकता है।

परिणाम गति और स्मृति उपयोग में सुधार कर रहे हैं: उदाहरण के लिए BERT MLPerf सबमिशन में 8 वोल्टा V100 GPU का उपयोग करते हुए XLA का उपयोग करके ~ 7x प्रदर्शन में सुधार और ~ 5x बैच आकार में सुधार हुआ है:

परिचय

जब TensorFlow प्रोग्राम चलाया जाता है, तो सभी ऑपरेशन TensorFlow निष्पादक द्वारा व्यक्तिगत रूप से निष्पादित किए जाते हैं। प्रत्येक TensorFlow ऑपरेशन में एक precompiled GPU कर्नेल कार्यान्वयन है जिसे निष्पादक भेजता है।

XLA रनिंग मॉडल का एक वैकल्पिक मोड प्रदान करता है: यह TensorFlow ग्राफ को दिए गए मॉडल के लिए विशेष रूप से उत्पन्न गणना कर्नेल के अनुक्रम में संकलित करता है। क्योंकि ये गुठली मॉडल के लिए अद्वितीय हैं, वे अनुकूलन के लिए मॉडल-विशिष्ट जानकारी का उपयोग कर सकते हैं। उदाहरण के लिए, आइए एक एक्सएलए एक साधारण TensorFlow संगणना के संदर्भ में एक अनुकूलन को देखें:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

एक्सएलए के बिना चलाएं, ग्राफ तीन गुठली लॉन्च करता है: एक गुणा के लिए, एक जोड़ के लिए और एक कमी के लिए। हालाँकि, XLA ग्राफ़ को अनुकूलित कर सकता है ताकि यह एकल कर्नेल लॉन्च में परिणाम की गणना करे। यह एक एकल GPU कर्नेल में जोड़, गुणा और कमी को "फ्यूज" करके करता है। इसके अलावा, यह फ़्यूज़्ड ऑपरेशन मेमोरी में y*z और x+y*z द्वारा निर्मित मध्यवर्ती मूल्यों को नहीं लिखता है; इसके बजाय यह पूरी तरह से GPU रजिस्टरों में रखते हुए अपने उपयोगकर्ताओं के लिए सीधे इन मध्यवर्ती संगणनाओं के परिणामों को "स्ट्रीम" करता है। फ्यूजन XLA का सबसे महत्वपूर्ण अनुकूलन है। मेमोरी बैंडविड्थ आमतौर पर हार्डवेयर एक्सेलेरेटर पर स्कार् टी संसाधन है, इसलिए मेमोरी ऑपरेशंस को हटाना प्रदर्शन को बेहतर बनाने के सर्वोत्तम तरीकों में से एक है।

TensorFlow मॉडल के लिए XLA सक्षम करें

tf.function(jit_compile=True) साथ स्पष्ट संकलन

स्पष्ट संकलन एपीआई यह चुनने के लिए एक ठीक-ठाक नियंत्रण प्रदान करता है कि किन कार्यों को संकलित किया जाना चाहिए। उदाहरण के लिए, निम्नलिखित TensorFlow फ़ंक्शन जो MNIST प्रशिक्षण करता है, XLA के साथ संकलित किया जाता है:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

jit_compile API के पास शब्दार्थ संकलन होना चाहिए : या तो पूरे फ़ंक्शन को errors.InvalidArgumentError , या errors.InvalidArgumentError साथ संकलित किया गया है। errors.InvalidArgumentError अपवाद को फेंक दिया गया है। एक्सएलए वर्तमान में उन कार्यों को संकलित नहीं कर सकता है जहां आयाम बांझ नहीं हैं: अर्थात, यदि संपूर्ण गणना को चलाने के बिना सभी टेंसरों के आयामों का अनुमान लगाना संभव नहीं है। उदाहरण के लिए, निम्नलिखित फ़ंक्शन संकलित नहीं करेगा:

@tf.function
def not_compilable(x):
  return tf.unique(x)

आकार हालांकि रनों में भिन्न हो सकते हैं:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

अधिक विस्तृत उपयोग उदाहरण के लिए ट्यूटोरियल कोलाब देखें।

ऑटो-क्लस्टरिंग

बिना किसी बदलाव के TensorFlow मॉडल में XLA का उपयोग शुरू करने का एक सरल तरीका ऑटो-क्लस्टरिंग को सक्षम करना है , जो TensorFlow कार्यों के भीतर स्वचालित रूप से क्लस्टर (जुड़े उपसमूह) पाता है जिसे XLA द्वारा संकलित और निष्पादित किया जा सकता है। GPU पर ऑटो-क्लस्टरिंग को TF_XLA_FLAGS पर्यावरण चर सेट करके सक्षम किया जा सकता है:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

वर्तमान में ऑटो-क्लस्टरिंग को GPU वर्कलोड के लिए अनुकूलित किया गया है, लेकिन यह सीपीयू पर अतिरिक्त रूप से ध्वज --tf_xla_cpu_global_jit का उपयोग करके भी सक्षम किया जा सकता है:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

एक विस्तृत उपयोग उदाहरण के लिए ऑटो-क्लस्टरिंग ट्यूटोरियल कोलाब देखें

एओटी (अहेड-ऑफ-टाइम) tfcompile लिए tfcompile साथ संकलन

आप एक स्टैंडअलोन tfcompile टूल का भी उपयोग कर सकते हैं, जो TensorFlow ग्राफ को निष्पादन योग्य कोड में परिवर्तित करता है (केवल x86-64 CPU के लिए)।

संकलित कार्यक्रमों का निरीक्षण करें

XLA आत्मनिरीक्षण की सुविधा प्रदान करता है जो आपको उत्पन्न कार्यक्रमों का निरीक्षण करने देता है। उत्पन्न कार्यक्रमों को डंप करने के लिए, पर्यावरण चर XLA_FLAGS उपयोग करें:

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

डंपिंग करने के बाद, आप निम्न फ़ाइलें /tmp/generated में पा सकते हैं:

  • module_XXXX.*_optimizations.txt एक्स्ट्रा module_XXXX.*_optimizations.txt प्रोग्राम , प्रत्येक एक संकलित क्लस्टर। XLA बग रिपोर्ट सबमिट करते समय उन्हें संलग्न करना बेहद मददगार होता है!

  • module_XXXX.ir-*.ll में उत्पन्न फ़ाइलों LLVM मध्यवर्ती प्रतिनिधित्व, साथ NVPTX intrinsics।

  • module_XXXX.ptx जनरेट किया PTX फ़ाइलें।

आप TensorFlow ग्राफ के अंदर XLA क्लस्टर के एम्बेडिंग के दृश्य ग्राफ को भी डंप कर सकते हैं:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

प्रतिगामी बग रिपोर्ट

बग रिपोर्ट को पुन: उत्पन्न करने के लिए बहुत आसान है यदि इसमें उत्पन्न XLA कार्यक्रमों और प्रयुक्त ऑटो-क्लस्टरिंग एम्बेडिंग के लिए डंप शामिल हैं। ऑटो-क्लस्टरिंग के साथ चलने वाले TensorFlow कार्यक्रम के लिए उन्हें उत्पन्न करने के लिए, लॉन्च करें:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

बग दर्ज करते समय, /tmp/generated निर्देशिका की सामग्री संलग्न करें (ऊपर संदर्भित)।

यदि संभव हो, तो replay_computation का उपयोग करके किसी एकल replay_computation प्रोग्राम में बग को अलग करने की कोशिश करें और इसे उत्पन्न कार्यक्रमों पर चलने दें।

अग्रिम पठन

XLA का मोर्चा

TensorFlow के अलावा, XLA कार्यक्रमों द्वारा उत्पन्न किया जा सकता है:

  • JAX : पायथन + न्यूमपी कार्यक्रमों के संगत परिवर्तन
  • जूलिया : वैज्ञानिक कंप्यूटिंग के लिए जूलिया भाषा
  • PyTorch : PyTorch की रूपरेखा
  • Nx : अमृत प्रोग्रामिंग भाषा के लिए संख्यात्मक कंप्यूटिंग लाइब्रेरी