Arêtes vives dans Differentiable Swift

Differentiable Swift a parcouru un long chemin en termes de convivialité. Voici un aperçu des parties qui sont encore un peu peu évidentes. Au fur et à mesure des progrès, ce guide deviendra de plus en plus petit et vous pourrez écrire du code différenciable sans avoir besoin de syntaxe spéciale.

Boucles

Les boucles sont différenciables, il n'y a qu'un détail à connaître. Lorsque vous écrivez la boucle, enveloppez le bit où vous spécifiez ce sur quoi vous bouclez withoutDerivative(at:)

var a: [Float] = [1,2,3]

Par exemple:

for _ in a.indices 
{}

devient

for _ in withoutDerivative(at: a.indices) 
{}

ou:

for _ in 0..<a.count 
{}

devient

for _ in 0..<withoutDerivative(at: a.count) 
{}

Cela est nécessaire car le membre Array.count ne contribue pas à la dérivée par rapport au tableau. Seuls les éléments réels du tableau contribuent à la dérivée.

Si vous avez une boucle dans laquelle vous utilisez manuellement un entier comme limite supérieure, il n'est pas nécessaire de l'utiliser withoutDerivative(at:) :

let iterations: Int = 10
for _ in 0..<iterations {} //this is fine as-is.

Cartographier et réduire

map et reduce ont des versions différenciables spéciales qui fonctionnent exactement comme ce à quoi vous êtes habitué :

a = [1,2,3]
let aPlusOne = a.differentiableMap {$0 + 1}
let aSum = a.differentiableReduce(0, +)
print("aPlusOne", aPlusOne)
print("aSum", aSum)
aPlusOne [2.0, 3.0, 4.0]
aSum 6.0

Ensembles d'indices de tableau

Les ensembles d'indices de tableau ( array[0] = 0 ) ne sont pas différenciables dès le départ, mais vous pouvez coller cette extension :

extension Array where Element: Differentiable {
    @differentiable(where Element: Differentiable)
    mutating func updated(at index: Int, with newValue: Element) {
        self[index] = newValue
    }

    @derivative(of: updated)
    mutating func vjpUpdated(at index: Int, with newValue: Element)
      -> (value: Void, pullback: (inout TangentVector) -> (Element.TangentVector))
    {
        self.updated(at: index, with: newValue)
        return ((), { v in
            let dElement = v[index]
            v.base[index] = .zero
            return dElement
        })
    }
}

puis la syntaxe de solution de contournement est la suivante :

var b: [Float] = [1,2,3]

au lieu de cela:

b[0] = 17

écrire cela:

b.updated(at: 0, with: 17)

Assurons-nous que cela fonctionne :

func plusOne(array: [Float]) -> Float{
  var array = array
  array.updated(at: 0, with: array[0] + 1)
  return array[0]
}

let plusOneValAndGrad = valueWithGradient(at: [2], in: plusOne)
print(plusOneValAndGrad)
(value: 3.0, gradient: [1.0])

L'erreur que vous obtiendrez sans cette solution de contournement est Differentiation of coroutine calls is not yet supported . Voici le lien pour voir les progrès réalisés pour rendre cette solution de contournement inutile : https://bugs.swift.org/browse/TF-1277 (il parle d'Array.subscript._modify, qui est ce qu'on appelle dans les coulisses lorsque vous créez un tableau ensemble d'indices).

Float <-> Double conversions

Si vous basculez entre Float et Double , leurs constructeurs ne sont pas déjà différenciables. Voici une fonction qui vous permettra de passer d'un Float à un Double de manière différenciée.

