Custom differentiation

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial will show you how to define your own custom derivatives, perform derivative surgery, and implement your own gradient checkpointing API in just 5 lines of Swift.

Declaring custom derivatives

You can define custom derivatives for any Swift function that has differentiable parameters and results. By doing that, you can even import a C function and make it differentiable.

import Glibc

func sillyExp(_ x: Float) -> Float {
    let 𝑒 = Float(M_E)
    print("Taking 𝑒(\(𝑒)) to the power of \(x)!")
    return pow(𝑒, x)
}

@derivative(of: sillyExp)
func sillyDerivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    let y = sillyExp(x)
    return (value: y, pullback: { v in v * y })
}

print("exp(3) =", sillyExp(3))
print("𝛁exp(3) =", gradient(of: sillyExp)(3))
Taking 𝑒(2.7182817) to the power of 3.0!
exp(3) = 20.085535
Taking 𝑒(2.7182817) to the power of 3.0!
𝛁exp(3) = 20.085535

Stop derivatives from propagating

Commonly known as "stop gradient" in machine learning use cases, method withoutDerivative(at:) stops derivatives from propagating.

Plus, withoutDerivative(at:) can sometimes help the Swift compiler with identifying what not to differentiate and producing more efficient derivatives. When it is detectable that the derivative of a function will always be zero, the Swift compiler will produce a warning. Explicitly using withoutDerivative(at:) silences that warning.

let x: Float = 2.0
let y: Float = 3.0
let xyGradient = gradient(at: x, y) { x, y in
    sin(sin(sin(x))) + withoutDerivative(at: cos(cos(cos(y))))
}
print(xyGradient)
(-0.18009877, 0.0)

Derivative surgery

Method withDerivative(_:) makes arbitrary operations (including mutation) run on the gradient at a value during the enclosing function’s backpropagation.

Use this to debug or make experimental tweaks to backpropagation.

It works anywhere

All differentiation APIs provided by the standard library are defined generically over all types that conform to the Differentiable protocol: Float, Double, Float80, SIMD vectors, and even your own types!

Read technical document Differentiable Types for more insights on the Differentiable protocol.

var x: Float = 30
let xGradient = gradient(at: x) { x -> Float in
    // Print the partial derivative with respect to the result of `sin(x)`.
    let a = sin(x).withDerivative { print("∂+/∂sin = \($0)") } 
    // Force the partial derivative with respect to `x` to be `0.5`.
    let b = log(x.withDerivative { (dx: inout Float) in
        print("∂log/∂x = \(dx), but rewritten to 0.5");
        dx = 0.5
    })
    return a + b
}
print(xGradient)
∂log/∂x = 0.033333335, but rewritten to 0.5
∂+/∂sin = 1.0
0.65425146

Use it in a neural network module

Just like how we used it in a simple Float function, we can use it in any numerical application, like the following neural network built using the Swift for TensorFlow Deep Learning Library.

import TensorFlow

struct MLP: Layer {
    var layer1 = Dense<Float>(inputSize: 2, outputSize: 10, activation: relu)
    var layer2 = Dense<Float>(inputSize: 10, outputSize: 1, activation: relu)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        let h0 = layer1(input).withDerivative { print("∂L/∂layer1 =", $0) }
        return layer2(h0)
    }
}

var classifier = MLP()
let optimizer = SGD(for: classifier, learningRate: 0.02)

let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
let y: Tensor<Float> = [0, 1, 1, 0]

for _ in 0..<10 {
    let 𝛁model = gradient(at: classifier) { classifier -> Tensor<Float> in
        let ŷ = classifier(x).withDerivative { print("∂L/∂ŷ =", $0) }
        let loss = (ŷ - y).squared().mean()
        print("Loss: \(loss)")
        return loss
    }
    optimizer.update(&classifier, along: 𝛁model)
}
Loss: 0.45304087
∂L/∂ŷ = [[     -0.25],
 [     -0.25],
 [-0.2143442],
 [-0.1791575]]
