Sharp edges in Differentiable Swift

Differentiable Swift has come a long way in terms of usability. Here is a heads-up about the parts that are still a little un-obvious. As progress continues, this guide will become smaller and smaller, and you'll be able to write differentiable code without needing special syntax.


Loops are differentiable, there's just one detail to know about. When you write the loop, wrap the bit where you specify what you're looping over in withoutDerivative(at:)

var a: [Float] = [1,2,3]

for example:

for _ in a.indices 


for _ in withoutDerivative(at: a.indices) 


for _ in 0..<a.count 


for _ in 0..<withoutDerivative(at: a.count) 

This is necessary because the Array.count member doesn't contribute to the derivative with respect to the array. Only the actual elements in the array contribute to the derivative.

If you've got a loop where you manually use an integer as the upper bound, there's no need to use withoutDerivative(at:):

let iterations: Int = 10
for _ in 0..<iterations {} //this is fine as-is.

Map and Reduce

map and reduce have special differentiable versions that work exactly like what you're used to:

a = [1,2,3]
let aPlusOne = a.differentiableMap {$0 + 1}
let aSum = a.differentiableReduce(0, +)
print("aPlusOne", aPlusOne)
print("aSum", aSum)
aPlusOne [2.0, 3.0, 4.0]
aSum 6.0

Array subscript sets

Array subscript sets (array[0] = 0) aren't differentiable out of the box, but you can paste this extension:

extension Array where Element: Differentiable {
    @differentiable(where Element: Differentiable)
    mutating func updated(at index: Int, with newValue: Element) {
        self[index] = newValue

    @derivative(of: updated)
    mutating func vjpUpdated(at index: Int, with newValue: Element)
      -> (value: Void, pullback: (inout TangentVector) -> (Element.TangentVector))
        self.updated(at: index, with: newValue)
        return ((), { v in
            let dElement = v[index]
            v.base[index] = .zero
            return dElement

and then the workaround syntax is like this:

var b: [Float] = [1,2,3]

instead of this:

b[0] = 17

write this:

b.updated(at: 0, with: 17)

Let's make sure it works:

func plusOne(array: [Float]) -> Float{
  var array = array
  array.updated(at: 0, with: array[0] + 1)
  return array[0]

let plusOneValAndGrad = valueWithGradient(at: [2], in: plusOne)
(value: 3.0, gradient: [1.0])

The error you'll get without this workaround is Differentiation of coroutine calls is not yet supported. Here is the link to see progress on making this workaround unnecessary: (it talks about Array.subscript._modify, which is what's called behind the scenes when you do an array subscript set).

Float <-> Double conversions

If you're switching between Float and Double, their constructors aren't already differentiable. Here's a function that will let you go from a Float to a Double differentiably.

(Switch Float and Double in the below code, and you've got a function that converts from Double to Float.)

You can make similar converters for any other real Numeric types.

func convertToDouble(_ a: Float) -> Double {
    return Double(a)

@derivative(of: convertToDouble)
func convertToDoubleVJP(_ a: Float) -> (value: Double, pullback: (Double) -> Float) {
    func pullback(_ v: Double) -> Float{
        return Float(v)
    return (value: Double(a), pullback: pullback)

Here's an example usage:

func timesTwo(a: Float) -> Double {
  return convertToDouble(a * 2)
let input: Float = 3
let valAndGrad = valueWithGradient(at: input, in: timesTwo)
print("grad", valAndGrad.gradient)
print("type of input:", type(of: input))
print("type of output:", type(of: valAndGrad.value))
print("type of gradient:", type(of: valAndGrad.gradient))
grad 2.0
type of input: Float
type of output: Double
type of gradient: Float

Transcendental and other functions (sin, cos, abs, max)

A lot of transcendentals and other common built-in functions have already been made differentiable for Float and Double. There are fewer for Double than Float. Some aren't available for either. So here are a few manual derivative definitions to give you the idea of how to make what you need, in case it isn't already provided:

pow (see link for derivative explanation)

import Foundation

@derivative(of: pow) 
func powVJP(_ base: Double, _ exponent: Double) -> (value: Double, pullback: (Double) -> (Double, Double)) {
    let output: Double = pow(base, exponent)
    func pullback(_ vector: Double) -> (Double, Double) {
        let baseDerivative = vector * (exponent * pow(base, exponent - 1))
        let exponentDerivative = vector * output * log(base)
        return (baseDerivative, exponentDerivative)

    return (value: output, pullback: pullback)


@derivative(of: max)
func maxVJP<T: Comparable & Differentiable>(_ x: T, _ y: T) -> (value: T, pullback: (T.TangentVector)
  -> (T.TangentVector, T.TangentVector))
    func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
        if x < y {
            return (.zero, v)
        } else {
            return (v, .zero)
    return (value: max(x, y), pullback: pullback)


@derivative(of: abs)
func absVJP<T: Comparable & SignedNumeric & Differentiable>(_ x: T)
  -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
    func pullback(_ v: T.TangentVector) -> T.TangentVector{
        if x < 0 {
            return .zero - v
        else {
            return v
    return (value: abs(x), pullback: pullback)

sqrt (see link for derivative explanation)

@derivative(of: sqrt) 
func sqrtVJP(_ x: Double) -> (value: Double, pullback: (Double) -> Double) {
    let output = sqrt(x)
    func pullback(_ v: Double) -> Double {
        return v / (2 * output)
    return (value: output, pullback: pullback)

Let's check that these work:

let powGrad = gradient(at: 2, 2, in: pow)
print("pow gradient: ", powGrad, "which is", powGrad == (4.0, 2.772588722239781) ? "correct" : "incorrect")

let maxGrad = gradient(at: 1, 2, in: max)
print("max gradient: ", maxGrad, "which is", maxGrad == (0.0, 1.0) ? "correct" : "incorrect")

let absGrad = gradient(at: 2, in: abs)
print("abs gradient: ", absGrad, "which is", absGrad == 1.0 ? "correct" : "incorrect")

let sqrtGrad = gradient(at: 4, in: sqrt)
print("sqrt gradient: ", sqrtGrad, "which is", sqrtGrad == 0.25 ? "correct" : "incorrect")
pow gradient:  (4.0, 2.772588722239781) which is correct
max gradient:  (0.0, 1.0) which is correct
abs gradient:  1.0 which is correct
sqrt gradient:  0.25 which is correct

The compiler error that alerts you to the need for something like this is: Expression is not differentiable. Cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files

KeyPath subscripting

KeyPath subscripting (get or set) doesn't work out of the box, but once again, there are some extensions you can add, and then use a workaround syntax. Here it is:

This workaround is a little uglier than the others. It only works for custom objects, which must conform to Differentiable and AdditiveArithmetic. You have to add a .tmp member and a .read() function, and you use the .tmp member as intermediate storage when doing KeyPath subscript gets (there is an example in the linked code). KeyPath subscript sets work pretty simply with a .write() function.