গুগল আই/ও একটি মোড়ক! TensorFlow সেশনগুলি দেখুন সেশনগুলি দেখুন

টিএফএফে জ্যাক্সের জন্য পরীক্ষামূলক সহায়তা support

TensorFlow.org এ দেখুন Google Colab-এ চালান GitHub এ দেখুন নোটবুক ডাউনলোড করুন

TensorFlow ইকোসিস্টেমের একটি অংশ হওয়ার পাশাপাশি, TFF-এর লক্ষ্য অন্যান্য ফ্রন্টএন্ড এবং ব্যাকএন্ড ML ফ্রেমওয়ার্কের সাথে আন্তঃকার্যযোগ্যতা সক্ষম করা। এই মুহুর্তে, অন্যান্য ML ফ্রেমওয়ার্কের জন্য সমর্থন এখনও ইনকিউবেশন পর্যায়ে রয়েছে, এবং API এবং সমর্থিত কার্যকারিতা পরিবর্তিত হতে পারে (বিশালভাবে TFF ব্যবহারকারীদের চাহিদার একটি ফাংশন হিসাবে)। এই টিউটোরিয়ালটি বর্ণনা করে যে কীভাবে JAX-এর সাথে TFF একটি বিকল্প ML ফ্রন্টএন্ড এবং XLA কম্পাইলার একটি বিকল্প ব্যাকএন্ড হিসাবে ব্যবহার করা যায়। এখানে দেখানো উদাহরণগুলি সম্পূর্ণরূপে নেটিভ JAX/XLA স্ট্যাকের উপর ভিত্তি করে, এন্ড-টু-এন্ড। ফ্রেমওয়ার্ক জুড়ে কোড মিশ্রিত করার সম্ভাবনা (যেমন, টেনসরফ্লো সহ JAX) ভবিষ্যতের টিউটোরিয়ালগুলির একটিতে আলোচনা করা হবে।

সর্বদা হিসাবে, আমরা আপনার অবদান স্বাগত জানাই. যদি JAX/XLA-এর জন্য সমর্থন বা অন্যান্য ML ফ্রেমওয়ার্কের সাথে ইন্টারঅপারেশন করার ক্ষমতা আপনার জন্য গুরুত্বপূর্ণ হয়, তাহলে অনুগ্রহ করে আমাদের এই ক্ষমতাগুলিকে TFF-এর অবশিষ্ট অংশের সাথে সমতার দিকে বিকশিত করতে সাহায্য করুন।

আমরা শুরু করার আগে

কিভাবে আপনার পরিবেশ কনফিগার করবেন তার জন্য TFF ডকুমেন্টেশনের মূল অংশের সাথে পরামর্শ করুন। আপনি এই টিউটোরিয়ালটি কোথায় চালাচ্ছেন তার উপর নির্ভর করে, আপনি নীচের কিছু বা সমস্ত কোড আনকমেন্ট করতে এবং চালাতে চাইতে পারেন।

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

এই টিউটোরিয়ালটি অনুমান করে যে আপনি TFF এর প্রাথমিক TensorFlow টিউটোরিয়াল পর্যালোচনা করেছেন এবং আপনি মূল TFF ধারণাগুলির সাথে পরিচিত। আপনি যদি এখনও এটি না করে থাকেন তবে অনুগ্রহ করে তাদের মধ্যে অন্তত একটি পর্যালোচনা করুন।

JAX গণনা

TFF-এ JAX-এর জন্য সমর্থনকে এমনভাবে ডিজাইন করা হয়েছে যেভাবে TFF TensorFlow-এর সাথে ইন্টারঅপারেটিং করে, আমদানি থেকে শুরু করে:

import jax
import numpy as np
import tensorflow_federated as tff

এছাড়াও, TensorFlow-এর মতই, যেকোনো TFF কোড প্রকাশ করার ভিত্তি হল যুক্তি যা স্থানীয়ভাবে চলে। নিচের চিত্রের ব্যবহার আপনি Jax মধ্যে এই যুক্তি প্রকাশ করতে পারেন @tff.experimental.jax_computation মোড়কের। এটা তোলে একইভাবে আচরণ করবে @tff.tf_computation যে এখন দ্বারা আপনার সাথে পরিচিত। আসুন সহজ কিছু দিয়ে শুরু করি, যেমন, একটি গণনা যা দুটি পূর্ণসংখ্যা যোগ করে:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

