View source on GitHub

Standardized representation of logic deployable to MapReduce-like systems.

This standardized representation can be used to describe a range of iterative processes representable as a single round of MapReduce-like processing, and deployable to MapReduce-like systems that are only capable of executing plain TensorFlow code.

Non-iterative processes, or processes that do not originate at the server can be described as well as degenerate cases, but we standardize on the iterative process here as the basis for defining the canonical representation, as this type of processing is most common in federated learning scenarios, where different rounds typically involve different subsets of a potentially very large number of participating clients, and the iterative aspect of the computation allows for it to not only model processes that evolve over time, but also ones that might involve a very large client population in which not all participants (clients, data shards, etc.) may be present at the same time, and the iterative approach may instead be dictated by data availability or scalability considerations. Related to the above, the fact that in practical scenarios the set of clients involved in a federated computation will (not always, but usually) vary from round to round, the server state is necessarily what connects subsequent rounds into a single contiguous logical sequence.

Instances of this class can be generated by TFF's transformation pipeline and consumed by a variety of backends that have the ability to orchestrate their execution in a MapReduce-like fashion. The latter can include systems that run static data pipelines such Apache Beam or Hadoop, but also platforms like that which has been described in the following paper:

"Towards Federated Learning at Scale: System Design"

It should be noted that not every computation that proceeds in synchronous rounds is representable as an instance of this class. In particular, this representation is not suitable for computations that involve multiple phases of processing, and does not generalize to arbitrary static data pipelines. Generalized representations that can take advantage of the full expressiveness of Apache Beam-like systems may emerge at a later time, and will be supported by a separate set of tools, with a more expressive canonical representation.

We say that a tff.templates.IterativeProcess is in the canonical form for a simple MapReduce-like platform if its iterative component (next) can be converted into a semantically equivalent instance of the computation template shown below, with all the variable constituents of the template representable as simple sections of TensorFlow logic. A process in such form can then be equivalently represented by listing only the variable constituents of the template (the tuple of TensorFlow computations that appear as parameters of the federated operators). Such compact representations can be encapsulated as instances of this class.

The requirement that the variable constituents of the template be in the form of pure TensorFlow code (not arbitrary TFF constructs) reflects the intent for instances of this class to be easily converted into a representation that can be compiled into a system that does not have the ability to interpret the full TFF language (as defined in computation.proto), but that does have the ability to run TensorFlow. Client-side logic in such systems could be deployed in a number of ways, e.g., as shards in a MapReduce job, to mobile or embedded devices, etc.

Conceptually, next, as the iterator part of an iterative process, is modeled in the same way as any stateful computation in TFF, i.e., one that takes the server state as the first component of the input, and returns updated server state as the first component of the output. If there is no need for server state, the input/output state should be modeled as an empty tuple.

In addition to updating state, next additionally takes client-side data as input, and can produce results on server-side and/or on the client-side. As is the case for the server state, if either of these is undesired, it should be modeled as an empty tuple.

