Veja no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub |
Com base no TensorFlow, o Swift for TensorFlow adota uma nova abordagem para o design de API. As APIs são cuidadosamente selecionadas de bibliotecas estabelecidas e combinadas com novos idiomas de linguagem. Isso significa que nem todas as APIs do TensorFlow estarão disponíveis diretamente como APIs Swift, e nossa curadoria de APIs precisa de tempo e esforço dedicado para evoluir. No entanto, não se preocupe se o seu operador TensorFlow favorito não estiver disponível no Swift - a biblioteca TensorFlow Swift oferece acesso transparente à maioria dos operadores TensorFlow, no namespace _Raw
.
Importe o TensorFlow
para começar.
import TensorFlow
Chamando operadores brutos
Basta encontrar a função que você precisa no namespace _Raw
por meio do preenchimento de código.
print(_Raw.mul(Tensor([2.0, 3.0]), Tensor([5.0, 6.0])))
[10.0, 18.0]
Definindo um novo operador de multiplicação
Multiply já está disponível como operador *
no Tensor
, mas vamos fingir que queríamos disponibilizá-lo com um novo nome como .*
. O Swift permite adicionar métodos ou propriedades computadas retroativamente a tipos existentes usando declarações de extension
.
Agora, vamos adicionar .*
ao Tensor
declarando uma extensão e disponibilizá-la quando o tipo Scalar
do tensor estiver em conformidade com Numeric
.
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
let x: Tensor<Double> = [[1.0, 2.0], [3.0, 4.0]]
let y: Tensor<Double> = [[8.0, 7.0], [6.0, 5.0]]
print(x .* y)
[[ 8.0, 14.0], [18.0, 20.0]]
Definindo uma derivada de uma função encapsulada
Além de definir facilmente uma API Swift para um operador bruto do TensorFlow, você também pode torná-la diferenciável para trabalhar com a diferenciação automática de primeira classe do Swift.
Para tornar .*
diferenciável, use o atributo @derivative
na função derivada e especifique a função original como um argumento de atributo sob o rótulo of:
Como o operador .*
é definido quando o tipo genérico Scalar
está em conformidade com Numeric
, ele não é suficiente para tornar Tensor<Scalar>
em conformidade com o protocolo Differentiable
. Nascido com segurança de tipo, o Swift nos lembrará de adicionar uma restrição genérica no atributo @differentiable
para exigir que Scalar
esteja em conformidade com o protocolo TensorFlowFloatingPoint
, o que tornaria Tensor<Scalar>
em conformidade com Differentiable
.
@differentiable(where Scalar: TensorFlowFloatingPoint)
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
@differentiable(where Scalar: TensorFlowFloatingPoint)
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
extension Tensor where Scalar : TensorFlowFloatingPoint {
@derivative(of: .*)
static func multiplyDerivative(
_ lhs: Tensor, _ rhs: Tensor
) -> (value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)) {
return (lhs * rhs, { v in
((rhs * v).unbroadcasted(to: lhs.shape),
(lhs * v).unbroadcasted(to: rhs.shape))
})
}
}
// Now, we can take the derivative of a function that calls `.*` that we just defined.
print(gradient(at: x, y) { x, y in
(x .* y).sum()
})
(0.0, 0.0)
Mais exemplos
let matrix = Tensor<Float>([[1, 2], [3, 4]])
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: false))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: false))
[[ 7.0, 15.0], [10.0, 22.0]] [[10.0, 14.0], [14.0, 20.0]] [[ 5.0, 11.0], [11.0, 25.0]] [[ 7.0, 10.0], [15.0, 22.0]]