(Changez Float et Double dans le code ci-dessous, et vous disposez d'une fonction qui convertit Double en Float .)

Vous pouvez créer des convertisseurs similaires pour tout autre type numérique réel.

@differentiable
func convertToDouble(_ a: Float) -> Double {
    return Double(a)
}

@derivative(of: convertToDouble)
func convertToDoubleVJP(_ a: Float) -> (value: Double, pullback: (Double) -> Float) {
    func pullback(_ v: Double) -> Float{
        return Float(v)
    }
    return (value: Double(a), pullback: pullback)
}

Voici un exemple d'utilisation :

@differentiable
func timesTwo(a: Float) -> Double {
  return convertToDouble(a * 2)
}
let input: Float = 3
let valAndGrad = valueWithGradient(at: input, in: timesTwo)
print("grad", valAndGrad.gradient)
print("type of input:", type(of: input))
print("type of output:", type(of: valAndGrad.value))
print("type of gradient:", type(of: valAndGrad.gradient))
grad 2.0
type of input: Float
type of output: Double
type of gradient: Float

Fonctions transcendantales et autres (sin, cos, abs, max)

De nombreuses fonctions transcendantales et autres fonctions intégrées courantes ont déjà été différenciables pour Float et Double . Il y en a moins pour Double que pour Float . Certains ne sont pas disponibles non plus. Voici donc quelques définitions dérivées manuelles pour vous donner une idée de la façon de créer ce dont vous avez besoin, au cas où cela ne serait pas déjà fourni :

pow (voir le lien pour l'explication dérivée)

import Foundation

@usableFromInline
@derivative(of: pow) 
func powVJP(_ base: Double, _ exponent: Double) -> (value: Double, pullback: (Double) -> (Double, Double)) {
    let output: Double = pow(base, exponent)
    func pullback(_ vector: Double) -> (Double, Double) {
        let baseDerivative = vector * (exponent * pow(base, exponent - 1))
        let exponentDerivative = vector * output * log(base)
        return (baseDerivative, exponentDerivative)
    }

    return (value: output, pullback: pullback)
}

maximum

@usableFromInline
@derivative(of: max)
func maxVJP<T: Comparable & Differentiable>(_ x: T, _ y: T) -> (value: T, pullback: (T.TangentVector)
  -> (T.TangentVector, T.TangentVector))
{
    func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
        if x < y {
            return (.zero, v)
        } else {
            return (v, .zero)
        }
    }
    return (value: max(x, y), pullback: pullback)
}

abdos

@usableFromInline
@derivative(of: abs)
func absVJP<T: Comparable & SignedNumeric & Differentiable>(_ x: T)
  -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
{
    func pullback(_ v: T.TangentVector) -> T.TangentVector{
        if x < 0 {
            return .zero - v
        }
        else {
            return v
        }
    }
    return (value: abs(x), pullback: pullback)
}

sqrt (voir le lien pour l'explication dérivée)

@usableFromInline
@derivative(of: sqrt) 
func sqrtVJP(_ x: Double) -> (value: Double, pullback: (Double) -> Double) {
    let output = sqrt(x)
    func pullback(_ v: Double) -> Double {
        return v / (2 * output)
    }
    return (value: output, pullback: pullback)
}

Vérifions que ceux-ci fonctionnent :

let powGrad = gradient(at: 2, 2, in: pow)
print("pow gradient: ", powGrad, "which is", powGrad == (4.0, 2.772588722239781) ? "correct" : "incorrect")

let maxGrad = gradient(at: 1, 2, in: max)
print("max gradient: ", maxGrad, "which is", maxGrad == (0.0, 1.0) ? "correct" : "incorrect")

let absGrad = gradient(at: 2, in: abs)
print("abs gradient: ", absGrad, "which is", absGrad == 1.0 ? "correct" : "incorrect")

let sqrtGrad = gradient(at: 4, in: sqrt)
print("sqrt gradient: ", sqrtGrad, "which is", sqrtGrad == 0.25 ? "correct" : "incorrect")
pow gradient:  (4.0, 2.772588722239781) which is correct
max gradient:  (0.0, 1.0) which is correct
abs gradient:  1.0 which is correct
sqrt gradient:  0.25 which is correct

L'erreur du compilateur qui vous avertit de la nécessité de quelque chose comme ceci est la suivante : Expression is not differentiable. Cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files

Abonnement KeyPath

L'abonnement KeyPath (get ou set) ne fonctionne pas immédiatement, mais encore une fois, vous pouvez ajouter quelques extensions, puis utiliser une syntaxe de contournement. C'est ici:

https://github.com/tensorflow/swift/issues/530#issuecomment -687400701

Cette solution de contournement est un peu plus moche que les autres. Cela ne fonctionne que pour les objets personnalisés, qui doivent être conformes à Differentiable et AdditiveArithmetic. Vous devez ajouter un membre .tmp et une fonction .read() , et vous utilisez le membre .tmp comme stockage intermédiaire lors de l'obtention d'indices KeyPath (il y a un exemple dans le code lié). Les ensembles d'indices KeyPath fonctionnent assez simplement avec une fonction .write() .