# Custom differentiation

This tutorial will show you how to define your own custom derivatives, perform derivative surgery, and implement your own gradient checkpointing API in just 5 lines of Swift.

## Declaring custom derivatives

You can define custom derivatives for any Swift function that has differentiable parameters and results. By doing that, you can even import a C function and make it differentiable.

import Glibc

func sillyExp(_ x: Float) -> Float {
let 𝑒 = Float(M_E)
print("Taking 𝑒(\(𝑒)) to the power of \(x)!")
return pow(𝑒, x)
}

@derivative(of: sillyExp)
func sillyDerivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
let y = sillyExp(x)
return (value: y, pullback: { v in v * y })
}

print("exp(3) =", sillyExp(3))

Taking 𝑒(2.7182817) to the power of 3.0!
exp(3) = 20.085535
Taking 𝑒(2.7182817) to the power of 3.0!
𝛁exp(3) = 20.085535



## Stop derivatives from propagating

Commonly known as "stop gradient" in machine learning use cases, method withoutDerivative(at:) stops derivatives from propagating.

Plus, withoutDerivative(at:) can sometimes help the Swift compiler with identifying what not to differentiate and producing more efficient derivaitves. When it is detectable that the derivative of a function will always be zero, the Swift compiler will produce a warning. Explicitly using withoutDerivative(at:) silences that warning.

let x: Float = 2.0
let y: Float = 3.0
print(gradient(at: x, y) { x, y in
sin(sin(sin(x))) + withoutDerivative(at: cos(cos(cos(y))))
})

(-0.18009877, 0.0)



## Derivative surgery

Method withDerivative(_:) makes arbitrary operations (including mutation) run on the gradient at a value during the enclosing function’s backpropagation.

Use this to debug or make experimental tweaks to backpropagation.

### It works anywhere

All differentiation APIs provided by the standard library are defined generically over all types that conform to the Differentiable protocol: Float, Double, Float80, SIMD vectors, and even your own types!

Read technical document Differentiable Types for more insights on the Differentiable protocol.

var x: Float = 30
print(gradient(at: x) { x -> Float in
// Print the partial derivative with respect to the result of sin(x).
let a = sin(x).withDerivative { print("∂+/∂sin = \($0)") } // Force the partial derivative with respect to x to be 0.5. let b = log(x.withDerivative { (dx: inout Float) in print("∂log/∂x = \(dx), but rewritten to 0.5"); dx = 0.5 }) return a + b })  ∂log/∂x = 0.033333335, but rewritten to 0.5 ∂+/∂sin = 1.0 0.65425146  ### Use it in a neural network module Just like how we used it in a simple Float function, we can use it in any numerical application, like the following neural network built using the Swift for TensorFlow Deep Learning Library. import TensorFlow struct MLP: Layer { var layer1 = Dense<Float>(inputSize: 2, outputSize: 10, activation: relu) var layer2 = Dense<Float>(inputSize: 10, outputSize: 1, activation: relu) @differentiable func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { let h0 = layer1(input).withDerivative { print("∂L/∂layer1 =",$0) }
return layer2(h0)
}
}

var classifier = MLP()
let optimizer = SGD(for: classifier, learningRate: 0.02)

let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
let y: Tensor<Float> = [0, 1, 1, 0]

for _ in 0..<10 {
let 𝛁model = gradient(at: classifier) { classifier -> Tensor<Float> in
let ŷ = classifier(x).withDerivative { print("∂L/∂ŷ =", \$0) }
let loss = (ŷ - y).squared().mean()
print("Loss: \(loss)")
return loss
}
optimizer.update(&classifier, along: 𝛁model)
}

