Differentiable Swift의 날카로운 모서리

Differentiable Swift는 유용성 측면에서 큰 발전을 이루었습니다. 아직 명확하지 않은 부분에 대해 미리 알려드립니다. 계속 진행됨에 따라 이 가이드는 점점 더 작아질 것이며 특별한 구문 없이도 미분 가능한 코드를 작성할 수 있을 것입니다.

루프

루프는 미분 가능합니다. 알아야 할 세부 사항은 단 하나뿐입니다. 루프를 작성할 때, withoutDerivative(at:) 에서 루프할 내용을 지정하는 비트를 래핑하세요.

var a: [Float] = [1,2,3]

예를 들어:

for _ in a.indices 
{}

된다

for _ in withoutDerivative(at: a.indices) 
{}

또는:

for _ in 0..<a.count 
{}

된다

for _ in 0..<withoutDerivative(at: a.count) 
{}

이는 Array.count 멤버가 배열과 관련된 파생 항목에 기여하지 않기 때문에 필요합니다. 배열의 실제 요소만 도함수에 기여합니다.

수동으로 정수를 상한으로 사용하는 루프가 있는 경우에는 withoutDerivative(at:) 사용할 필요가 없습니다.

let iterations: Int = 10
for _ in 0..<iterations {} //this is fine as-is.

매핑 및 축소

mapreduce 에는 익숙한 것과 똑같이 작동하는 특수한 미분 가능한 버전이 있습니다.

a = [1,2,3]
let aPlusOne = a.differentiableMap {$0 + 1}
let aSum = a.differentiableReduce(0, +)
print("aPlusOne", aPlusOne)
print("aSum", aSum)
aPlusOne [2.0, 3.0, 4.0]
aSum 6.0

배열 아래 첨자 세트

배열 첨자 세트( array[0] = 0 )는 기본적으로 미분할 수 없지만 다음 확장자를 붙여넣을 수 있습니다.

extension Array where Element: Differentiable {
    @differentiable(where Element: Differentiable)
    mutating func updated(at index: Int, with newValue: Element) {
        self[index] = newValue
    }

    @derivative(of: updated)
    mutating func vjpUpdated(at index: Int, with newValue: Element)
      -> (value: Void, pullback: (inout TangentVector) -> (Element.TangentVector))
    {
        self.updated(at: index, with: newValue)
        return ((), { v in
            let dElement = v[index]
            v.base[index] = .zero
            return dElement
        })
    }
}

해결 방법 구문은 다음과 같습니다.

var b: [Float] = [1,2,3]

대신에:

b[0] = 17

이것을 쓰세요:

b.updated(at: 0, with: 17)

제대로 작동하는지 확인해 봅시다:

func plusOne(array: [Float]) -> Float{
  var array = array
  array.updated(at: 0, with: array[0] + 1)
  return array[0]
}

let plusOneValAndGrad = valueWithGradient(at: [2], in: plusOne)
print(plusOneValAndGrad)
(value: 3.0, gradient: [1.0])

이 해결 방법을 사용하지 않으면 오류는 Differentiation of coroutine calls is not yet supported 입니다. 다음은 이 해결 방법을 불필요하게 만드는 진행 상황을 볼 수 있는 링크입니다: https://bugs.swift.org/browse/TF-1277 (배열을 수행할 때 뒤에서 호출되는 Array.subscript._modify에 대해 설명합니다. 아래 첨자 세트).

Float <-> Double 변환

FloatDouble 사이를 전환하는 경우 해당 생성자는 아직 미분 가능하지 않습니다. 다음은 Float 에서 Double 로 차별화 가능하게 해주는 함수입니다.

(아래 코드에서 FloatDouble 전환하면 Double 에서 Float 로 변환되는 함수가 있습니다.)

다른 실제 숫자 유형에 대해서도 유사한 변환기를 만들 수 있습니다.

@differentiable
func convertToDouble(_ a: Float) -> Double {
    return Double(a)
}

@derivative(of: convertToDouble)
func convertToDoubleVJP(_ a: Float) -> (value: Double, pullback: (Double) -> Float) {
    func pullback(_ v: Double) -> Float{
        return Float(v)
    }
    return (value: Double(a), pullback: pullback)
}

사용 예는 다음과 같습니다.

