Google I / O là một kết quả hoàn hảo! Cập nhật các phiên TensorFlow Xem phiên

Hỗ trợ thử nghiệm cho JAX trong TFF

Xem trên TensorFlow.org Chạy trong Google Colab Xem trên GitHub Tải xuống sổ ghi chép

Ngoài việc trở thành một phần của hệ sinh thái TensorFlow, TFF còn hướng tới việc kích hoạt khả năng tương tác với các khung ML frontend và backend khác. Hiện tại, hỗ trợ cho các khuôn khổ ML khác vẫn đang trong giai đoạn ươm tạo và các API và chức năng được hỗ trợ có thể thay đổi (phần lớn là một chức năng theo yêu cầu của người dùng TFF). Hướng dẫn này mô tả cách sử dụng TFF với JAX như một giao diện người dùng ML thay thế và trình biên dịch XLA như một chương trình phụ trợ thay thế. Các ví dụ được hiển thị ở đây dựa trên một ngăn xếp JAX / XLA hoàn toàn nguyên bản, end-to-end. Khả năng trộn mã giữa các khuôn khổ (ví dụ: JAX với TensorFlow) sẽ được thảo luận trong một trong các hướng dẫn trong tương lai.

Như mọi khi, chúng tôi hoan nghênh những đóng góp của bạn. Nếu hỗ trợ cho JAX / XLA hoặc khả năng tương tác với các khuôn khổ ML khác là quan trọng đối với bạn, vui lòng xem xét giúp chúng tôi phát triển các khả năng này theo hướng ngang bằng với phần còn lại của TFF.

Trước khi chúng tôi bắt đầu

Vui lòng tham khảo phần chính của tài liệu TFF để biết cách định cấu hình môi trường của bạn. Tùy thuộc vào nơi bạn đang chạy hướng dẫn này, bạn có thể muốn bỏ ghi chú và chạy một số hoặc tất cả mã bên dưới.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

Hướng dẫn này cũng giả định rằng bạn đã xem lại các hướng dẫn TensorFlow chính của TFF và bạn đã quen thuộc với các khái niệm cốt lõi của TFF. Nếu bạn chưa làm được điều này, hãy cân nhắc xem lại ít nhất một trong số chúng.

Tính toán JAX

Hỗ trợ cho JAX trong TFF được thiết kế để đối xứng với cách thức mà TFF tương tác với TensorFlow, bắt đầu với việc nhập:

import jax
import numpy as np
import tensorflow_federated as tff

Ngoài ra, giống như với TensorFlow, nền tảng để thể hiện bất kỳ mã TFF nào là logic chạy cục bộ. Bạn có thể thể hiện logic này trong JAX, như hình dưới đây, sử dụng @tff.experimental.jax_computation wrapper. Nó hoạt động tương tự như @tff.tf_computation rằng bây giờ bạn đã quen thuộc với. Hãy bắt đầu với một cái gì đó đơn giản, ví dụ, một phép tính thêm hai số nguyên:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Bạn có thể sử dụng tính toán JAX được định nghĩa ở trên giống như bạn thường sử dụng tính toán TFF. Ví dụ: bạn có thể kiểm tra chữ ký kiểu của nó, như sau:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Lưu ý rằng chúng tôi sử dụng np.int32 để xác định loại đối số. TFF không phân biệt giữa các loại NumPy (như np.int32 ) và loại TensorFlow (như tf.int32 ). Theo quan điểm của TFF, chúng chỉ là những cách để đề cập đến cùng một thứ.

Bây giờ, hãy nhớ rằng TFF không phải là Python (và nếu điều này không gây tiếng chuông, vui lòng xem lại một số hướng dẫn trước đây của chúng tôi, ví dụ: về thuật toán tùy chỉnh). Bạn có thể sử dụng @tff.experimental.jax_computation wrapper với bất kỳ JAX mã có thể được truy tìm và tuần tự, tức là, với mã mà bạn sẽ thường chú thích với @jax.jit dự kiến sẽ được biên dịch vào XLA (nhưng bạn không cần phải thực sự sử dụng @jax.jit chú thích để nhúng mã JAX của bạn trong TFF).

Thật vậy, TFF ngay lập tức biên dịch các phép tính JAX sang XLA. Bạn có thể kiểm tra điều này cho chính mình bằng cách thủ công chiết xuất và in mã XLA đăng từ add_numbers , như sau:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Hãy suy nghĩ về đại diện của JAX tính toán như mã XLA như là tương đương với chức năng của tf.GraphDef cho tính toán thể hiện trong TensorFlow. Nó là xách tay và thực thi trong một loạt các môi trường có hỗ trợ XLA, giống như tf.GraphDef thể được thực hiện trên bất kỳ thời gian chạy TensorFlow.

