public struct TensorVisitorPlan<Base>
TensorVisitorPlan approximates [WritableKeyPath<Base, Tensor<Float>]
but
is more efficient. This is useful for writing generic optimizers which want
to map over the gradients, the existing weights, and an index which can be
used to find auxiliarily stored weights. This is slightly more efficient (~2x) but it could
be better because it trades off slightly higher overheads (extra pointer dereference)
for not having to do O(depth_of_tree) work that is required with a plain list to track
down each individual KeyPath.
-
Flatten out the plan as a single
[WritableKeyPath<Base, Tensor<Float>]
.Declaration
public var allTensorKeyPaths: [WritableKeyPath<Base, Tensor<Float>>] { get }
-
Efficiently collect all the tensors.
Declaration
public func allTensors(_ v: Base) -> [Tensor<Float>]
-
Efficiently map over two values of type
Base
and apply a mapping function. Returns the number of tensors. The extraInt
argument is provided to allow indexing into an auxiliary list of Tensors with the same Tensor count as the plan. -
Declaration
func populateMask<Base>(_ mask: inout [Bool], _ kp: WritableKeyPath<Base, Tensor<Float>>)
-
Find all keys ending with a particular key-path.
Declaration
public func keysEnding<Base>(with kp: WritableKeyPath<Base, Tensor<Float>>) -> [Bool]
-
Declaration
func findFirstIndex<TrueBase, T>( _ rootKeyPath: WritableKeyPath<TrueBase, Base>, _ prefix: WritableKeyPath<TrueBase, T>, _ i: inout Int ) -> Bool
-
Find the index of the first keypath starting with a particular prefix. Note: All array layers support 1-past-the-end indexing.
Declaration
func firstIndex<T>(withPrefix prefix: WritableKeyPath<Base, T>) -> Int
-
Find all keys indices in a range defined by two KeyPath prefixes: [lower, upper)
Declaration
public func allKeysBetween<T, U>(lower: WritableKeyPath<Base, T>, upper: WritableKeyPath<Base, U>) -> [Bool]
-
Creates a plan to visit all the tensors in a particular instance of
Base
. This plan is transferable to structurally equivalent versions of Base.Declaration
public init(_ obj: Base)