The type signature of next, in the concise TFF type notation (as defined in TFF's computation.proto), is as follows:


The above type signature involves the following abstract types:

  • S is the type of the state that is passed at the server between rounds of processing. For example, in the context of federated training, the server state would typically include the weights of the model being trained. The weights would be updated in each rounds as the model is trained on more and more of the clients' data, and hence the server state would evolve as well.

  • D represents the type of per-client units of data that serve as the input to the computation. Often, this would be a sequence type, i.e., a data set in TensorFlow's parlance, although strictly speaking, this does not have to always be the case.

  • X represents the type of server-side outputs generated by the server after each round.

  • Y represents the type of client-side outputs generated on the clients in each round. Most computations would not involve client-side outputs.

One can think of the process based on this representation as being equivalent to the following loop:

client_data = ...
server_state = initialize()
while True:
  server_state, server_outputs, client_outputs = (
    next(server_state, client_data))

The logic of next in this form is defined as follows in terms of the seven variable components prepare, work, zero, accumulate, merge, report, and update (in addition to initialize that produces the server state component for the initial round). The code below uses common syntactic shortcuts (such as federated zipping and unzipping) for brevity.

def next(server_state, client_data):

  # The server prepares an input to be broadcast to all clients that controls
  # what will happen in this round.

  client_input = (
    tff.federated_broadcast(tff.federated_map(prepare, server_state)))

  # The clients all independently do local work and produce updates, plus the
  # optional client-side outputs.

  client_updates, client_outputs = (
    tff.federated_map(work, [client_data, client_input]))

  # The client updates are aggregated across the system into a single global
  # update at the server.

  global_update = (
    tff.federated_aggregate(client_updates, zero, accumulate, merge, report))

  # Finally, the server produces a new state as well as server-side output to
  # emit from this round.

  new_server_state, server_output = (
    tff.federated_map(update, [server_state, global_update]))

  # The updated server state, server- and client-side outputs are returned as
  # results of this round.

  return new_server_state, server_output, client_outputs

The above characterization of next depends on the seven pieces of what must be pure TensorFlow logic, defined as follows. Please also consult the documentation for related federated operators for more detail (particularly the tff.federated_aggregate(), as sever of the components below correspond directly to the parameters of that operator).

  • prepare represents the preparatory steps taken by the server to generate inputs that will be broadcast to the clients and that, together with the client data, will drive the client-side work in this round. It takes the initial state of the server, and produces the input for use by the clients. Its type signature is (S -> C).

  • work represents the totality of client-side processing, again all as a single section of TensorFlow code. It takes a tuple of client data and client input that was broadcasted by the server, and returns a tuple that consists of a client update to be aggregated (across all the clients) along with possibly local per-client output (to be consumed locally by clients, as it is not being aggregated). Its type signature is (<D,C> -> <U,Y>). As noted earlier, the per-client local outputs will typically be absent. An example use of such outputs might involve, e.g., debugging metrics that might be preserved locally (but not aggregated globally).

  • zero is the TensorFlow computation that produces the initial state of accumulators that are used to combine updates collected from subsets of the client population. In some systems, all accumulation may happen at the server, but for scalability reasons, it is often desirable to structure aggregation in multiple tiers. Its type signature is A, or when represented as a tff.Computation in Python, ( -> A).

  • accumulate is the TensorFlow computation that updates the state of an update accumulator (initialized with zero above) with a single client's update. Its type signature is (<A,U> -> A). Typically, a single acumulator would be used to combine the updates from multiple clients, but this does not have to be the case (it's up to the target deployment platform to choose how to use this logic in a particular deployment scenario).

  • merge is the TensorFlow computation that merges two accumulators holding the results of aggregation over two disjoint subsets of clients. Its type signature is (<A,A> -> A).

  • report is the TensorFlow computation that transforms the state of the top-most accumulator (after accumulating updates from all clients and merging all the resulting accumulators into a single one at the top level of the system hierarchy) into the final result of aggregation. Its type signature is (A -> R).

  • update is the TensorFlow computation that applies the aggregate of all clients' updates (the output of report), also referred to above as the global update, to the server state, to produce a new server state to feed into the next round, and that additionally outputs a server-side output, to be reported externally as one of the results of this round. In federated learning scenarios, the server-side outputs might include things like loss and accuracy metrics, and the server state to be carried over, as noted above, may include the model weights to be trained further in a subsequent round. The type signature of this computation is (<S,R> -> <S,X>).

The above TensorFlow computations' type signatures involves the following abstract types in addition to those defined earlier:

  • C is the type of the inputs for the clients, to be supplied by the server at the beginning of each round (or an empty tuple if not needed).

  • U is the type of the per-client update to be produced in each round and fed into the cross-client aggregation protocol.

  • A is the type of the accumulators used to combine updates from subsets of clients.

  • R is the type of the final result of aggregating all client updates, the global update to be incorporated into the server state at the end of the round (and to produce the server-side output).

The individual TensorFlow computations that constitute an iterative process in this form are supplied as constructor arguments.

initialize The computation that produces the initial server state.
prepare The computation that prepares the input for the clients.
work The client-side work computation.
zero The computation that produces the initial state for accumulators.
accumulate The computation that adds a client update to an accumulator.
merge The computation to use for merging pairs of accumulators.
report The computation that produces the final server-side aggregate for the top level accumulator (the global update).
bitwidth The computation that produces the bitwidth for secure sum.
update The computation that takes the global update and the server state and produces the new server state, as well as server-side output.

TypeError If the Python or TFF types of the arguments are invalid or not compatible with each other.
AssertionError If the manner in which the given TensorFlow computations are represented by TFF does not match what this code is expecting (this is an internal error that requires code update).












View source

Prints a string summary of the CanonicalForm.

print_fn Print function to use. It will be called on each line of the summary in order to capture the string summary.