TFF cung cấp một ngăn xếp thời gian chạy dựa trên trình biên dịch XLA như một chương trình phụ trợ. Bạn có thể kích hoạt nó như sau:

tff.backends.xla.set_local_python_execution_context()

Bây giờ, bạn có thể thực hiện phép tính mà chúng tôi đã xác định ở trên:

add_numbers(2, 3)
5

Vừa đủ dễ. Hãy bắt đầu với cú đánh và làm một cái gì đó phức tạp hơn, chẳng hạn như MNIST.

Ví dụ về đào tạo MNIST với API đóng hộp

Như thường lệ, chúng tôi bắt đầu bằng cách xác định một loạt các kiểu TFF cho các lô dữ liệu và cho mô hình (hãy nhớ rằng, TFF là một khuôn khổ được đánh máy mạnh).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Bây giờ, hãy xác định một hàm mất mát cho mô hình trong JAX, lấy mô hình và một lô dữ liệu duy nhất làm tham số:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Bây giờ, một cách để thực hiện là sử dụng một API đóng hộp. Dưới đây là ví dụ về cách bạn có thể sử dụng API của chúng tôi để tạo quy trình đào tạo dựa trên hàm mất mát vừa được xác định.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Bạn có thể sử dụng ở trên cũng giống như bạn sẽ sử dụng một huấn luyện viên xây dựng từ một tf.Keras mô hình trong TensorFlow. Ví dụ: đây là cách bạn có thể tạo mô hình ban đầu để đào tạo:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Để thực hiện đào tạo thực tế, chúng tôi cần một số dữ liệu. Hãy tạo dữ liệu ngẫu nhiên để giữ cho nó đơn giản. Vì dữ liệu là ngẫu nhiên, chúng tôi sẽ đánh giá trên dữ liệu huấn luyện, vì nếu không, với dữ liệu đánh giá ngẫu nhiên, sẽ khó có thể mong đợi mô hình hoạt động. Ngoài ra, đối với bản demo quy mô nhỏ này, chúng tôi sẽ không lo lắng về việc lấy mẫu khách hàng ngẫu nhiên (chúng tôi để nó như một bài tập cho người dùng để khám phá các loại thay đổi đó bằng cách làm theo các mẫu từ các hướng dẫn khác):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Với điều đó, chúng ta có thể thực hiện một bước đào tạo duy nhất, như sau:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Hãy đánh giá kết quả của bước đào tạo. Để dễ dàng, chúng tôi có thể đánh giá nó theo cách tập trung:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

Sự mất mát ngày càng giảm. Tuyệt quá! Bây giờ, hãy chạy điều này qua nhiều vòng:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Như bạn thấy, việc sử dụng JAX với TFF không có gì khác biệt, mặc dù các API thử nghiệm vẫn chưa ngang bằng với chức năng của các API TensorFlow.

Dưới mui xe

Nếu bạn không muốn sử dụng API đóng hộp của chúng tôi, bạn có thể triển khai các tính toán tùy chỉnh của riêng mình, giống như cách bạn đã thấy nó được thực hiện trong các hướng dẫn thuật toán tùy chỉnh cho TensorFlow, ngoại trừ việc bạn sẽ sử dụng cơ chế của JAX để giảm độ dốc. Ví dụ: dưới đây là cách bạn có thể xác định tính toán JAX cập nhật mô hình trên một minibatch duy nhất:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Đây là cách bạn có thể kiểm tra xem nó có hoạt động không:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Một báo trước khi làm việc với JAX là nó không cung cấp tương đương với tf.data.Dataset . Do đó, để lặp lại các tập dữ liệu, bạn sẽ cần sử dụng các cấu trúc khai báo của TFF cho các hoạt động trên chuỗi, chẳng hạn như cấu trúc được hiển thị bên dưới:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Hãy xem nó hoạt động:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Máy tính thực hiện một vòng đào tạo giống như máy tính mà bạn có thể đã thấy trong hướng dẫn của TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Hãy xem nó hoạt động:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Như bạn thấy, việc sử dụng JAX trong TFF, cho dù thông qua các API đóng hộp, hay trực tiếp sử dụng các cấu trúc TFF cấp thấp, đều tương tự như sử dụng TFF với TensorFlow. Hãy theo dõi các bản cập nhật trong tương lai và nếu bạn muốn được hỗ trợ tốt hơn cho khả năng tương tác trên các khuôn khổ ML, vui lòng gửi yêu cầu kéo cho chúng tôi!