Loss: 0.39960712
∂L/∂ŷ = [[      -0.25],
[-0.16906497],
[ 0.24121547],
[0.018418938]]
∂L/∂layer1 = [[         0.0,          0.0,          0.0,          0.0,          0.0,          0.0,
0.0,          0.0,          0.0,          0.0],
[ -0.10081359,  -0.11300191,  0.045225676,  0.037417475,  -0.12096906,  -0.11740652,
0.092516035,  -0.11235209, -0.054534305, -0.093863584],
[    0.143837,   0.16122684, -0.064526275,  -0.05338583,   0.17259406,   0.16751115,
-0.13199836,   0.16029969,   0.07780747,     0.133921],
[  0.01098323,  0.012311097, -0.004927153, -0.004076481,  0.013179085,   0.01279096,
-0.010079244,  0.012240301, 0.0059412895,  0.010226054]]
Loss: 0.39462554
∂L/∂ŷ = [[       -0.25],
[ -0.16807508],
[  0.23180142],
[0.0120177865]]
∂L/∂layer1 = [[          0.0,           0.0,           0.0,           0.0,           0.0,           0.0,
0.0,           0.0,           0.0,           0.0],
[ -0.099650644,   -0.11184558,   0.044736966,   0.037198395,   -0.12010395,    -0.1167191,
0.09197435,   -0.11185439,  -0.054215003,  -0.093314014],
[   0.13743357,    0.15425228,  -0.061699163,   -0.05130231,    0.16564183,    0.16097361,
-0.12684679,    0.15426442,    0.07477084,    0.12869439],
[ 0.0071252687,   0.007997237, -0.0031988043, -0.0026597776,   0.008587731,   0.008345705,
-0.0065763947,   0.007997867,  0.0038765075,   0.006672184]]
Loss: 0.39019448
∂L/∂ŷ = [[      -0.25],
[-0.16671714],
[  0.2231982],
[0.009079367]]
∂L/∂layer1 = [[          0.0,           0.0,           0.0,           0.0,           0.0,           0.0,
0.0,           0.0,           0.0,           0.0],
[  -0.09830766,   -0.11049614,    0.04413618,   0.036897853,   -0.11899029,   -0.11577608,
0.09123125,   -0.11111182,   -0.05377698,   -0.09256009],
[    0.1316127,    0.14793044,  -0.059088804,  -0.049398247,    0.15930226,    0.15499914,
-0.12213892,     0.1487547,    0.07199575,    0.12391795],
[ 0.0053538065,  0.0060175876, -0.0024036437, -0.0020094463,  0.0064801765,  0.0063051316,
-0.004968427,  0.0060511176,  0.0029286786,  0.0050407955]]
Loss: 0.38616335
∂L/∂ŷ = [[      -0.25],
[  -0.165146],
[ 0.21527952],
[0.006699443]]
∂L/∂layer1 = [[          0.0,           0.0,           0.0,           0.0,           0.0,           0.0,
0.0,           0.0,           0.0,           0.0],
[ -0.096873865,   -0.10904455,   0.043476164,   0.036550127,   -0.11773739,     -0.114685,
0.09037148,   -0.11022576,  -0.053270184,    -0.0916878],
[   0.12628195,    0.14214732,  -0.056674264,  -0.047645684,    0.15347904,    0.14950003,
-0.117805645,     0.1436871,    0.06944146,    0.11952156],
[  0.003929862,  0.0044235876, -0.0017636884, -0.0014827213,  0.0047762278,  0.0046524024,
-0.003666081,   0.004471505,  0.0021610004,  0.0037194798]]
Loss: 0.38242665
∂L/∂ŷ = [[      -0.25],
[ -0.1634019],
[  0.2078557],
[0.004740596]]
∂L/∂layer1 = [[          0.0,           0.0,           0.0,           0.0,           0.0,           0.0,
0.0,           0.0,           0.0,           0.0],
[  -0.09537157,   -0.10751423,   0.042770643,   0.036164124,  -0.116372935,   -0.11347382,
0.08941708,   -0.10922218,    -0.0527076,    -0.0907195],
[   0.12131758,    0.13676369,   -0.05440648,   -0.04600264,    0.14803241,     0.1443446,
-0.11374316,    0.13893628,   0.067046806,   0.115399905],
[ 0.0027669081,  0.0031191898, -0.0012408567, -0.0010491891,   0.003376197,  0.0032920884,
-0.0025941574,  0.0031687405,  0.0015291465,  0.0026319427]]
Loss: 0.3789484
∂L/∂ŷ = [[       -0.25],
[  -0.1615152],
[  0.20087644],
[0.0031458437]]
∂L/∂layer1 = [[           0.0,            0.0,            0.0,            0.0,            0.0,
0.0,            0.0,            0.0,            0.0,            0.0],
[   -0.09381737,    -0.10592259,     0.04202999,     0.03574656,    -0.11491785,
-0.11216361,    0.088384636,    -0.10812061,    -0.05209902,   -0.089672014],
[    0.11668065,      0.1317359,   -0.052272696,   -0.044457994,     0.14292333,
0.13949788,   -0.109923966,      0.1344696,    0.064795546,     0.11152508],
[   0.001827288,    0.002063062, -0.00081862125,  -0.0006962384,   0.0022382636,
0.002184619,  -0.0017214742,   0.0021058733,   0.0010147365,   0.0017465486]]
Loss: 0.37563026
∂L/∂ŷ = [[ -0.24986199],
[  -0.1595121],
[  0.19429788],
[0.0018657446]]
∂L/∂layer1 = [[   -0.14446305,    -0.16335253,    0.064635016,     0.05529948,    -0.17761588,
-0.17351569,     0.13672993,    -0.16750911,    -0.08059653,    -0.13872148],
[   -0.09222533,    -0.10428439,     0.04126305,    0.035303235,   -0.113390125,
-0.110772565,      0.0872885,    -0.10693795,    -0.05145289,    -0.08855991],
[    0.11233747,     0.12702633,   -0.050261535,   -0.043002024,      0.1381178,
0.13492942,    -0.10632403,     0.13025856,     0.06267354,     0.10787271],
[  0.0010787201,   0.0012197698, -0.00048263618, -0.00041292678,   0.0013262756,
0.0012956591,  -0.0010209761,   0.0012508073,   0.0006018224,   0.0010358472]]
Loss: 0.3703937
∂L/∂ŷ = [[ -0.24574545],
[  -0.1536124],
[  0.19074702],
[0.0046536624]]
∂L/∂layer1 = [[  -0.14144973,   -0.16019088,     0.0631879,   0.054388408,    -0.1745423,   -0.17065698,
0.13447726,   -0.16500926,   -0.07926868,   -0.13643602],
[  -0.08841846,  -0.100133315,   0.039497964,   0.033997513,    -0.1091042,   -0.10667554,
0.08406005,   -0.10314522,   -0.04954986,   -0.08528444],
[   0.10979295,   0.124339774,  -0.049046293,   -0.04221615,     0.1354793,    0.13246353,
-0.10438093,    0.12807979,   0.061528157,     0.1059013],
[ 0.0026786227,  0.0030335223, -0.0011965843,  -0.001029949,  0.0033052939,   0.003231718,
-0.0025465854,  0.0031247674,  0.0015011048,   0.002583678]]
Loss: 0.3654527
∂L/∂ŷ = [[ -0.24179694],
[ -0.14801097],
[   0.1871593],
[0.0071467757]]
∂L/∂layer1 = [[  -0.13856791,   -0.15715094,   0.061826546,    0.05351452,   -0.17160064,   -0.16791496,
0.13231656,   -0.16261542,   -0.07799503,   -0.13424383],
[  -0.08482146,   -0.09619668,   0.037845835,     0.0327578,   -0.10504176,   -0.10278565,
0.08099483,   -0.09954165,  -0.047743037,   -0.08217457],
[   0.10725641,    0.12164033,  -0.047855914,  -0.041422114,     0.1328249,    0.12997206,
-0.10241765,    0.12587003,    0.06037089,   0.103909425],
[ 0.0040956424,  0.0046448996,  -0.001827403, -0.0015817251,  0.0050719883,   0.004963051,
-0.003910871,   0.004806413,   0.002305294,  0.0039678356]]
Loss: 0.36081767
∂L/∂ŷ = [[-0.23810837],
[-0.14284015],
[ 0.18339509],
[0.009221196]]
∂L/∂layer1 = [[  -0.13586827,   -0.15429293,    0.06056964,   0.052698165,   -0.16885515,   -0.16535343,
0.13029808,   -0.16038936,  -0.076805234,   -0.13219596],
[  -0.08150677,   -0.09255964,   0.036335457,   0.031613395,   -0.10129545,   -0.09919479,
0.07816524,  -0.096216865,  -0.046075117,  -0.079303764],
[   0.10464803,   0.118839025,  -0.046651762,   -0.04058902,    0.13005508,    0.12735802,
-0.100357786,     0.1235346,   0.059156686,    0.10181956],
[  0.005261755,   0.005975285, -0.0023456737,  -0.002040836,  0.0065392344,   0.006403624,
-0.0050460394,   0.006211381,  0.0029744275,   0.005119538]]