∂L/∂layer1 = [[         0.0,          0.0,          0.0,          0.0,          0.0,          0.0,
           0.0,          0.0,          0.0,          0.0],
 [         0.0,          0.0,          0.0,          0.0,          0.0,          0.0,
           0.0,          0.0,          0.0,          0.0],
 [-0.046330024,  -0.07919147, -0.077494234,  -0.07907715,   0.14447221,  -0.07965051,
     0.0873662, -0.016764779,    0.1293755,  0.027867926],
 [-0.038724493, -0.066191405,   -0.0647728,  -0.06609586,   0.12075568,  -0.06657509,
    0.07302418, -0.014012676,  0.108137235,  0.023293132]]
Loss: 0.43502235
∂L/∂ŷ = [[-0.24459878],
 [-0.24358931],
 [-0.19911093],
 [-0.16190395]]
∂L/∂layer1 = [[-0.053103957,  -0.09203638,   -0.0885385,  -0.09065656,   0.16429774, -0.090893134,
    0.09901551, -0.019131118,   0.14763679,   0.03180147],
 [-0.052884795,  -0.09165655,   -0.0881731,  -0.09028242,   0.16361968,  -0.09051801,
    0.09860687, -0.019052165,   0.14702748,  0.031670224],
 [-0.043228254, -0.074920446, -0.072073065, -0.073797226,   0.13374342,   -0.0739898,
    0.08060167, -0.015573319,   0.12018088,  0.025887374],
 [-0.035150383, -0.060920395, -0.058605086,  -0.06000707,   0.10875137, -0.060163658,
    0.06553999, -0.012663202,   0.09772321,  0.021049915]]
Loss: 0.40576553
∂L/∂ŷ = [[-0.23289952],
 [-0.22639728],
 [-0.17728773],
 [-0.13724682]]
∂L/∂layer1 = [[-0.050774142,  -0.08952092, -0.084402055, -0.086720824,   0.15596299, -0.086545676,
    0.09358021,  -0.01821607,    0.1403872,  0.030280393],
 [-0.049356595,  -0.08702162,  -0.08204567,   -0.0842997,    0.1516087,  -0.08412944,
    0.09096757, -0.017707502,   0.13646778,  0.029435005],
 [ -0.03865028,   -0.0681451,  -0.06424852,  -0.06601361,   0.11872211,  -0.06588028,
   0.071235105, -0.013866433,  0.106865525,  0.023050034],
 [-0.029921012, -0.052754343, -0.049737815,  -0.05110426,    0.0919084, -0.051001046,
   0.055146467, -0.010734662,   0.08272966,  0.017844122]]
Loss: 0.38182113
∂L/∂ŷ = [[ -0.22214013],
 [ -0.21068493],
 [ -0.15761846],
 [-0.115079075]]
∂L/∂layer1 = [[-0.048611242,  -0.08700116,  -0.08059354,  -0.08307868,   0.14837542,  -0.08254748,
    0.08869235, -0.017374532,   0.13374089,  0.028881513],
 [ -0.04610448,  -0.08251473,  -0.07643753, -0.078794524,   0.14072408, -0.078290716,
    0.08411872, -0.016478572,    0.1268442,  0.027392166],
 [ -0.03449187, -0.061731257,  -0.05718476,  -0.05894808,  0.105279066, -0.058571167,
    0.06293123, -0.012328016,    0.0948952,  0.020492738],
 [-0.025182918, -0.045070708, -0.041751258,  -0.04303868,   0.07686547,  -0.04276349,
   0.045946825, -0.009000828,   0.06928409,  0.014961987]]
Loss: 0.36222494
∂L/∂ŷ = [[ -0.2122466],
 [-0.19632757],
 [-0.13990551],
 [-0.09517485]]
