Посмотреть на TensorFlow.org | Запустить в Google Colab | Посмотреть исходный код на GitHub |
Swift for TensorFlow, основанный на TensorFlow, использует новый подход к проектированию API. API тщательно создаются на основе существующих библиотек и комбинируются с новыми языковыми идиомами. Это означает, что не все API-интерфейсы TensorFlow будут напрямую доступны как API-интерфейсы Swift, и для развития нашего курирования API требуется время и целенаправленные усилия. Однако не волнуйтесь, если ваш любимый оператор TensorFlow недоступен в Swift — библиотека TensorFlow Swift предоставляет вам прозрачный доступ к большинству операторов TensorFlow в пространстве имен _Raw
.
Импортируйте TensorFlow
, чтобы начать.
import TensorFlow
Вызов сырых операторов
Просто найдите нужную функцию в пространстве имен _Raw
с помощью автодополнения кода.
print(_Raw.mul(Tensor([2.0, 3.0]), Tensor([5.0, 6.0])))
[10.0, 18.0]
Определение нового оператора умножения
Multiply уже доступен как оператор *
в Tensor
, но давайте представим, что мы хотим сделать его доступным под новым именем .*
. Swift позволяет вам задним числом добавлять методы или вычисляемые свойства к существующим типам, используя объявления extension
.
Теперь давайте добавим .*
к Tensor
, объявив расширение, и сделаем его доступным, когда Scalar
тип тензора соответствует Numeric
.
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
let x: Tensor<Double> = [[1.0, 2.0], [3.0, 4.0]]
let y: Tensor<Double> = [[8.0, 7.0], [6.0, 5.0]]
print(x .* y)
[[ 8.0, 14.0], [18.0, 20.0]]
Определение производной завернутой функции
Вы можете не только легко определить Swift API для необработанного оператора TensorFlow, но и сделать его дифференцируемым для работы с первоклассным автоматическим дифференцированием Swift.
Чтобы сделать .*
дифференцируемым, используйте атрибут @derivative
для производной функции и укажите исходную функцию в качестве аргумента атрибута под меткой of:
Поскольку оператор .*
определяется, когда универсальный тип Scalar
соответствует Numeric
, этого недостаточно для соответствия Tensor<Scalar>
протоколу Differentiable
. Созданный с безопасностью типов, Swift напомнит нам о необходимости добавить общее ограничение к атрибуту @differentiable
, чтобы потребовать Scalar
соответствия протоколу TensorFlowFloatingPoint
, что приведет к тому, что Tensor<Scalar>
будет соответствовать Differentiable
.
@differentiable(where Scalar: TensorFlowFloatingPoint)
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
@differentiable(where Scalar: TensorFlowFloatingPoint)
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
extension Tensor where Scalar : TensorFlowFloatingPoint {
@derivative(of: .*)
static func multiplyDerivative(
_ lhs: Tensor, _ rhs: Tensor
) -> (value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)) {
return (lhs * rhs, { v in
((rhs * v).unbroadcasted(to: lhs.shape),
(lhs * v).unbroadcasted(to: rhs.shape))
})
}
}
// Now, we can take the derivative of a function that calls `.*` that we just defined.
print(gradient(at: x, y) { x, y in
(x .* y).sum()
})
(0.0, 0.0)
Больше примеров
let matrix = Tensor<Float>([[1, 2], [3, 4]])
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: false))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: false))
[[ 7.0, 15.0], [10.0, 22.0]] [[10.0, 14.0], [14.0, 20.0]] [[ 5.0, 11.0], [11.0, 25.0]] [[ 7.0, 10.0], [15.0, 22.0]]