## Recomputing activations during backpropagation to save memory (checkpointing)

Checkpointing is a traditional technique in reverse-mode automatic differentiation for saving memory. Rather than saving large intermediate values in the original computation for computing derivatives, the intermediate values are instead recomputed as needed during backpropagation.

This technique has been realized in modern deep learning libraries as well. In Swift, API withRecomputationInPullbacks(_:) enables you to control what to recompute during backpropagation, and it is available on all Differentiable types.

But today, let us learn how to define our own gradient checkpointing APIs from scratch, in just a few lines of code.

We can define our own gradient checkpointing API, makeRecomputedInGradient(_:), in terms of standard library function differentiableFunction(from:), which is a shorthand for creating a differentiable function directly from a derivative function (also called a "vector-Jacobian products (VJP) function").

As we have seen before, the derivative function returns a tuple of the original function's result and a pullback closure. We return original(x) in value:, and call pullback(at:in:) on original to evaluate the original function again and get a pullback.

/// Given a differentiable function, returns the same differentiable function except when
/// derivatives of this function are being computed. In that case, values in the original function needed
/// for computing the derivatives will be recomputed, instead of being captured by the differential or pullback.
///
/// - Parameter body: The body of the differentiable function.
/// - Returns: The same differentiable function whose derivatives, when computed, will recompute
///   some values from the original function.
_ original: @escaping @differentiable (T) -> U
) -> @differentiable (T) -> U {
return differentiableFunction { x in
(value: original(x), pullback: { v in pullback(at: x, in: original)(v) })
}
}