∂L/∂layer1 = [[ -0.046605036,   -0.08450727,  -0.077087075,   -0.07970615,    0.14145951,  -0.078871034,
     0.08428629,  -0.016600717,    0.12764633,   0.027595207],
 [ -0.043109544,   -0.07816901,   -0.07130535,   -0.07372799,    0.13084969,  -0.072955504,
    0.077964604, -0.0153556205,    0.11807254,   0.025525497],
 [ -0.030720405,   -0.05570423,  -0.050813094,  -0.052539498,    0.09324514,   -0.05198902,
    0.055558562,   -0.01094261,    0.08413999,   0.018189792],
 [ -0.020898461,  -0.037894443,  -0.034567107,   -0.03574154,    0.06343276,   -0.03536706,
     0.03779535,  -0.007444033,   0.057238705,   0.012374142]]
Loss: 0.34618416
∂L/∂ŷ = [[-0.20314947],
 [ -0.1832107],
 [-0.12396976],
 [-0.07732913]]
∂L/∂layer1 = [[  -0.04474547,  -0.082062505,   -0.07385858,   -0.07658187,    0.13514856,   -0.07549053,
     0.08030583,   -0.01588919,   0.122056164,   0.026412444],
 [  -0.04035378,   -0.07400821,   -0.06660949,    -0.0690655,   0.121883966,   -0.06808127,
     0.07242396,  -0.014329694,    0.11007657,    0.02382011],
 [  -0.02730544,  -0.050077755,    -0.0450714,  -0.046733256,    0.08247295,   -0.04606728,
    0.049005765,  -0.009696207,   0.074483454,   0.016117908],
 [ -0.017032426,  -0.031237207,  -0.028114373,  -0.029150996,    0.05144449,  -0.028735576,
    0.030568527, -0.0060482426,   0.046460852,  0.0100539345]]
Loss: 0.33304712
∂L/∂ŷ = [[ -0.19478384],
 [  -0.1712287],
 [ -0.10964805],
 [-0.061354905]]
∂L/∂layer1 = [[ -0.04302273,  -0.07968434,  -0.07088566,   -0.0736866,   0.12938349, -0.072381854,
   0.076702625, -0.015234879,   0.11692673,  0.025324788],
 [ -0.03782001, -0.070048146, -0.062313486,  -0.06477571,   0.11373719,  -0.06362875,
      0.067427, -0.013392531,   0.10278683,  0.022262271],
 [-0.024218429,  -0.04485604, -0.039903075, -0.041479785,   0.07283277, -0.040745318,
    0.04317757, -0.008576044,    0.0658206,  0.014255873],
 [-0.013551718, -0.025099747, -0.022328254, -0.023210522,  0.040754467,  -0.02279954,
   0.024160538,  -0.00479883,   0.03683072,  0.007977048]]
Loss: 0.32227832
∂L/∂ŷ = [[  -0.187089],
 [-0.16028392],
 [-0.09679102],
 [-0.04708069]]
∂L/∂layer1 = [[ -0.041427277,   -0.07738533,   -0.06814741,  -0.071002685,   0.124111414,   -0.06952245,
     0.07343468, -0.0146330325,    0.11221778,   0.024324344],
 [  -0.03549181,  -0.066297986,   -0.05838363,  -0.060829815,    0.10632942,  -0.059561655,
    0.062913366,  -0.012536493,    0.09613983,   0.020839289],
 [  -0.02143252,   -0.04003552,  -0.035256255,  -0.036733437,   0.064209394,  -0.035967633,
     0.03799164,  -0.007570441,   0.058056183,   0.012584269],
 [ -0.010425118,   -0.01947391,  -0.017149203,  -0.017867727,   0.031232467,  -0.017495228,
    0.018479737, -0.0036823824,   0.028239448,   0.006121188]]
Loss: 0.3134383
∂L/∂ŷ = [[ -0.18000817],
 [ -0.15028599],
 [ -0.08526195],
 [-0.034349076]]
∂L/∂layer1 = [[ -0.039949864,   -0.07517394,  -0.065624304,   -0.06851376,   0.119284846,    -0.0668912,
     0.07046529,  -0.014079211,     0.1078921,   0.023403734],
 [ -0.033353515,   -0.06276154,  -0.054788698,   -0.05720106,    0.09958904,   -0.05584641,
     0.05883036,  -0.011754512,   0.090077415,   0.019539408],
 [ -0.018922493,  -0.035606585,  -0.031083344,  -0.032451954,   0.056499984,   -0.03168342,
    0.033376306, -0.0066687027,   0.051103737,   0.011085318],
 [-0.0076232147,  -0.014344656, -0.0125223985,  -0.013073765,    0.02276188,  -0.012764148,
    0.013446154, -0.0026865886,   0.020587921,  0.0044658897]]