আপনি উপরে সংজ্ঞায়িত JAX গণনাটি ব্যবহার করতে পারেন যেমন আপনি সাধারণত একটি TFF গণনা ব্যবহার করেন। উদাহরণস্বরূপ, আপনি এটির প্রকার স্বাক্ষর পরীক্ষা করতে পারেন, নিম্নরূপ:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

মনে রাখবেন যে, আমরা ব্যবহৃত np.int32 আর্গুমেন্ট ধরণ নির্ধারণ করতে। TFF (যেমন Numpy ধরনের মধ্যে পার্থক্য নেই np.int32 ) এবং (যেমন TensorFlow টাইপ tf.int32 )। TFF এর দৃষ্টিকোণ থেকে, তারা একই জিনিস উল্লেখ করার উপায় মাত্র।

এখন, মনে রাখবেন যে TFF পাইথন নয় (এবং যদি এটি একটি ঘণ্টা না বাজে, অনুগ্রহ করে আমাদের আগের কিছু টিউটোরিয়াল পর্যালোচনা করুন, যেমন, কাস্টম অ্যালগরিদমগুলিতে)। আপনি ব্যবহার করতে পারেন @tff.experimental.jax_computation কোড সহ, মোড়কের সঙ্গে কোনো Jax কোডটি আঁকা যাবে না এবং ধারাবাহিকভাবে, অর্থাত্ যে আপনার would স্বাভাবিকভাবে সঙ্গে annotate @jax.jit XLA মধ্যে কম্পাইল করা হবে বলে আশা করা (কিন্তু আপনি প্রয়োজন হবে না আসলে ব্যবহার @jax.jit টীকা থেকে এম্বেড আপনার Jax TFF কোড)।

প্রকৃতপক্ষে, হুডের নিচে, TFF অবিলম্বে JAX কম্পিউটেশনগুলিকে XLA-তে কম্পাইল করে। আপনি নিজে আহরণের এবং প্রিন্টিং থেকে ধারাবাহিকভাবে XLA কোড দ্বারা নিজের জন্য এই পরীক্ষা করতে পারবেন add_numbers , নিম্নরূপঃ:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

কার্যকরী সমতুল্য হচ্ছে XLA কোড হিসেবে Jax কম্পিউটেশন উপস্থাপনা চিন্তা tf.GraphDef TensorFlow প্রকাশ কম্পিউটেশনের জন্য। এটা তোলে পোর্টেবল এবং পরিবেশ রয়েছে যা XLA সমর্থন, ঠিক বিভিন্ন এক্সিকিউটেবল হয় tf.GraphDef কোনো TensorFlow রানটাইম মৃত্যুদন্ড কার্যকর করা যেতে পারে।

TFF ব্যাকএন্ড হিসাবে XLA কম্পাইলারের উপর ভিত্তি করে একটি রানটাইম স্ট্যাক প্রদান করে। আপনি নিম্নলিখিত হিসাবে এটি সক্রিয় করতে পারেন:

tff.backends.xla.set_local_python_execution_context()

এখন, আপনি আমরা উপরে সংজ্ঞায়িত গণনা চালাতে পারেন:

add_numbers(2, 3)
5

যথেষ্ট সহজ. আসুন ঘা দিয়ে যাই এবং আরও জটিল কিছু করি, যেমন MNIST।

ক্যানড API সহ MNIST প্রশিক্ষণের উদাহরণ

যথারীতি, আমরা ডেটার ব্যাচ এবং মডেলের জন্য একগুচ্ছ TFF প্রকার সংজ্ঞায়িত করে শুরু করি (মনে রাখবেন, TFF একটি দৃঢ়ভাবে টাইপ করা কাঠামো)।

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

এখন, JAX-এ মডেলের জন্য একটি লস ফাংশন সংজ্ঞায়িত করা যাক, মডেল এবং ডেটার একটি একক ব্যাচকে প্যারামিটার হিসাবে গ্রহণ করুন:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

এখন, যাওয়ার একটি উপায় হল একটি টিনজাত API ব্যবহার করা। এইমাত্র সংজ্ঞায়িত ক্ষতি ফাংশনের উপর ভিত্তি করে একটি প্রশিক্ষণ প্রক্রিয়া তৈরি করতে আপনি কীভাবে আমাদের API ব্যবহার করতে পারেন তার একটি উদাহরণ এখানে রয়েছে।

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

আপনি শুধু উপরে ব্যবহার আপনি একটি থেকে একটি প্রশিক্ষকদের তৈরি ব্যবহার করেন পারেন tf.Keras TensorFlow মডেল। উদাহরণস্বরূপ, প্রশিক্ষণের জন্য আপনি কীভাবে প্রাথমিক মডেল তৈরি করতে পারেন তা এখানে:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

