Optimizers

The optimizer is a crucial component in training a machine learning model. Here, I will cover only a small subset of those available in PyTorch, illustrating their evolution starting from SGD.

Recap

Before diving into optimizers, let’s quickly recap gradient descent. Gradient descent computes the direction of the steepest slope of the loss function. In other words, it tells us how the loss changes with respect to each weight. If the gradient is negative, subtracting the learning rate times the gradient from the weight will move it in the direction that reduces the loss. The opposite is also true: if the gradient is positive, increasing the weight would increase the loss. Therefore, subtracting the learning rate times the positive gradient reduces the weight value, which moves the model toward minimizing the loss.

Figure1. Gradient descent curve. The blue line represents the model’s loss as a function of its weights (y-axis: loss, x-axis: weights). The orange lines show the local slopes or gradients—positive gradients point upward (left to right), while negative gradients point downward (right to left). The gray dots illustrate the gradient descent steps, where the learning rate determines the step size. The movement progresses toward the global minimum as the model parameters are updated. Source: adapted from research gate.

Stocastic gradient descent (SGD)

Stochastic Gradient Descent (SGD) is an optimization method that updates model parameters using gradients calculated from a batch.

Firstly, the gradient of the loss with respect to the weight is computed. θ represents a vector of weights at a given batch:

gt=θL(θt)g_t = \nabla_\theta L(\theta_t)

Then the weights are updated: The new weight is equal to the previous weight minus the learning rate (η) multiplied by the gradient.

θt+1=θtηgt \quad \theta_{t+1} = \theta_t - \eta g_t

The weight updates can be noisy due to the lack of adaptation of its parameters. That's way the momentum was added to SGD.

SGD with momentum

SGD with momentum adds a velocity term that accumulates past gradients to smooth the optimization path and accelerate convergence.

First, we compute the gradient:

gt=θL(θt)g_t = \nabla_\theta L(\theta_t)

Secondly, compute the velocity term (𝑣𝑡), where β (commonly 0.9) controls how much of the past gradient is retained. This helps damp oscillations in noisy directions and improves stability. Thus, accelerating gradient descent in directions of persistent reduction in the objective across iterations.

vt+1=βvt+(1β)gtv_{t+1} = \beta v_t + (1-\beta) g_t

Then update the weights:

θt+1=θtηvt+1\quad \theta_{t+1} = \theta_t - \eta v_{t+1}

Adam (Adaptive Moment Estimation)

Adam extends SGD by keeping exponential moving averages of the gradients (𝑚𝑡) and their squared values (𝑣𝑡). These moment estimates allow the optimizer to adapt the learning rate for each parameter based on the magnitude and variability of its past gradients.

First, we compute the gradient:

gt=θL(θt)+λ2θt22,gt=θL(θt)+λθtg_t = \nabla_{\theta}L(\theta_t) + \frac{\lambda}{2} \|\theta_t\|_2^2 , \\ g_t = \nabla_\theta L(\theta_t) + \lambda \theta_t

Here, 𝜆𝜃𝑡 is the derivative of the L2 regularization term. This term adds a penalty proportional to the weight values, nudging them toward zero during training and helping to reduce overfitting. However, since Adam adapts the learning rate for each parameter using 𝑚𝑡 and 𝑣𝑡, the L2 term is also scaled by these adaptive factors. This can cause the effective weight decay to vary across parameters, making regularization inconsistent.

Secondly, updating the moving averages of the gradients:

mt+1=β1mt+(1β1)gt,vt+1=β2vt+(1β2)gt2m_{t+1} = \beta_1 m_t + (1 - \beta_1) g_t, \\ \quad v_{t+1} = \beta_2 v_t + (1 - \beta_2) g_t^2

𝑚𝑡+1​ is the exponentially weighted average of the gradients (first moment), similar to momentum. 𝑣𝑡+1​ is the exponentially weighted average of squared gradients (second moment), which tracks the gradient’s variance. These moving averages allow Adam to adapt the step size for each parameter individually, smoothing noisy updates and improving convergence.

Lastly, update parameters:

θt+1=θtηmt+1vt+1+ϵ\theta_{t+1} = \theta_t - \eta \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}

AdamW (weight decay)

AdamW improves upon Adam by decoupling weight decay from the gradient-based parameter update and applying regularization directly to the parameters. In the original Adam, as we saw, the L2 regularization was applied through the gradient, which caused the weight decay effect to interact improperly with adaptive learning rates. This decoupling prevents the step size from being unintentionally scaled by weight decay.

Firtly, compute loss without the L2 regularization:

gt=θL(θt)g_t = \nabla_\theta L(\theta_t)

Secondly, updating the moving averages of the gradients. These running averages act as per-parameter scaling factors, making the learning rate adaptive. Adam keeps a constant global learning rate η, but automatically adjusts it locally for each weight based on how noisy or stable its gradients are.

mt+1=β1mt+(1β1)gt,vt+1=β2vt+(1β2)gt2m_{t+1} = \beta_1 m_t + (1 - \beta_1)g_t, \\ \quad v_{t+1} = \beta_2 v_t + (1 - \beta_2)g_t^2

Lastly, update parameters but now applying weight decay directly to the model parameters:

θt+1=θtη(mt+1vt+1+ϵ+λθt)\theta_{t+1} = \theta_t - \eta \left( \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda \theta_t \right)

This decoupling leads to better generalization and more consistent behavior, particularly in large-scale models like transformers. AdamW is now the default optimizer in most modern deep learning frameworks.

Further reading

The paper "Decoupled Weight Decay Regularization" by Ilya Loshchilov and Frank Hutter proposed AdamW, an Adam variant with decoupled weight decay. It highlights the limitations of applying standard L2 regularization directly in Adam and how decoupling improves regularization.

A brief overview of what the paper covers.

  • L2 regularization and weight decay are not identical.

  • L2 regularization is not effective in Adam.

  • Optimal weight decay depends on the total number of batch passes/weight updates

References

Loshchilov, Ilya, and Frank Hutter. "Decoupled weight decay regularization." arXiv preprint arXiv:1711.05101 (2017).

Sutskever, Ilya, et al. "On the importance of initialization and momentum in deep learning." International conference on machine learning. pmlr, 2013.

All the equations shown here are simplified or adapted from the original sources.

SGD. Pytorch: https://docs.pytorch.org/docs/stable/generated/torch.optim.SGD.html

Adam. Pytorch: https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html

AdamW. Pytorch: https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html

Optimizers. Pytorch: https://docs.pytorch.org/docs/main/optim.html

Last updated