SGD

public class SGD<Model: Differentiable>: Optimizer
where
  Model.TangentVector: VectorProtocol & ElementaryFunctions & KeyPathIterable,
  Model.TangentVector.VectorSpaceScalar == Float

Um otimizador de descida gradiente estocástico (SGD).

Implementa o algoritmo de descida gradiente estocástico com suporte para momentum, decaimento da taxa de aprendizagem e momentum de Nesterov. Momentum e momentum de Nesterov (também conhecido como método de gradiente acelerado de Nesterov) são métodos de otimização de primeira ordem que podem melhorar a velocidade de treinamento e a taxa de convergência da descida do gradiente.

Referências:

  • Declaração

    public typealias Model = Model
  • A taxa de aprendizagem.

    Declaração

    public var learningRate: Float
  • O fator de momentum. Ele acelera a descida do gradiente estocástico na direção relevante e amortece as oscilações.

    Declaração

    public var momentum: Float
  • O declínio da taxa de aprendizagem.

    Declaração

    public var decay: Float
  • Use o momento de Nesterov se verdadeiro.

    Declaração

    public var nesterov: Bool
  • O estado de velocidade do modelo.

    Declaração

    public var velocity: Model.TangentVector
  • O conjunto de etapas realizadas.

    Declaração

    public var step: Int
  • Cria uma instância de model .

    Declaração

    public init(
      for model: __shared Model,
      learningRate: Float = 0.01,
      momentum: Float = 0,
      decay: Float = 0,
      nesterov: Bool = false
    )

    Parâmetros

    learningRate

    A taxa de aprendizagem. O valor padrão é 0.01 .

    momentum

    O fator de momentum que acelera a descida do gradiente estocástico na direção relevante e amortece as oscilações. O valor padrão é 0 .

    decay

    O declínio da taxa de aprendizagem. O valor padrão é 0 .

    nesterov

    Use Nesterov impulso sse true . O valor padrão é true .

  • Declaração

    public func update(_ model: inout Model, along direction: Model.TangentVector)
  • Declaração

    public required init(copying other: SGD, to device: Device)