Loss: 0.30616698
∂L/∂ŷ = [[ -0.17348853],
 [ -0.14115131],
 [-0.074935496],
 [-0.023015507]]
∂L/∂layer1 = [[ -0.038581613,   -0.07305531,  -0.063298136,   -0.06620461,    0.11486097,  -0.064468496,
    0.067762226,  -0.013569281,   0.103915446,   0.022556083],
 [ -0.031390235,  -0.059438244,  -0.051499747,  -0.053864464,   0.093451574,  -0.052451957,
    0.055131756,  -0.011040049,    0.08454623,   0.018351763],
 [  -0.01666469,  -0.031555034,  -0.027340584,  -0.028595984,    0.04961229,    -0.0278461,
    0.029268773, -0.0058610262,   0.044884555,   0.009742727],
 [-0.0051183524,  -0.009691737,   -0.00839732,  -0.008782901,   0.015237799,  -0.008552584,
     0.00898954, -0.0018001414,   0.013785734,  0.0029923574]]

Recomputing activations during backpropagation to save memory (checkpointing)

Checkpointing is a traditional technique in reverse-mode automatic differentiation for saving memory. Rather than saving large intermediate values in the original computation for computing derivatives, the intermediate values are instead recomputed as needed during backpropagation.

This technique has been realized in modern deep learning libraries as well. In Swift, API withRecomputationInPullbacks(_:) enables you to control what to recompute during backpropagation, and it is available on all Differentiable types.

But today, let us learn how to define our own gradient checkpointing APIs from scratch, in just a few lines of code.

Our gradient checkpointing API

We can define our own gradient checkpointing API, makeRecomputedInGradient(_:), in terms of standard library function differentiableFunction(from:), which is a shorthand for creating a differentiable function directly from a derivative function (also called a "vector-Jacobian products (VJP) function").

As we have seen before, the derivative function returns a tuple of the original function's result and a pullback closure. We return original(x) in value:, and call pullback(at:in:) on original to evaluate the original function again and get a pullback.

/// Given a differentiable function, returns the same differentiable function except when
/// derivatives of this function are being computed. In that case, values in the original function needed
/// for computing the derivatives will be recomputed, instead of being captured by the differential or pullback.
///
/// - Parameter body: The body of the differentiable function.
/// - Returns: The same differentiable function whose derivatives, when computed, will recompute
///   some values from the original function.
func makeRecomputedInGradient<T: Differentiable, U: Differentiable>(
    _ original: @escaping @differentiable (T) -> U
) -> @differentiable (T) -> U {
    return differentiableFunction { x in
        (value: original(x), pullback: { v in pullback(at: x, in: original)(v) })
    }
}

Verify it works

let input: Float = 10.0
print("Running original computation...")

// Differentiable multiplication with checkpointing.
let square = makeRecomputedInGradient { (x: Float) -> Float in
    print("  Computing square...")
    return x * x
}

// Differentiate `f(x) = (cos(x))^2`.
let (output, backprop) = valueWithPullback(at: input) { input -> Float in
    return square(cos(input))
}
print("Running backpropagation...")
let grad = backprop(1)
print("Gradient = \(grad)")
Running original computation...
  Computing square...
Running backpropagation...
  Computing square...
Gradient = -0.9129453

Extend it to neural network modules

In this example, we define a simple convolutional neural network.

struct Model: Layer {
    var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6))
    var maxPool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
    var flatten = Flatten<Float>()
    var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10)

    @differentiable
    func call(_ input: Tensor<Float>) -> Tensor<Float> {
        return input.sequenced(through: conv, maxPool, flatten, dense)
    }
}

