Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub |
Sobre la base de TensorFlow, Swift for TensorFlow adopta un nuevo enfoque para el diseño de API. Las API se seleccionan cuidadosamente de bibliotecas establecidas y se combinan con nuevos lenguajes lingüísticos. Esto significa que no todas las API de TensorFlow estarán disponibles directamente como API de Swift, y nuestra curación de API necesita tiempo y esfuerzo dedicado para evolucionar. Sin embargo, no se preocupe si su operador de TensorFlow favorito no está disponible en Swift: la biblioteca de TensorFlow Swift le brinda acceso transparente a la mayoría de los operadores de TensorFlow, bajo el espacio de nombres _Raw
.
Importa TensorFlow
para comenzar.
import TensorFlow
Llamar a operadores sin formato
Simplemente busque la función que necesita en el espacio de nombres _Raw
completando el código.
print(_Raw.mul(Tensor([2.0, 3.0]), Tensor([5.0, 6.0])))
[10.0, 18.0]
Definición de un nuevo operador de multiplicación
Multiply ya está disponible como operador *
en Tensor
, pero supongamos que queremos que esté disponible con un nuevo nombre como .*
. Swift le permite agregar retroactivamente métodos o propiedades calculadas a tipos existentes mediante declaraciones de extension
.
Ahora, agreguemos .*
a Tensor
declarando una extensión y hagámoslo disponible cuando el tipo Scalar
del tensor se ajuste a Numeric
.
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
let x: Tensor<Double> = [[1.0, 2.0], [3.0, 4.0]]
let y: Tensor<Double> = [[8.0, 7.0], [6.0, 5.0]]
print(x .* y)
[[ 8.0, 14.0], [18.0, 20.0]]
Definición de una derivada de una función envuelta
No solo puede definir fácilmente una API de Swift para un operador de TensorFlow sin procesar, sino que también puede hacer que sea diferenciable para trabajar con la diferenciación automática de primera clase de Swift.
Para hacer que .*
sea diferenciable, use el atributo @derivative
en la función derivada y especifique la función original como un argumento de atributo bajo la etiqueta of:
Dado que el operador .*
se define cuando el tipo genérico Scalar
se ajusta a Numeric
, no es suficiente para hacer que Tensor<Scalar>
se ajuste al protocolo Differentiable
. Nacido con seguridad de tipos, Swift nos recordará agregar una restricción genérica en el atributo @differentiable
para requerir que Scalar
se ajuste al protocolo TensorFlowFloatingPoint
, lo que haría que Tensor<Scalar>
se ajustara a Differentiable
.
@differentiable(where Scalar: TensorFlowFloatingPoint)
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
@differentiable(where Scalar: TensorFlowFloatingPoint)
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
extension Tensor where Scalar : TensorFlowFloatingPoint {
@derivative(of: .*)
static func multiplyDerivative(
_ lhs: Tensor, _ rhs: Tensor
) -> (value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)) {
return (lhs * rhs, { v in
((rhs * v).unbroadcasted(to: lhs.shape),
(lhs * v).unbroadcasted(to: rhs.shape))
})
}
}
// Now, we can take the derivative of a function that calls `.*` that we just defined.
print(gradient(at: x, y) { x, y in
(x .* y).sum()
})
(0.0, 0.0)
Más ejemplos
let matrix = Tensor<Float>([[1, 2], [3, 4]])
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: false))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: false))
[[ 7.0, 15.0], [10.0, 22.0]] [[10.0, 14.0], [14.0, 20.0]] [[ 5.0, 11.0], [11.0, 25.0]] [[ 7.0, 10.0], [15.0, 22.0]]