tfc.entropy_models.ContinuousBatchedEntropyModel

Batched entropy model for continuous random variables.

This entropy model handles quantization of a bottleneck tensor and helps with training of the parameters of the probability distribution modeling the tensor (a shared "prior" between sender and receiver). It also pre-computes integer probability tables, which can then be used to compress and decompress bottleneck tensors reliably across different platforms.

A typical workflow looks like this:

  • Train a model using an instance of this entropy model as a bottleneck, passing the bottleneck tensor through it. With training=True, the model computes a differentiable upper bound on the number of bits needed to compress the bottleneck tensor.
  • For evaluation, get a closer estimate of the number of compressed bits using training=False.
  • Instantiate an entropy model with compression=True (and the same parameters as during training), and share the model between a sender and a receiver.
  • On the sender side, compute the bottleneck tensor and call compress() on it. The output is a compressed string representation of the tensor. Transmit the string to the receiver, and call decompress() there. The output is the quantized bottleneck tensor. Continue processing the tensor on the receiving side.

This class assumes that all scalar elements of the encoded tensor are statistically independent, and that the parameters of their scalar distributions do not depend on data. The innermost dimensions of the bottleneck tensor must be broadcastable to the batch shape of prior. Any dimensions to the left of the batch shape are assumed to be i.i.d., i.e. the likelihoods are broadcast to the bottleneck tensor accordingly.

A more detailed description (and motivation) of this way of performing quantization and range coding can be found in the following paper. Please cite the paper when using this code for derivative work.

"End-to-end Optimized Image Compression"
J. Ballé, V. Laparra, E.P. Simoncelli
https://openreview.net/forum?id=rJxdQ3jeg

Entropy models which contain range coding tables (i.e. with compression=True) can be instantiated in three ways:

  • By providing a continuous "prior" distribution object. The range coding tables are then derived from that continuous distribution.
  • From a config as returned by get_config, followed by a call to set_weights. This implements the Keras serialization protocol. In this case, the initializer creates empty state variables for the range coding tables, which are then filled by set_weights. As a consequence, this method requires stateless=False.
  • In a more low-level way, by directly providing the range coding tables to __init__, for use cases where the Keras protocol can't be used (e.g., when the entropy model must not create variables).

The quantization_offset and offset_heuristic arguments control whether quantization is performed with respect to integer values, or potentially non-integer offsets (i.e., y = tf.round(x - o) + o). There are three modes of operation:

  • If quantization_offset is provided manually (not None), these values are used and offset_heuristic is ineffective.
  • Otherwise, if offset_heuristic and compression, the offsets are computed once on initialization, and then fixed. If the entropy model is serialized, they are preserved.
  • Otherwise, if offset_heuristic and not compression, the offsets are recomputed every time quantization is performed. Note this may be computationally expensive when the prior does not have a mode that is computable in closed form (e.g. for NoisyDeepFactorized).

This offset heuristic is discussed in Section III.A of:

"Nonlinear Transform Coding"
J. Ballé, P.A. Chou, D. Minnen, S. Singh, N. Johnston, E. Agustsson, S.J. Hwang, G. Toderici
https://doi.org/10.1109/JSTSP.2020.3034501

prior A tfp.distributions.Distribution object. A density model fitting the marginal distribution of the bottleneck data with additive uniform noise, which is shared a priori between the sender and the receiver. For best results, the distribution should be flexible enough to have a unit-width uniform distribution as a special case, since this is the marginal distribution for bottleneck dimensions that are constant. The distribution parameters may not depend on data (they must be either variables or constants).
coding_rank Integer. Number of innermost dimensions considered a coding unit. Each coding unit is compressed to its own bit string, and the bits in the call method are summed over each coding unit.
compression Boolean. If set to True, the range coding tables used by compress() and decompress() will be built on instantiation. If set to False, these two methods will not be accessible.
stateless Boolean. If False, range coding tables are created as Variables. This allows the entropy model to be serialized using the SavedModel protocol, so that both the encoder and the decoder use identical tables when loading the stored model. If True, creates range coding tables as Tensors. This makes the entropy model stateless and allows it to be constructed within a tf.function body, for when the range coding tables are provided manually. If compression=False, then stateless=True is implied and the provided value is ignored.
expected_grads If True, will use analytical expected gradients during backpropagation w.r.t. additive uniform noise.
tail_mass Float. Approximate probability mass which is encoded using an Elias gamma code embedded into the range coder.
range_coder_precision Integer. Precision passed to the range coding op.
bottleneck_dtype tf.dtypes.DType. Data type of bottleneck tensor. Defaults to tf.keras.mixed_precision.global_policy().compute_dtype.
prior_shape Batch shape of the prior (dimensions which are not assumed i.i.d.). Must be provided if prior is omitted.
cdf tf.Tensor or None. If provided, is used for range coding rather than tables built from the prior.
cdf_offset tf.Tensor or None. Must be provided along with cdf.
cdf_shapes Shapes of cdf and cdf_offset. If provided, empty range coding tables are created, which can then be restored using set_weights. Requires compression=True and stateless=False.
offset_heuristic Boolean. Whether to quantize to non-integer offsets heuristically determined from mode/median of prior. Set this to False if you are using soft quantization during training.
quantization_offset tf.Tensor or None. The quantization offsets to use. If provided (not None), then offset_heuristic is ineffective.
decode_sanity_check Boolean. If True, an raises an error if the binary strings passed into decompress are not completely decoded.
laplace_tail_mass Float. If positive, will augment the prior with a Laplace mixture for training stability. (experimental)

