RMSProp

public class RMSProp<Model: Differentiable>: Optimizer
where
  Model.TangentVector: VectorProtocol & PointwiseMultiplicative
    & ElementaryFunctions & KeyPathIterable,
  Model.TangentVector.VectorSpaceScalar == Float

Un ottimizzatore RMSProp.

Implementa l'algoritmo di ottimizzazione RMSProp. RMSProp è una forma di discesa del gradiente stocastica in cui i gradienti sono divisi per una media corrente della loro grandezza recente. RMSProp mantiene una media mobile del gradiente quadrato per ciascun peso.

Riferimenti:

  • Dichiarazione

    public typealias Model = Model
  • Il tasso di apprendimento.

    Dichiarazione

    public var learningRate: Float
  • Rho

    Il fattore di decadimento della media mobile del gradiente.

    Dichiarazione

    public var rho: Float
  • Un piccolo scalare aggiunto al denominatore per migliorare la stabilità numerica.

    Dichiarazione

    public var epsilon: Float
  • Il decadimento del tasso di apprendimento.

    Dichiarazione

    public var decay: Float
  • Il conteggio dei passi.

    Dichiarazione

    public var step: Float
  • I valori alfa per tutte le variabili differenziabili del modello.

    Dichiarazione

    public var alpha: Model.TangentVector
  • Crea un'istanza per model .

    Dichiarazione

    public init(
      for model: __shared Model,
      learningRate: Float = 1e-3,
      rho: Float = 0.9,
      epsilon: Float = 1e-8,
      decay: Float = 0
    )

    Parametri

    learningRate

    Il tasso di apprendimento. Il valore predefinito è 1e-3 .

    rho

    Il fattore di decadimento della media mobile del gradiente. Il valore predefinito è 0.9 .

    epsilon

    Un piccolo scalare aggiunto al denominatore per migliorare la stabilità numerica. Il valore predefinito è 1e-8 .

    decay

    Il decadimento del tasso di apprendimento. Il valore predefinito è 0 .

  • Dichiarazione

    public func update(_ model: inout Model, along direction: Model.TangentVector)
  • Dichiarazione

    public required init(copying other: RMSProp, to device: Device)