@differentiable
func timesTwo(a: Float) -> Double {
  return convertToDouble(a * 2)
}
let input: Float = 3
let valAndGrad = valueWithGradient(at: input, in: timesTwo)
print("grad", valAndGrad.gradient)
print("type of input:", type(of: input))
print("type of output:", type(of: valAndGrad.value))
print("type of gradient:", type(of: valAndGrad.gradient))
grad 2.0
type of input: Float
type of output: Double
type of gradient: Float

초월 및 기타 함수(sin, cos, abs, max)

많은 초월 함수와 기타 일반적인 내장 함수가 이미 FloatDouble 에 대해 차별화 가능해졌습니다. Double Float 보다 적습니다. 일부는 둘 다 사용할 수 없습니다. 따라서 아직 제공되지 않은 경우 필요한 것을 만드는 방법에 대한 아이디어를 제공하는 몇 가지 수동 파생 정의는 다음과 같습니다.

pow (파생 설명은 링크 참조)

import Foundation

@usableFromInline
@derivative(of: pow) 
func powVJP(_ base: Double, _ exponent: Double) -> (value: Double, pullback: (Double) -> (Double, Double)) {
    let output: Double = pow(base, exponent)
    func pullback(_ vector: Double) -> (Double, Double) {
        let baseDerivative = vector * (exponent * pow(base, exponent - 1))
        let exponentDerivative = vector * output * log(base)
        return (baseDerivative, exponentDerivative)
    }

    return (value: output, pullback: pullback)
}

최대

@usableFromInline
@derivative(of: max)
func maxVJP<T: Comparable & Differentiable>(_ x: T, _ y: T) -> (value: T, pullback: (T.TangentVector)
  -> (T.TangentVector, T.TangentVector))
{
    func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
        if x < y {
            return (.zero, v)
        } else {
            return (v, .zero)
        }
    }
    return (value: max(x, y), pullback: pullback)
}

복근

@usableFromInline
@derivative(of: abs)
func absVJP<T: Comparable & SignedNumeric & Differentiable>(_ x: T)
  -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
{
    func pullback(_ v: T.TangentVector) -> T.TangentVector{
        if x < 0 {
            return .zero - v
        }
        else {
            return v
        }
    }
    return (value: abs(x), pullback: pullback)
}

sqrt (미분 설명은 링크 참조)

@usableFromInline
@derivative(of: sqrt) 
func sqrtVJP(_ x: Double) -> (value: Double, pullback: (Double) -> Double) {
    let output = sqrt(x)
    func pullback(_ v: Double) -> Double {
        return v / (2 * output)
    }
    return (value: output, pullback: pullback)
}

다음이 작동하는지 확인해 보겠습니다.

let powGrad = gradient(at: 2, 2, in: pow)
print("pow gradient: ", powGrad, "which is", powGrad == (4.0, 2.772588722239781) ? "correct" : "incorrect")

let maxGrad = gradient(at: 1, 2, in: max)
print("max gradient: ", maxGrad, "which is", maxGrad == (0.0, 1.0) ? "correct" : "incorrect")

let absGrad = gradient(at: 2, in: abs)
print("abs gradient: ", absGrad, "which is", absGrad == 1.0 ? "correct" : "incorrect")

let sqrtGrad = gradient(at: 4, in: sqrt)
print("sqrt gradient: ", sqrtGrad, "which is", sqrtGrad == 0.25 ? "correct" : "incorrect")
pow gradient:  (4.0, 2.772588722239781) which is correct
max gradient:  (0.0, 1.0) which is correct
abs gradient:  1.0 which is correct
sqrt gradient:  0.25 which is correct

이와 같은 필요성을 경고하는 컴파일러 오류는 다음과 같습니다. Expression is not differentiable. Cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files

KeyPath 구독

KeyPath 첨자(가져오기 또는 설정)는 기본적으로 작동하지 않지만 다시 한 번 추가할 수 있는 몇 가지 확장이 있으며 해결 방법 구문을 사용할 수 있습니다. 여기있어:

https://github.com/tensorflow/swift/issues/530#issuecomment-687400701

이 해결 방법은 다른 해결 방법보다 약간 추악합니다. Differentiable 및 AdditiveArithmetic을 준수해야 하는 사용자 정의 개체에 대해서만 작동합니다. .tmp 멤버와 .read() 함수를 추가해야 하며 KeyPath 첨자 가져오기를 수행할 때 .tmp 멤버를 중간 저장소로 사용합니다(링크된 코드에 예제가 있음). KeyPath 첨자 세트는 .write() 함수를 사용하여 매우 간단하게 작동합니다.