Functions

The following functions are available globally.

  • Returns the L1 loss between predictions and expectations.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func l1Loss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the L2 loss between predictions and expectations.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func l2Loss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the hinge loss between predictions and expectations.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the squared hinge loss between predictions and expectations.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the categorical hinge loss between predictions and expectations.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func categoricalHingeLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the logarithm of the hyperbolic cosine of the error between predictions and expectations.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func logCoshLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the Poisson loss between predictions and expectations.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func poissonLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the Kullback-Leibler divergence (KL divergence) between between expectations and predictions. Given two distributions p and q, KL divergence computes p * log(p / q).

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func kullbackLeiblerDivergence<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Returns the softmax cross entropy (categorical cross entropy) between logits and labels.

    Declaration

    @differentiable(wrt: logits)
    public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
      logits: Tensor<Scalar>,
      probabilities: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    logits

    One-hot encoded outputs from a neural network.

    labels

    Indices (zero-indexed) of the correct outputs.

  • Returns the sigmoid cross entropy (binary cross entropy) between logits and labels.

    Declaration

    @differentiable(wrt: logits)
    @differentiable(wrt: (logits, labels)
    public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
      logits: Tensor<Scalar>,
      labels: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    logits

    The unscaled output of a neural network.

    labels

    Integer values that correspond to the correct output.

  • Returns a tensor with the same shape and scalars as the specified tensor.

    Declaration

    @differentiable
    public func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> where Scalar : TensorFlowScalar
  • Calls the given closure within a context that has everything identical to the current context except for the given learning phase.

    Declaration

    public func withContext<R>(_ context: Context, _ body: () throws -> R) rethrows -> R

    Parameters

    context

    A context that will be set before the closure gets called and restored after the closure returns.

    body

    A nullary closure. If the closure has a return value, that value is also used as the return value of the withContext(_:_:) function.

    Return Value

    The return value, if any, of the body closure.

  • Calls the given closure within a context that has everything identical to the current context except for the given learning phase.

    Declaration

    public func withLearningPhase<R>(
      _ learningPhase: LearningPhase,
      _ body: () throws -> R
    ) rethrows -> R

    Parameters

    learningPhase

    A learning phase that will be set before the closure gets called and restored after the closure returns.

    body

    A nullary closure. If the closure has a return value, that value is also used as the return value of the withLearningPhase(_:_:) function.

    Return Value

    The return value, if any, of the body closure.

  • Calls the given closure within a context that has everything identical to the current context except for the given random seed.

    Declaration

    public func withRandomSeedForTensorFlow<R>(
      _ randomSeed: TensorFlowSeed,
      _ body: () throws -> R
    ) rethrows -> R

    Parameters

    randomSeed

    A random seed that will be set before the closure gets called and restored after the closure returns.

    body

    A nullary closure. If the closure has a return value, that value is also used as the return value of the withRandomSeedForTensorFlow(_:_:) function.

    Return Value

    The return value, if any, of the body closure.

  • Calls the given closure within a context that has everything identical to the current context except for the given random number generator.

    Declaration

    public func withRandomNumberGeneratorForTensorFlow<G: RandomNumberGenerator, R>(
      _ randomNumberGenerator: inout G,
      _ body: () throws -> R
    ) rethrows -> R

    Parameters

    randomNumberGenerator

    A random number generator that will be set before the closure gets called and restored after the closure returns.

    body

    A nullary closure. If the closure has a return value, that value is also used as the return value of the withRandomNumberGeneratorForTensorFlow(_:_:) function.

    Return Value

    The return value, if any, of the body closure.

  • Declaration

    public func zip<T: TensorGroup, U: TensorGroup>(
      _ dataset1: Dataset<T>, _ dataset2: Dataset<U>
    ) -> Dataset<Zip2TensorGroup<T, U>>
  • LazyTensorBarrier ensures all live tensors (on device if provided) are scheduled and running. If wait is set to true, this call blocks until the computation is complete.

    Declaration

    public func LazyTensorBarrier(on device: Device? = nil, devices: [Device] = [], wait: Bool = false)
  • Declaration

    public func valueWithGradient<T, R>(
      at x: T,
      in f: @differentiable (T) -> Tensor<R>
    ) -> (value: Tensor<R>, gradient: T.TangentVector)
    where T: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func valueWithGradient<T, U, R>(
      at x: T,
      _ y: U,
      in f: @differentiable (T, U) -> Tensor<R>
    ) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector))
    where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func valueWithGradient<T, U, V, R>(
      at x: T,
      _ y: U,
      _ z: V,
      in f: @differentiable (T, U, V) -> Tensor<R>
    ) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
    where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func valueWithGradient<T, R>(
      of f: @escaping @differentiable (T) -> Tensor<R>
    ) -> (T) -> (value: Tensor<R>, gradient: T.TangentVector)
    where T: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func valueWithGradient<T, U, R>(
      of f: @escaping @differentiable (T, U) -> Tensor<R>
    ) -> (T, U) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector))
    where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func valueWithGradient<T, U, V, R>(
      of f: @escaping @differentiable (T, U, V) -> Tensor<R>
    ) -> (T, U, V) -> (
      value: Tensor<R>,
      gradient: (T.TangentVector, U.TangentVector, V.TangentVector)
    )
    where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func gradient<T, R>(
      at x: T,
      in f: @differentiable (T) -> Tensor<R>
    ) -> T.TangentVector where T: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func gradient<T, U, R>(
      at x: T,
      _ y: U,
      in f: @differentiable (T, U) -> Tensor<R>
    ) -> (T.TangentVector, U.TangentVector)
    where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func gradient<T, U, V, R>(
      at x: T,
      _ y: U,
      _ z: V,
      in f: @differentiable (T, U, V) -> Tensor<R>
    ) -> (T.TangentVector, U.TangentVector, V.TangentVector)
    where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func gradient<T, R>(
      of f: @escaping @differentiable (T) -> Tensor<R>
    ) -> (T) -> T.TangentVector where T: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func gradient<T, U, R>(
      of f: @escaping @differentiable (T, U) -> Tensor<R>
    ) -> (T, U) -> (T.TangentVector, U.TangentVector)
    where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint
  • Declaration

    public func gradient<T, U, V, R>(
      of f: @escaping @differentiable (T, U, V) -> Tensor<R>
    ) -> (T, U, V) -> (T.TangentVector, U.TangentVector, V.TangentVector)
    where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint
  • Make a function be recomputed in its pullback, known as “checkpointing” in traditional automatic differentiation.

    Declaration

    public func withRecomputationInPullbacks<T, U>(
      _ body: @escaping @differentiable (T) -> U
    ) -> @differentiable (T) -> U where T : Differentiable, U : Differentiable
  • Create a differentiable function from a vector-Jacobian products function.

    Declaration

    public func differentiableFunction<T : Differentiable, R : Differentiable>(
      from vjp: @escaping (T)
               -> (value: R, pullback: (R.TangentVector) -> T.TangentVector)
    ) -> @differentiable (T) -> R
  • Create a differentiable function from a vector-Jacobian products function.

    Declaration

    public func differentiableFunction<T, U, R>(
      from vjp: @escaping (T, U)
               -> (value: R, pullback: (R.TangentVector)
                 -> (T.TangentVector, U.TangentVector))
    ) -> @differentiable (T, U) -> R
  • Returns x like an identity function. When used in a context where x is being differentiated with respect to, this function will not produce any derivative at x.

    Declaration

    @_semantics("autodiff.nonvarying")
    public func withoutDerivative<T>(at x: T) -> T
  • Applies the given closure body to x. When used in a context where x is being differentiated with respect to, this function will not produce any derivative at x.

    Declaration

    @_semantics("autodiff.nonvarying")
    public func withoutDerivative<T, R>(at x: T, in body: (T) -> R) -> R
  • Executes a closure, making TensorFlow operations run on a specific kind of device.

    Declaration

    public func withDevice<R>(
      _ kind: DeviceKind,
      _ index: UInt = 0,
      perform body: () throws -> R
    ) rethrows -> R

    Parameters

    kind

    A kind of device to run TensorFlow operations on.

    index

    The device to run the ops on.

    body

    A closure whose TensorFlow operations are to be executed on the specified kind of device.

  • Executes a closure, making TensorFlow operations run on a device with a specific name.

    Some examples of device names:

    • “/device:CPU:0”: The CPU of your machine.
    • “/GPU:0”: Short-hand notation for the first GPU of your machine that is visible to TensorFlow
    • “/job:localhost/replica:0/task:0/device:GPU:1”: Fully qualified name of the second GPU of your machine that is visible to TensorFlow.

    Declaration

    public func withDevice<R>(named name: String, perform body: () throws -> R) rethrows -> R

    Parameters

    name

    Device name.

    body

    A closure whose TensorFlow operations are to be executed on the specified kind of device.

  • Executes a closure, allowing TensorFlow to place TensorFlow operations on any device. This should restore the default placement behavior.

    Declaration

    public func withDefaultDevice<R>(perform body: () throws -> R) rethrows -> R

    Parameters

    body

    A closure whose TensorFlow operations are to be executed on the specified kind of device.

  • Resize images to size using the specified method.

    Precondition

    The images must have rank 3 or 4.

    Precondition

    The size must be positive.

    Declaration

    @differentiable(wrt: images)
    public func resize(
      images: Tensor<Float>,
      size: (newHeight: Int, newWidth: Int),
      method: ResizeMethod = .bilinear,
      antialias: Bool = false
    ) -> Tensor<Float>

    Parameters

    images

    4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of shape [height, width, channels].

    size

    The new size of the images.

    method

    The resize method. The default value is .bilinear.

    antialias

    Iff true, use an anti-aliasing filter when downsampling an image.

  • Resize images to size using area interpolation.

    Precondition

    The images must have rank 3 or 4.

    Precondition

    The size must be positive.

    Declaration

    public func resizeArea<Scalar: TensorFlowNumeric>(
      images: Tensor<Scalar>,
      size: (newHeight: Int, newWidth: Int),
      alignCorners: Bool = false
    ) -> Tensor<Float>

    Parameters

    images

    4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of shape [height, width, channels].

    size

    The new size of the images.

  • Returns a 2-D dilation with the specified input, filter, strides, and padding.

    Precondition

    input must have rank 4.

    Precondition

    filter must have rank 3.

    Declaration

    @differentiable(wrt: (input, filter)
    public func dilation2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filter: Tensor<Scalar>,
      strides: (Int, Int, Int, Int) = (1, 1, 1, 1),
      rates: (Int, Int, Int, Int) = (1, 1, 1, 1),
      padding: Padding = .valid
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filter

    The dilation filter.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation

    rates

    The dilation rates for each dimension of the input.

  • Returns a 2-D erosion with the specified input, filter, strides, and padding.

    Precondition

    input must have rank 4.

    Precondition

    filter must have rank 3.

    Declaration

    @differentiable(wrt: (input, filter)
    public func erosion2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filter: Tensor<Scalar>,
      strides: (Int, Int, Int, Int) = (1, 1, 1, 1),
      rates: (Int, Int, Int, Int) = (1, 1, 1, 1),
      padding: Padding = .valid
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filter

    The erosion filter.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation

    rates

    The dilation rates for each dimension of the input.

  • Returns a function that creates a tensor by initializing all its values to zeros.

    Declaration

    public func zeros<Scalar>() -> ParameterInitializer<Scalar> where Scalar : TensorFlowFloatingPoint
  • Returns a function that creates a tensor by initializing all its values to the provided value.

    Declaration

    public func constantInitializer<Scalar: TensorFlowFloatingPoint>(
      value: Scalar
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by initializing it to the provided value. Note that broadcasting of the provided value is not supported.

    Declaration

    public func constantInitializer<Scalar: TensorFlowFloatingPoint>(
      value: Tensor<Scalar>
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by performing Glorot (Xavier) uniform initialization for the specified shape, randomly sampling scalar values from a uniform distribution between -limit and limit, generated by the default random number generator, where limit is sqrt(6 / (fanIn + fanOut)), and fanIn/fanOut represent the number of input and output features multiplied by the receptive field, if present.

    Declaration

    public func glorotUniform<Scalar: TensorFlowFloatingPoint>(
      seed: TensorFlowSeed = Context.local.randomSeed
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by performing Glorot (Xavier) normal initialization for the specified shape, randomly sampling scalar values from a truncated normal distribution centered on 0 with standard deviation sqrt(2 / (fanIn + fanOut)), where fanIn/fanOut represent the number of input and output features multiplied by the receptive field size, if present.

    Declaration

    public func glorotNormal<Scalar: TensorFlowFloatingPoint>(
      seed: TensorFlowSeed = Context.local.randomSeed
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by performing He (Kaiming) uniform initialization for the specified shape, randomly sampling scalar values from a uniform distribution between -limit and limit, generated by the default random number generator, where limit is sqrt(6 / fanIn), and fanIn represents the number of input features multiplied by the receptive field, if present.

    Declaration

    public func heUniform<Scalar: TensorFlowFloatingPoint>(
      seed: TensorFlowSeed = Context.local.randomSeed
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by performing He (Kaiming) normal initialization for the specified shape, randomly sampling scalar values from a truncated normal distribution centered on 0 with standard deviation sqrt(2 / fanIn), where fanIn represents the number of input features multiplied by the receptive field size, if present.

    Declaration

    public func heNormal<Scalar: TensorFlowFloatingPoint>(
      seed: TensorFlowSeed = Context.local.randomSeed
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by performing LeCun uniform initialization for the specified shape, randomly sampling scalar values from a uniform distribution between -limit and limit, generated by the default random number generator, where limit is sqrt(3 / fanIn), and fanIn represents the number of input features multiplied by the receptive field, if present.

    Declaration

    public func leCunUniform<Scalar: TensorFlowFloatingPoint>(
      seed: TensorFlowSeed = Context.local.randomSeed
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by performing LeCun normal initialization for the specified shape, randomly sampling scalar values from a truncated normal distribution centered on 0 with standard deviation sqrt(1 / fanIn), where fanIn represents the number of input features multiplied by the receptive field size, if present.

    Declaration

    public func leCunNormal<Scalar: TensorFlowFloatingPoint>(
      seed: TensorFlowSeed = Context.local.randomSeed
    ) -> ParameterInitializer<Scalar>
  • Returns a function that creates a tensor by initializing all its values randomly from a truncated Normal distribution. The generated values follow a Normal distribution with mean mean and standard deviation standardDeviation, except that values whose magnitude is more than two standard deviations from the mean are dropped and resampled.

    Declaration

    public func truncatedNormalInitializer<Scalar: TensorFlowFloatingPoint>(
      mean: Tensor<Scalar> = Tensor<Scalar>(0),
      standardDeviation: Tensor<Scalar> = Tensor<Scalar>(1),
      seed: TensorFlowSeed = Context.local.randomSeed
    ) -> ParameterInitializer<Scalar>

    Parameters

    mean

    Mean of the Normal distribution.

    standardDeviation

    Standard deviation of the Normal distribution.

    Return Value

    A truncated normal parameter initializer function.

  • Declaration

    public func == (lhs: TFETensorHandle, rhs: TFETensorHandle) -> Bool
  • Returns an identity matrix or a batch of matrices.

    Declaration

    public func eye<Scalar: Numeric>(
      rowCount: Int,
      columnCount: Int? = nil,
      batchShape: [Int] = [],
      on device: Device = .default
    ) -> Tensor<Scalar>

    Parameters

    rowCount

    The number of rows in each batch matrix.

    columnCount

    The number of columns in each batch matrix.

    batchShape

    The leading batch dimensions of the returned tensor.

  • Computes the trace of an optionally batched matrix. The trace is the the sum along the main diagonal of each inner-most matrix.

    The input is a tensor with shape [..., M, N]. The output is a tensor with shape [...].

    Precondition

    matrix must be a tensor with shape [..., M, N].

    Declaration

    @differentiable(wrt: matrix)
    public func trace<T>(_ matrix: Tensor<T>) -> Tensor<T> where T : Numeric, T : TensorFlowScalar

    Parameters

    matrix

    A tensor of shape [..., M, N].

  • Returns the Cholesky decomposition of one or more square matrices.

    The input is a tensor of shape [..., M, M] whose inner-most 2 dimensions form square matrices.

    The input has to be symmetric and positive definite. Only the lower-triangular part of the input will be used for this operation. The upper-triangular part will not be read.

    The output is a tensor of the same shape as the input containing the Cholesky decompositions for all input submatrices [..., :, :].

    Declaration

    @differentiable
    public func cholesky<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint

    Parameters

    input

    A tensor of shape [..., M, M].

  • Returns the solution x to the system of linear equations represented by Ax = b.

    Precondition

    matrix must be a tensor with shape [..., M, M].

    Precondition

    rhs must be a tensor with shape [..., M, K].

    Declaration

    @differentiable
    public func triangularSolve<T: TensorFlowFloatingPoint>(
      matrix: Tensor<T>,
      rhs: Tensor<T>,
      lower: Bool = true,
      adjoint: Bool = false
    ) -> Tensor<T>

    Parameters

    matrix

    The input triangular coefficient matrix, representing A in Ax = b.

    rhs

    Right-hand side values, representing b in Ax = b.

    lower

    Whether matrix is lower triangular (true) or upper triangular (false). The default value is true.

    adjoint

    If true, solve with the adjoint of matrix instead of matrix. The default value is false.

    Return Value

    The solution x to the system of linear equations represented by Ax = b. x has the same shape as b.

  • Computes the L1 loss between expected and predicted. loss = reduction(abs(expected - predicted))

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func l1Loss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _sum
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the L2 loss between expected and predicted. loss = reduction(square(expected - predicted))

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func l2Loss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _sum
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the mean of absolute difference between labels and predictions. loss = mean(abs(expected - predicted))

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func meanAbsoluteError<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Computes the mean of squares of errors between labels and predictions. loss = mean(square(expected - predicted))

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func meanSquaredError<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Computes the mean squared logarithmic error between predicted and expected loss = square(log(expected) - log(predicted))

    Note

    Negative tensor entries will be clamped at 0 to avoid undefined logarithmic behavior, as log(_:) is undefined for negative reals.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func meanSquaredLogarithmicError<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Computes the mean absolute percentage error between predicted and expected. loss = 100 * mean(abs((expected - predicted) / abs(expected)))

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func meanAbsolutePercentageError<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

  • Computes the hinge loss between predicted and expected. loss = reduction(max(0, 1 - predicted * expected)) expected values are expected to be -1 or 1.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func hingeLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the squared hinge loss between predicted and expected. loss = reduction(square(max(0, 1 - predicted * expected))) expected values are expected to be -1 or 1.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func squaredHingeLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the categorical hinge loss between predicted and expected. loss = maximum(negative - positive + 1, 0) where negative = max((1 - expected) * predicted) and positive = sum(predicted * expected)

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func categoricalHingeLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the logarithm of the hyperbolic cosine of the prediction error. logcosh = log((exp(x) + exp(-x))/2), where x is the error predicted - expected

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func logCoshLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the Poisson loss between predicted and expected The Poisson loss is the mean of the elements of the Tensor predicted - expected * log(predicted).

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func poissonLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes Kullback-Leibler divergence loss between expected and predicted. loss = reduction(expected * log(expected / predicted))

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func kullbackLeiblerDivergence<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _sum
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the sparse softmax cross entropy (categorical cross entropy) between logits and labels. Use this crossentropy loss function when there are two or more label classes. We expect labels to be provided as integers. There should be # classes floating point values per feature for logits and a single floating point value per feature for expected.

    Declaration

    @differentiable(wrt: logits)
    public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
      logits: Tensor<Scalar>,
      labels: Tensor<Int32>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    logits

    One-hot encoded outputs from a neural network.

    labels

    Indices (zero-indexed) of the correct outputs.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the sparse softmax cross entropy (categorical cross entropy) between logits and labels. Use this crossentropy loss function when there are two or more label classes. We expect labels to be provided provided in a one_hot representation. There should be # classes floating point values per feature.

    Declaration

    @differentiable(wrt: logits)
    public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
      logits: Tensor<Scalar>,
      probabilities: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    logits

    Unscaled log probabilities from a neural network.

    probabilities

    Probability values that correspond to the correct output. Each row must be a valid probability distribution.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the sigmoid cross entropy (binary cross entropy) between logits and labels. Use this cross-entropy loss when there are only two label classes (assumed to be 0 and 1). For each example, there should be a single floating-point value per prediction.

    Declaration

    @differentiable(wrt: logits)
    @differentiable(wrt: (logits, labels)
    public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
      logits: Tensor<Scalar>,
      labels: Tensor<Scalar>,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _mean
    ) -> Tensor<Scalar>

    Parameters

    logits

    The unscaled output of a neural network.

    labels

    Integer values that correspond to the correct output.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Computes the Huber loss between predicted and expected.

    For each value x in error = expected - predicted:

    • 0.5 * x^2 if |x| <= δ.
    • 0.5 * δ^2 + δ * (|x| - δ) otherwise.

    • Source: Wikipedia article.

    Declaration

    @differentiable(wrt: predicted)
    @differentiable(wrt: (predicted, expected)
    public func huberLoss<Scalar: TensorFlowFloatingPoint>(
      predicted: Tensor<Scalar>,
      expected: Tensor<Scalar>,
      delta: Scalar,
      reduction: @differentiable (Tensor<Scalar>) -> Tensor<Scalar> = _sum
    ) -> Tensor<Scalar>

    Parameters

    predicted

    Predicted outputs from a neural network.

    expected

    Expected values, i.e. targets, that correspond to the correct output.

    delta

    A floating point scalar representing the point where the Huber loss function changes from quadratic to linear.

    reduction

    Reduction to apply on the computed element-wise loss values.

  • Returns the absolute value of the specified tensor element-wise.

    Declaration

    @differentiable
    public func abs<T>(_ x: Tensor<T>) -> Tensor<T> where T : SignedNumeric, T : TensorFlowScalar
  • Returns the natural logarithm of the specified tensor element-wise.

    Declaration

    @differentiable
    public func log<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the base-two logarithm of the specified tensor element-wise.

    Declaration

    @differentiable
    public func log2<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the base-ten logarithm of the specified tensor element-wise.

    Declaration

    @differentiable
    public func log10<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the logarithm of 1 + x element-wise.

    Declaration

    @differentiable
    public func log1p<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns log(1 - exp(x)) using a numerically stable approach.

    Declaration

    @differentiable
    public func log1mexp<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the sine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func sin<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the cosine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func cos<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the tangent of the specified tensor element-wise.

    Declaration

    @differentiable
    public func tan<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the hyperbolic sine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func sinh<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the hyperbolic cosine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func cosh<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the hyperbolic tangent of the specified tensor element-wise.

    Declaration

    @differentiable
    public func tanh<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the inverse cosine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func acos<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the inverse sine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func asin<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the inverse tangent of the specified tensor element-wise.

    Declaration

    @differentiable
    public func atan<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the inverse hyperbolic cosine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func acosh<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the inverse hyperbolic sine of the specified tensor element-wise.

    Declaration

    @differentiable
    public func asinh<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the inverse hyperbolic tangent of the specified tensor element-wise.

    Declaration

    @differentiable
    public func atanh<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the square root of the specified tensor element-wise.

    Declaration

    @differentiable
    public func sqrt<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the inverse square root of the specified tensor element-wise.

    Declaration

    @differentiable
    public func rsqrt<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the exponential of the specified tensor element-wise.

    Declaration

    @differentiable
    public func exp<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns two raised to the power of the specified tensor element-wise.

    Declaration

    @differentiable
    public func exp2<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns ten raised to the power of the specified tensor element-wise.

    Declaration

    @differentiable
    public func exp10<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the exponential of x - 1 element-wise.

    Declaration

    @differentiable
    public func expm1<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the values of the specified tensor rounded to the nearest integer, element-wise.

    Declaration

    @differentiable
    public func round<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the ceiling of the specified tensor element-wise.

    Declaration

    @differentiable
    public func ceil<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the floor of the specified tensor element-wise.

    Declaration

    @differentiable
    public func floor<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns an indication of the sign of the specified tensor element-wise. Specifically, computes y = sign(x) = -1 if x < 0; 0 if x == 0; 1 if x > 0.

    Declaration

    @differentiable
    public func sign<T>(_ x: Tensor<T>) -> Tensor<T> where T : Numeric, T : TensorFlowScalar
  • Returns the sigmoid of the specified tensor element-wise. Specifically, computes 1 / (1 + exp(-x)).

    Declaration

    @differentiable
    public func sigmoid<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the log-sigmoid of the specified tensor element-wise. Specifically, log(1 / (1 + exp(-x))). For numerical stability, we use -softplus(-x).

    Declaration

    @differentiable
    public func logSigmoid<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the softplus of the specified tensor element-wise. Specifically, computes log(exp(features) + 1).

    Declaration

    @differentiable
    public func softplus<T>(_ features: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the softsign of the specified tensor element-wise. Specifically, computes features/ (abs(features) + 1).

    Declaration

    @differentiable
    public func softsign<T>(_ features: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the softmax of the specified tensor along the last axis. Specifically, computes exp(x) / exp(x).sum(alongAxes: -1).

    Declaration

    @differentiable
    public func softmax<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the softmax of the specified tensor along the specified axis. Specifically, computes exp(x) / exp(x).sum(alongAxes: axis).

    Declaration

    @differentiable
    public func softmax<T>(_ x: Tensor<T>, alongAxis axis: Int) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the log-softmax of the specified tensor element-wise.

    Declaration

    @differentiable
    public func logSoftmax<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying an exponential linear unit. Specifically, computes exp(x) - 1 if < 0, x otherwise. See Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)

    Declaration

    @differentiable
    public func elu<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the Gaussian Error Linear Unit (GELU) activations of the specified tensor element-wise.

    Specifically, gelu approximates xP(X <= x), where P(X <= x) is the Standard Gaussian cumulative distribution, by computing: x * [0.5 * (1 + tanh[√(2/π) * (x + 0.044715 * x^3)])].

    See Gaussian Error Linear Units.

    Declaration

    @differentiable
    public func gelu<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying the ReLU activation function to the specified tensor element-wise. Specifically, computes max(0, x).

    Declaration

    @differentiable
    public func relu<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying the ReLU6 activation function, namely min(max(0, x), 6).

    Declaration

    @differentiable
    public func relu6<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying the leaky ReLU activation function to the specified tensor element-wise. Specifically, computes max(x, x * alpha).

    Declaration

    @differentiable(wrt: x)
    public func leakyRelu<T: TensorFlowFloatingPoint>(
      _ x: Tensor<T>,
      alpha: Double = 0.2
    ) -> Tensor<T>
  • Returns a tensor by applying the SeLU activation function, namely scale * alpha * (exp(x) - 1) if x < 0, and scale * x otherwise.

    Note

    This is designed to be used together with the variance scaling layer initializers. Please refer to Self-Normalizing Neural Networks for more information.

    Declaration

    @differentiable
    public func selu<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying the swish activation function, namely x * sigmoid(x).

    Source: “Searching for Activation Functions” (Ramachandran et al. 2017) https://arxiv.org/abs/1710.05941

    Declaration

    @differentiable
    public func swish<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying the hard sigmoid activation function, namely Relu6(x+3)/6.

    Source: “Searching for MobileNetV3” (Howard et al. 2019) https://arxiv.org/abs/1905.02244

    Declaration

    @differentiable
    public func hardSigmoid<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying the hard swish activation function, namely x * Relu6(x+3)/6.

    Source: “Searching for MobileNetV3” (Howard et al. 2019) https://arxiv.org/abs/1905.02244

    Declaration

    @differentiable
    public func hardSwish<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns a tensor by applying the mish activation function, namely x * tanh(softplus(x)).

    Source: “Mish: A Self Regularized Non-Monotonic Neural Activation Function” https://arxiv.org/abs/1908.08681

    Declaration

    @differentiable
    public func mish<T>(_ x: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the power of the first tensor to the second tensor.

    Declaration

    @differentiable
    public func pow<T>(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the power of the scalar to the tensor, broadcasting the scalar.

    Declaration

    @differentiable(wrt: rhs)
    public func pow<T>(_ lhs: T, _ rhs: Tensor<T>) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the power of the tensor to the scalar, broadcasting the scalar.

    Declaration

    @differentiable(wrt: lhs)
    public func pow<T>(_ lhs: Tensor<T>, _ rhs: T) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the power of the tensor to the scalar, broadcasting the scalar.

    Declaration

    @differentiable
    public func pow<T>(_ x: Tensor<T>, _ n: Int) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the element-wise nth root of the tensor.

    Declaration

    @differentiable
    public func root<T>(_ x: Tensor<T>, _ n: Int) -> Tensor<T> where T : TensorFlowFloatingPoint
  • Returns the squared difference between x and y.

    Declaration

    @differentiable
    public func squaredDifference<T>(_ x: Tensor<T>, _ y: Tensor<T>) -> Tensor<T> where T : Numeric, T : TensorFlowScalar

    Return Value

    (x - y) ^ 2.

  • Returns the element-wise maximum of two tensors.

    Note

    max supports broadcasting.

    Declaration

    @differentiable
    public func max<T>(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> Tensor<T> where T : Comparable, T : Numeric, T : TensorFlowScalar
  • Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.

    Declaration

    @differentiable(wrt: rhs)
    public func max<T>(_ lhs: T, _ rhs: Tensor<T>) -> Tensor<T> where T : Comparable, T : Numeric, T : TensorFlowScalar
  • Returns the element-wise maximum of the scalar and the tensor, broadcasting the scalar.

    Declaration

    @differentiable(wrt: lhs)
    public func max<T>(_ lhs: Tensor<T>, _ rhs: T) -> Tensor<T> where T : Comparable, T : Numeric, T : TensorFlowScalar
  • Returns the element-wise minimum of two tensors.

    Note

    min supports broadcasting.

    Declaration

    @differentiable
    public func min<T>(_ lhs: Tensor<T>, _ rhs: Tensor<T>) -> Tensor<T> where T : Comparable, T : Numeric, T : TensorFlowScalar
  • Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.

    Declaration

    @differentiable(wrt: rhs)
    public func min<T>(_ lhs: T, _ rhs: Tensor<T>) -> Tensor<T> where T : Comparable, T : Numeric, T : TensorFlowScalar
  • Returns the element-wise minimum of the scalar and the tensor, broadcasting the scalar.

    Declaration

    @differentiable(wrt: lhs)
    public func min<T>(_ lhs: Tensor<T>, _ rhs: T) -> Tensor<T> where T : Comparable, T : Numeric, T : TensorFlowScalar
  • Returns the cosine similarity between x and y.

    Declaration

    @differentiable
    public func cosineSimilarity<Scalar: TensorFlowFloatingPoint>(
      _ x: Tensor<Scalar>,
      _ y: Tensor<Scalar>
    ) -> Tensor<Scalar>
  • Returns the cosine distance between x and y. Cosine distance is defined as 1 - cosineSimilarity(x, y).

    Declaration

    @differentiable
    public func cosineDistance<Scalar: TensorFlowFloatingPoint>(
      _ x: Tensor<Scalar>,
      _ y: Tensor<Scalar>
    ) -> Tensor<Scalar>
  • Performs matrix multiplication with another tensor and produces the result.

    Declaration

    @differentiable
    public func matmul<Scalar: Numeric>(
      _ lhs: Tensor<Scalar>,
      transposed transposeLhs: Bool = false,
      _ rhs: Tensor<Scalar>,
      transposed transposeRhs: Bool = false
    ) -> Tensor<Scalar>
  • Returns a 1-D convolution with the specified input, filter, stride, and padding.

    Precondition

    input must have rank 3.

    Precondition

    filter must have rank 3.

    Declaration

    @differentiable(wrt: (input, filter)
    public func conv1D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filter: Tensor<Scalar>,
      stride: Int = 1,
      padding: Padding = .valid,
      dilation: Int = 1
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filter

    The convolution filter.

    stride

    The stride of the sliding filter.

    padding

    The padding for the operation.

    dilation

    The dilation factor.

  • Returns a 2-D convolution with the specified input, filter, strides, and padding.

    Precondition

    input must have rank 4.

    Precondition

    filter must have rank 4.

    Declaration

    @differentiable(wrt: (input, filter)
    public func conv2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filter: Tensor<Scalar>,
      strides: (Int, Int, Int, Int) = (1, 1, 1, 1),
      padding: Padding = .valid,
      dilations: (Int, Int, Int, Int) = (1, 1, 1, 1)
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filter

    The convolution filter.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation

    dilations

    The dilation factor for each dimension of the input.

  • Returns a 2-D transposed convolution with the specified input, filter, strides, and padding.

    Precondition

    input must have rank 4.

    Precondition

    filter must have rank 4.

    Declaration

    @differentiable(wrt: (input, filter)
    public func transposedConv2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      shape: [Int64],
      filter: Tensor<Scalar>,
      strides: (Int, Int, Int, Int) = (1, 1, 1, 1),
      padding: Padding = .valid,
      dilations: (Int, Int, Int, Int) = (1, 1, 1, 1)
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    shape

    The output shape of the deconvolution operation.

    filter

    The convolution filter.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation

    dilations

    The dilation factor for each dimension of the input.

  • Returns a 3-D convolution with the specified input, filter, strides, padding and dilations.

    Precondition

    input must have rank 5.

    Precondition

    filter must have rank 5.

    Declaration

    @differentiable(wrt: (input, filter)
    public func conv3D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filter: Tensor<Scalar>,
      strides: (Int, Int, Int, Int, Int) = (1, 1, 1, 1, 1),
      padding: Padding = .valid,
      dilations: (Int, Int, Int, Int, Int) = (1, 1, 1, 1, 1)
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filter

    The convolution filter.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation.

    dilations

    The dilation factor for each dimension of the input.

  • Returns a 2-D depthwise convolution with the specified input, filter, strides, and padding.

    Precondition

    input must have rank 4.

    Precondition

    filter must have rank 4.

    Declaration

    @differentiable(wrt: (input, filter)
    public func depthwiseConv2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filter: Tensor<Scalar>,
      strides: (Int, Int, Int, Int),
      padding: Padding
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filter

    The depthwise convolution filter.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation.

  • Returns a 2-D max pooling, with the specified filter sizes, strides, and padding.

    Declaration

    @differentiable(wrt: input)
    public func maxPool2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filterSize: (Int, Int, Int, Int),
      strides: (Int, Int, Int, Int),
      padding: Padding
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filterSize

    The dimensions of the pooling kernel.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation.

  • Returns a 3-D max pooling, with the specified filter sizes, strides, and padding.

    Declaration

    @differentiable(wrt: input)
    public func maxPool3D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filterSize: (Int, Int, Int, Int, Int),
      strides: (Int, Int, Int, Int, Int),
      padding: Padding
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filterSize

    The dimensions of the pooling kernel.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation.

  • Returns a 2-D average pooling, with the specified filter sizes, strides, and padding.

    Declaration

    @differentiable(wrt: input)
    public func avgPool2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filterSize: (Int, Int, Int, Int),
      strides: (Int, Int, Int, Int),
      padding: Padding
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filterSize

    The dimensions of the pooling kernel.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation.

  • Returns a 3-D average pooling, with the specified filter sizes, strides, and padding.

    Declaration

    @differentiable(wrt: input)
    public func avgPool3D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      filterSize: (Int, Int, Int, Int, Int),
      strides: (Int, Int, Int, Int, Int),
      padding: Padding
    ) -> Tensor<Scalar>

    Parameters

    input

    The input.

    filterSize

    The dimensions of the pooling kernel.

    strides

    The strides of the sliding filter for each dimension of the input.

    padding

    The padding for the operation.

  • Returns a 2-D fractional max pooling, with the specified pooling ratios.

    Note: fractionalMaxPool does not have an XLA implementation, and thus may have performance implications.

    Declaration

    @differentiable(wrt: input)
    public func fractionalMaxPool2D<Scalar: TensorFlowFloatingPoint>(
      _ input: Tensor<Scalar>,
      poolingRatio: (Double, Double, Double, Double),
      pseudoRandom: Bool = false,
      overlapping: Bool = false,
      deterministic: Bool = false,
      seed: Int64 = 0,
      seed2: Int64 = 0
    ) -> Tensor<Scalar>

    Parameters

    input

    A Tensor. 4-D with shape [batch, height, width, channels].

    poolingRatio

    A list of Doubles. Pooling ratio for each dimension of input, currently only supports row and col dimension and should be >= 1.0.

    pseudoRandom

    An optional Bool. Defaults to false. When set to true, generates the pooling sequence in a pseudorandom fashion, otherwise, in a random fashion.

    overlapping

    An optional Bool. Defaults to false. When set to true, it means when pooling, the values at the boundary of adjacent pooling cells are used by both cells.

    deterministic

    An Optional Bool. When set to true, a fixed pooling region will be used when iterating over a fractionalMaxPool2D node in the computation graph.

    seed

    An optional Int64. Defaults to 0. If set to be non-zero, the random number generator is seeded by the given seed.

    seed2

    An optional Int64. Defaults to 0. A second seed to avoid seed collision.

  • Returns a copy of input where values from the depth dimension are moved in spatial blocks to the height and width dimensions.

    For example, given an input of shape [1, 2, 2, 1], data_format = “NHWC” and block_size = 2:

    x = [[[[1], [2]],
          [[3], [4]]]]
    

    This operation will output a tensor of shape [1, 1, 1, 4]:

    [[[[1, 2, 3, 4]]]]
    

    Here, the input has a batch of 1 and each batch element has shape [2, 2, 1], the corresponding output will have a single element (i.e. width and height are both 1) and will have a depth of 4 channels (1 * block_size * block_size). The output element shape is [1, 1, 4].

    For an input tensor with larger depth, here of shape [1, 2, 2, 3], e.g.

    x = [[[[1, 2, 3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]]
    

    This operation, for block_size of 2, will return the following tensor of shape [1, 1, 1, 12]

    [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]
    

    Similarly, for the following input of shape [1 4 4 1], and a block size of 2:

    x = [[[[1],   [2],  [5],  [6]],
          [[3],   [4],  [7],  [8]],
          [[9],  [10], [13],  [14]],
          [[11], [12], [15],  [16]]]]
    

    the operator will return the following tensor of shape [1 2 2 4]:

    x = [[[[1, 2, 3, 4],
           [5, 6, 7, 8]],
          [[9, 10, 11, 12],
           [13, 14, 15, 16]]]]
    

    Precondition

    input.rank == 4 && b >= 2.

    Precondition

    The number of the features must be divisible by square of b.

    Declaration

    @differentiable(wrt: input)
    public func depthToSpace<Scalar>(_ input: Tensor<Scalar>, blockSize b: Int) -> Tensor<Scalar> where Scalar : TensorFlowScalar
  • Returns a copy of input where values from the height and width dimensions are moved to the depth dimension.

    For example, given an input of shape [1, 2, 2, 1], data_format = “NHWC” and block_size = 2:

    x = [[[[1], [2]],
          [[3], [4]]]]
    

    This operation will output a tensor of shape [1, 1, 1, 4]:

    [[[[1, 2, 3, 4]]]]
    

    Here, the input has a batch of 1 and each batch element has shape [2, 2, 1], the corresponding output will have a single element (i.e. width and height are both 1) and will have a depth of 4 channels (1 * block_size * block_size). The output element shape is [1, 1, 4].

    For an input tensor with larger depth, here of shape [1, 2, 2, 3], e.g.

    x = [[[[1, 2, 3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]]
    

    This operation, for block_size of 2, will return the following tensor of shape [1, 1, 1, 12]

    [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]
    

    Similarly, for the following input of shape [1 4 4 1], and a block size of 2:

    x = [[[[1],   [2],  [5],  [6]],
          [[3],   [4],  [7],  [8]],
          [[9],  [10], [13],  [14]],
          [[11], [12], [15],  [16]]]]
    

    the operator will return the following tensor of shape [1 2 2 4]:

    x = [[[[1, 2, 3, 4],
           [5, 6, 7, 8]],
          [[9, 10, 11, 12],
           [13, 14, 15, 16]]]]
    

    Precondition

    input.rank == 4 && b >= 2.

    Precondition

    The height of the input must be divisible by b.

    Precondition

    The width of the input must be divisible by b.

    Declaration

    @differentiable(wrt: input)
    public func spaceToDepth<Scalar>(_ input: Tensor<Scalar>, blockSize b: Int) -> Tensor<Scalar> where Scalar : TensorFlowScalar
  • Builds a per-weight optimizer for LARS (https://arxiv.org/pdf/1708.03888.pdf).

    Declaration

    public func makeLARS(
      learningRate: Float = 0.01,
      momentum: Float = 0.9,
      trustCoefficient: Float = 0.001,
      nesterov: Bool = false,
      epsilon: Float = 0.0,
      weightDecay: Float = 0.0
    ) -> ParameterGroupOptimizer
  • Builds a SGD based per-weight optimizer.

    Declaration

    public func makeSGD(
      learningRate: Float = 0.01,
      momentum: Float = 0,
      weightDecay: Float = 0,
      nesterov: Bool = false
    ) -> ParameterGroupOptimizer
  • Builds a per-weight optimizer for Adam with weight decay.

    Reference: “Adam - A Method for Stochastic Optimization”

    Declaration

    public func makeAdam(
      learningRate: Float = 0.01,
      beta1: Float = 0.9,
      beta2: Float = 0.999,
      weightDecayRate: Float = 0.01,
      epsilon: Float = 1e-6
    ) -> ParameterGroupOptimizer
  • Generates a new random seed for TensorFlow.

    Declaration

    public func randomSeedForTensorFlow(using seed: TensorFlowSeed? = nil) -> TensorFlowSeed
  • Concatenates two values.

    Declaration

    @differentiable
    public func concatenate<T: Mergeable>(
      _ first: T,
      _ second: T
    ) -> T
  • Adds two values and produces their sum.

    Declaration

    @differentiable
    public func sum<T: Mergeable>(
      _ first: T,
      _ second: T
    ) -> T
  • Averages two values.

    Declaration

    @differentiable
    public func average<T: Mergeable>(
      _ first: T,
      _ second: T
    ) -> T
  • Multiplies two values.

    Declaration

    @differentiable
    public func multiply<T: Mergeable>(
      _ first: T,
      _ second: T
    ) -> T
  • Stack two values.

    Declaration

    @differentiable
    public func stack<T: Mergeable>(
      _ first: T,
      _ second: T
    ) -> T
  • Declaration

    public func PrintX10Metrics()
  • Creates a string summary of a list of training and testing stats.

    Declaration

    public func formatStatistics(_ stats: (train: HostStatistics, test: HostStatistics)) -> String
  • Declaration

    public func formatStatistics(train trainStats: HostStatistics, test testStats: HostStatistics)
      -> String
  • Maps a function over n threads.

    Declaration

    public func runOnNThreads<R>(_ nThreads: Int, _ body: @escaping (Int) -> R) -> [R]