Необработанные операторы TensorFlow

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть исходный код на GitHub

Swift for TensorFlow, основанный на TensorFlow, использует новый подход к проектированию API. API тщательно создаются на основе существующих библиотек и комбинируются с новыми языковыми идиомами. Это означает, что не все API-интерфейсы TensorFlow будут напрямую доступны как API-интерфейсы Swift, и для развития нашего курирования API требуется время и целенаправленные усилия. Однако не волнуйтесь, если ваш любимый оператор TensorFlow недоступен в Swift — библиотека TensorFlow Swift предоставляет вам прозрачный доступ к большинству операторов TensorFlow в пространстве имен _Raw .

Импортируйте TensorFlow , чтобы начать.

import TensorFlow

Вызов сырых операторов

Просто найдите нужную функцию в пространстве имен _Raw с помощью автодополнения кода.

print(_Raw.mul(Tensor([2.0, 3.0]), Tensor([5.0, 6.0])))
[10.0, 18.0]

Определение нового оператора умножения

Multiply уже доступен как оператор * в Tensor , но давайте представим, что мы хотим сделать его доступным под новым именем .* . Swift позволяет вам задним числом добавлять методы или вычисляемые свойства к существующим типам, используя объявления extension .

Теперь давайте добавим .* к Tensor , объявив расширение, и сделаем его доступным, когда Scalar тип тензора соответствует Numeric .

infix operator .* : MultiplicationPrecedence

extension Tensor where Scalar: Numeric {
    static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
        return _Raw.mul(lhs, rhs)
    }
}

let x: Tensor<Double> = [[1.0, 2.0], [3.0, 4.0]]
let y: Tensor<Double> = [[8.0, 7.0], [6.0, 5.0]]
print(x .* y)
[[ 8.0, 14.0],
 [18.0, 20.0]]

Определение производной завернутой функции

Вы можете не только легко определить Swift API для необработанного оператора TensorFlow, но и сделать его дифференцируемым для работы с первоклассным автоматическим дифференцированием Swift.

Чтобы сделать .* дифференцируемым, используйте атрибут @derivative для производной функции и укажите исходную функцию в качестве аргумента атрибута под меткой of: Поскольку оператор .* определяется, когда универсальный тип Scalar соответствует Numeric , этого недостаточно для соответствия Tensor<Scalar> протоколу Differentiable . Созданный с безопасностью типов, Swift напомнит нам о необходимости добавить общее ограничение к атрибуту @differentiable , чтобы потребовать Scalar соответствия протоколу TensorFlowFloatingPoint , что приведет к тому, что Tensor<Scalar> будет соответствовать Differentiable .

@differentiable(where Scalar: TensorFlowFloatingPoint)
infix operator .* : MultiplicationPrecedence

extension Tensor where Scalar: Numeric {
    @differentiable(where Scalar: TensorFlowFloatingPoint)
    static func .* (_ lhs: Tensor,  _ rhs: Tensor) -> Tensor {
        return _Raw.mul(lhs, rhs)
    }
}

extension Tensor where Scalar : TensorFlowFloatingPoint { 
    @derivative(of: .*)
    static func multiplyDerivative(
        _ lhs: Tensor, _ rhs: Tensor
    ) -> (value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)) {
        return (lhs * rhs, { v in
            ((rhs * v).unbroadcasted(to: lhs.shape),
            (lhs * v).unbroadcasted(to: rhs.shape))
        })
    }
}

// Now, we can take the derivative of a function that calls `.*` that we just defined.
print(gradient(at: x, y) { x, y in
    (x .* y).sum()
})
(0.0, 0.0)

Больше примеров

let matrix = Tensor<Float>([[1, 2], [3, 4]])

print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: false))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: false))
[[ 7.0, 15.0],
 [10.0, 22.0]]
[[10.0, 14.0],
 [14.0, 20.0]]
[[ 5.0, 11.0],
 [11.0, 25.0]]
[[ 7.0, 10.0],
 [15.0, 22.0]]