প্রকৃত প্রশিক্ষণ সঞ্চালন করার জন্য, আমাদের কিছু তথ্য প্রয়োজন। আসুন এটিকে সহজ রাখতে র্যান্ডম ডেটা তৈরি করি। যেহেতু ডেটা এলোমেলো, তাই আমরা প্রশিক্ষণ ডেটার উপর মূল্যায়ন করতে যাচ্ছি, যেহেতু অন্যথায়, র্যান্ডম ইভাল ডেটা সহ, মডেলটি সম্পাদন করার আশা করা কঠিন হবে। এছাড়াও, এই ছোট আকারের ডেমোর জন্য, আমরা এলোমেলোভাবে ক্লায়েন্টদের নমুনা নেওয়ার বিষয়ে চিন্তা করব না (আমরা এটিকে ব্যবহারকারীর জন্য অন্যান্য টিউটোরিয়ালের টেমপ্লেটগুলি অনুসরণ করে এই ধরণের পরিবর্তনগুলি অন্বেষণ করার জন্য একটি অনুশীলন হিসাবে ছেড়ে দিই):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

এর সাথে, আমরা প্রশিক্ষণের একটি একক ধাপ সম্পাদন করতে পারি, নিম্নরূপ:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

আসুন প্রশিক্ষণ ধাপের ফলাফল মূল্যায়ন করা যাক। এটি সহজ রাখতে, আমরা এটিকে কেন্দ্রীভূত ফ্যাশনে মূল্যায়ন করতে পারি:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

লোকসান কমছে। দারুণ! এখন, একাধিক রাউন্ডে এটি চালানো যাক:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

যেমন আপনি দেখতে পাচ্ছেন, TFF-এর সাথে JAX ব্যবহার করা তেমন আলাদা নয়, যদিও পরীক্ষামূলক APIগুলি এখনও TensorFlow API-এর কার্যকারিতা অনুসারে সমান নয়।

ফণা অধীনে

আপনি যদি আমাদের ক্যানড এপিআই ব্যবহার না করতে পছন্দ করেন, আপনি আপনার নিজস্ব কাস্টম কম্পিউটেশন বাস্তবায়ন করতে পারেন, যেভাবে আপনি টেনসরফ্লো-এর কাস্টম অ্যালগরিদম টিউটোরিয়ালগুলিতে দেখেছেন যেভাবে আপনি গ্রেডিয়েন্ট ডিসেন্টের জন্য JAX-এর মেকানিজম ব্যবহার করবেন। উদাহরণস্বরূপ, নীচে আপনি কীভাবে একটি JAX গণনা সংজ্ঞায়িত করতে পারেন যা একটি একক মিনিব্যাচে মডেল আপডেট করে:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

এখানে আপনি কিভাবে পরীক্ষা করতে পারেন যে এটি কাজ করে:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Jax সঙ্গে কাজ এক সতর্কীকরণ এটি সমতুল্য অফার করে না tf.data.Dataset । সুতরাং, ডেটাসেটগুলির উপর পুনরাবৃত্তি করার জন্য, আপনাকে ক্রমগুলির উপর ক্রিয়াকলাপের জন্য TFF-এর ঘোষণামূলক চুক্তিগুলি ব্যবহার করতে হবে, যেমন নীচে দেখানো হয়েছে:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

আসুন দেখি যে এটি কাজ করে:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

যে গণনাটি প্রশিক্ষণের একটি একক রাউন্ড সঞ্চালন করে তা দেখতে ঠিক যেমন আপনি টেনসরফ্লো টিউটোরিয়ালগুলিতে দেখেছেন:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

আসুন দেখি যে এটি কাজ করে:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

যেমন আপনি দেখতে পাচ্ছেন, টিএফএফ-এ JAX ব্যবহার করা, তা ক্যানড API-এর মাধ্যমে হোক বা সরাসরি নিম্ন-স্তরের TFF কনস্ট্রাক্ট ব্যবহার করা, টেনসরফ্লো-এর সাথে TFF ব্যবহার করার মতো। ভবিষ্যতের আপডেটের জন্য সাথে থাকুন, এবং আপনি যদি ML ফ্রেমওয়ার্ক জুড়ে আন্তঃকার্যযোগ্যতার জন্য আরও ভাল সমর্থন দেখতে চান তবে নির্দ্বিধায় আমাদের একটি পুল অনুরোধ পাঠান!