We want to make activations in the convolution layer (conv) be recomputed during backpropagation. However, using makeRecomputedInGradient(_:) could make the resulting code look cumbersome, especially when we want to apply layers sequentially using sequenced(in:through:_:_:_:_:).

input.sequenced(in: context, through: conv, maxPool, flatten, dense)

So, why don't we define a special layer type that wraps a layer and makes its activations be recomputed during backpropagation? Let's do it.

First, we define a makeRecomputedInGradient(_:) function that takes a binary function.

// Same as the previous `makeRecomputedInGradient(_:)`, except it's for binary functions.
func makeRecomputedInGradient<T: Differentiable, U: Differentiable, V: Differentiable>(
    _ original: @escaping @differentiable (T, U) -> V
) -> @differentiable (T, U) -> V {
    return differentiableFunction { x, y in
        (value: original(x, y), pullback: { v in pullback(at: x, y, in: original)(v) })
    }
}

Then, we define a generic layer ActivationDiscarding<Wrapped>.

import TensorFlow

/// A layer wrapper that makes the underlying layer's activations be discarded during application
/// and recomputed during backpropagation.
struct ActivationDiscarding<Wrapped: Layer>: Layer {
    /// The wrapped layer.
    var wrapped: Wrapped

    @differentiable
    func callAsFunction(_ input: Wrapped.Input) -> Wrapped.Output {
        let apply = makeRecomputedInGradient { (layer: Wrapped, input: Input) -> Wrapped.Output in
            print("    Applying \(Wrapped.self) layer...")
            return layer(input)
        }
        return apply(wrapped, input)
    }
}

Finally, we can add a method on all layers that returns the same layer except its activations are discarded during application and recomputed during backpropagation.

extension Layer {
    func discardingActivations() -> ActivationDiscarding<Self> {
        return ActivationDiscarding(wrapped: self)
    }
}

Back in the model, all we have to change is to wrap the convolution layer into the activation-recomputing layer.

var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)).discardingActivations()

Now, simply use it in the model!

struct Model: Layer {
    var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)).discardingActivations()
    var maxPool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
    var flatten = Flatten<Float>()
    var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        return input.sequenced(through: conv, maxPool, flatten, dense)
    }
}

When we run a training loop, we can see that the convolution layer's activations are computed twice: once during layer application, and once during backpropagation.

// Use random training data.
let x = Tensor<Float>(randomNormal: [10, 16, 16, 3])
let y = Tensor<Int32>(rangeFrom: 0, to: 10, stride: 1)

var model = Model()
let opt = SGD(for: model)

for i in 1...5 {
    print("Starting training step \(i)")
    print("  Running original computation...")
    let (logits, backprop) = model.appliedForBackpropagation(to: x)
    let (loss, dL_dŷ) = valueWithGradient(at: logits) { logits in
        softmaxCrossEntropy(logits: logits, labels: y)
    }
    print("  Loss: \(loss)")
    print("  Running backpropagation...")
    let (dL_dθ, _) = backprop(dL_dŷ)

    opt.update(&model, along: dL_dθ)
}
Starting training step 1
  Running original computation...
    Applying Conv2D<Float> layer...
  Loss: 2.6726463
  Running backpropagation...
    Applying Conv2D<Float> layer...
Starting training step 2
  Running original computation...
    Applying Conv2D<Float> layer...
  Loss: 2.3370266
  Running backpropagation...
    Applying Conv2D<Float> layer...
Starting training step 3
  Running original computation...
    Applying Conv2D<Float> layer...
  Loss: 2.0828948
  Running backpropagation...
    Applying Conv2D<Float> layer...
Starting training step 4
  Running original computation...
    Applying Conv2D<Float> layer...
  Loss: 1.8765408
  Running backpropagation...
    Applying Conv2D<Float> layer...
Starting training step 5
  Running original computation...
    Applying Conv2D<Float> layer...
  Loss: 1.701678
  Running backpropagation...
    Applying Conv2D<Float> layer...

Just like that, it is super easy to define generic differentiable programming libraries for different domains.