### Verify it works

let input: Float = 10.0
print("Running original computation...")

// Differentiable multiplication with checkpointing.
let square = makeRecomputedInGradient { (x: Float) -> Float in
print("  Computing square...")
return x * x
}

// Differentiate f(x) = (cos(x))^2.
let (output, backprop) = valueWithPullback(at: input) { input -> Float in
return square(cos(input))
}
print("Running backpropagation...")

Running original computation...
Computing square...
Running backpropagation...
Computing square...



### Extend it to neural network modules

In this example, we define a simple convolutional neural network.

struct Model: Layer {
var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6))
var maxPool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
var flatten = Flatten<Float>()
var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10)

@differentiable
func call(_ input: Tensor<Float>) -> Tensor<Float> {
return input.sequenced(through: conv, maxPool, flatten, dense)
}
}


We want to make activations in the convolution layer (conv) be recomputed during backpropagation. However, using makeRecomputedInGradient(_:) could make the resulting code look cumbersome, especially when we want to apply layers sequentially using sequenced(in:through:_:_:_:_:).

input.sequenced(in: context, through: conv, maxPool, flatten, dense)


So, why don't we define a special layer type that wraps a layer and makes its activations be recomputed during backpropagation? Let's do it.

First, we define a makeRecomputedInGradient(_:) function that takes a binary function.

// Same as the previous makeRecomputedInGradient(_:), except it's for binary functions.
func makeRecomputedInGradient<T: Differentiable, U: Differentiable, V: Differentiable>(
_ original: @escaping @differentiable (T, U) -> V
) -> @differentiable (T, U) -> V {
return differentiableFunction { x, y in
(value: original(x, y), pullback: { v in pullback(at: x, y, in: original)(v) })
}
}


Then, we define a generic layer ActivationDiscarding<Wrapped>.

import TensorFlow

/// A layer wrapper that makes the underlying layer's activations be discarded during application
/// and recomputed during backpropagation.
/// The wrapped layer.
var wrapped: Wrapped

@differentiable
func callAsFunction(_ input: Wrapped.Input) -> Wrapped.Output {
let apply = makeRecomputedInGradient { (layer: Wrapped, input: Input) -> Wrapped.Output in
print("    Applying \(Wrapped.self) layer...")
return layer(input)
}
return apply(wrapped, input)
}
}


Finally, we can add a method on all layers that returns the same layer except its activations are discarded during application and recomputed during backpropagation.

extension Layer {
}
}


Back in the model, all we have to change is to wrap the convolution layer into the activation-recomputing layer.

var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)).discardingActivations()


Now, simply use it in the model!

struct Model: Layer {
var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)).discardingActivations()
var maxPool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
var flatten = Flatten<Float>()
var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10)

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
return input.sequenced(through: conv, maxPool, flatten, dense)
}
}


When we run a training loop, we can see that the convolution layer's activations are computed twice: once during layer application, and once during backpropagation.

// Use random training data.
let x = Tensor<Float>(randomNormal: [10, 16, 16, 3])
let y = Tensor<Int32>(rangeFrom: 0, to: 10, stride: 1)

var model = Model()
let opt = SGD(for: model)

for i in 1...5 {
print("Starting training step \(i)")
print("  Running original computation...")
let (logits, backprop) = model.appliedForBackpropagation(to: x)
let (loss, dL_dŷ) = valueWithGradient(at: logits) { logits in
softmaxCrossEntropy(logits: logits, labels: y)
}
print("  Loss: \(loss)")
print("  Running backpropagation...")
let (dL_dθ, _) = backprop(dL_dŷ)

opt.update(&model, along: dL_dθ)
}

Starting training step 1
Running original computation...
Applying Conv2D<Float> layer...
Loss: 3.186631
Running backpropagation...
Applying Conv2D<Float> layer...
Starting training step 2
Running original computation...
Applying Conv2D<Float> layer...
Loss: 2.8461165
Running backpropagation...
Applying Conv2D<Float> layer...
Starting training step 3
Running original computation...
Applying Conv2D<Float> layer...
Loss: 2.5653338
Running backpropagation...
Applying Conv2D<Float> layer...
Starting training step 4
Running original computation...
Applying Conv2D<Float> layer...
Loss: 2.3213038
Running backpropagation...
Applying Conv2D<Float> layer...
Starting training step 5
Running original computation...
Applying Conv2D<Float> layer...
Loss: 2.1033013
Running backpropagation...
Applying Conv2D<Float> layer...



Just like that, it is super easy to define generic differentiable programming libraries for different domains.