bottleneck_dtype Data type of the bottleneck tensor.
cdf The CDFs used by range coding.
cdf_offset The CDF offsets used by range coding.
coding_rank Number of innermost dimensions considered a coding unit.
compression Whether this entropy model is prepared for compression.
expected_grads Whether to use analytical expected gradients during backpropagation.
laplace_tail_mass Whether to augment the prior with a Laplace mixture.
name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
offset_heuristic Whether to use heuristic to determine quantization offsets.
prior Prior distribution, used for deriving range coding tables.
prior_shape Batch shape of prior (dimensions which are not assumed i.i.d.).
prior_shape_tensor Batch shape of prior as a Tensor.
quantization_offset The quantization offset used in quantize and compress.
range_coder_precision Precision used in range coding op.
stateless Whether range coding tables are created as Tensors or Variables.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

tail_mass Approximate probability mass which is range encoded with overflow.
trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.

Methods

compress

View source

Compresses a floating-point tensor.

Compresses the tensor to bit strings. bottleneck is first quantized as in quantize(), and then compressed using the probability tables in self.cdf (derived from self.prior). The quantized tensor can later be recovered by calling decompress().

The innermost self.coding_rank dimensions are treated as one coding unit, i.e. are compressed into one string each. Any additional dimensions to the left are treated as batch dimensions.

Args
bottleneck tf.Tensor containing the data to be compressed. Must have at least self.coding_rank dimensions, and the innermost dimensions must be broadcastable to self.prior_shape.

Returns
A tf.Tensor having the same shape as bottleneck without the self.coding_rank innermost dimensions, containing a string for each coding unit.

decompress

View source

Decompresses a tensor.

Reconstructs the quantized tensor from bit strings produced by compress(). It is necessary to provide a part of the output shape in broadcast_shape.

Args
strings tf.Tensor containing the compressed bit strings.
broadcast_shape Iterable of ints. The part of the output tensor shape between the shape of strings on the left and self.prior_shape on the right. This must match the shape of the input to compress().

Returns
A tf.Tensor of shape strings.shape + broadcast_shape + self.prior_shape.

get_config

View source

Returns the configuration of the entropy model.

Returns
A JSON-serializable Python dict.

get_weights

View source

quantize

View source

Quantizes a floating-point bottleneck tensor.

The tensor is rounded to integer values potentially shifted by offsets (if self.quantization_offset is not None). These offsets can depend on self.prior. For instance, for a Gaussian distribution, when self.offset_heuristic == True, the returned values would be rounded to the location of the mode of the distribution plus or minus an integer.

The gradient of this rounding operation is overridden with the identity (straight-through gradient estimator).

Args
bottleneck tf.Tensor containing the data to be quantized. The innermost dimensions must be broadcastable to self.prior_shape.

Returns
A tf.Tensor containing the quantized values.

set_weights

View source

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

__call__

View source

Perturbs a tensor with (quantization) noise and estimates rate.

Args
bottleneck tf.Tensor containing the data to be compressed. Must have at least self.coding_rank dimensions, and the innermost dimensions must be broadcastable to self.prior_shape.
training Boolean. If False, computes the Shannon information of bottleneck under the distribution self.prior, which is a non-differentiable, tight lower bound on the number of bits needed to compress bottleneck using compress(). If True, returns a somewhat looser, but differentiable upper bound on this quantity.

Returns
A tuple (bottleneck_perturbed, bits) where bottleneck_perturbed is bottleneck perturbed with (quantization) noise, and bits is the rate. bits has the same shape as bottleneck without the self.coding_rank innermost dimensions.