Source code for pylo.optim.MuLO_naive

"""MuLO_CUDA: An Cuda-accelerated MLP learned optimizer in μP.

This is a PyTorch implementation of μLO from: https://arxiv.org/abs/2406.00153

The following code is adapted from the following Jax implementation: https://github.com/bentherien/mu_learned_optimization/blob/main/src/mup_adafac_mlp_lopt.py
"""
from mup.optim import process_param_groups
from collections import defaultdict
from pylo.optim import AdafacLO_naive


[docs]def MuLO_naive(params, impl=AdafacLO_naive, **kwargs): """ μP (Maximal Update Parameterization) wrapper for the PyTorch native implementation of the Adafac learned optimizer. This function applies the μP parameterization to the Adafac learned optimizer, scaling learning rates for matrix-like parameters according to their width multipliers. Parameters are organized into groups based on their infinite-width shape properties. Note: This implementation requires that all parameters have been processed with mup.set_base_shapes() to establish their infinite-width behavior. Example: >>> model = MyModel() >>> mup.set_base_shapes(model, base_model) >>> optimizer = MuLO_naive(model.parameters()) """ new_param_groups = [] for param_group in process_param_groups(params, **kwargs): # For every existing param group, we split into several new groups def new_group(): new_g = {k: v for k, v in param_group.items() if k != "params"} new_g["params"] = [] return new_g # The matrix-like weights might need multiple groups since weights # might have different width multipliers matrix_like_p = defaultdict(new_group) # key is width_mult vector_like_p = new_group() for p in param_group["params"]: # print(p.infshape.width_mult()) assert hasattr(p, "infshape"), ( f"A parameter with shape {p.shape} does not have `infshape` attribute. " "Did you forget to call `mup.set_base_shapes` on the model?" ) if p.infshape.ninf() == 2: matrix_like_p[p.infshape.width_mult()]["params"].append(p) elif p.infshape.ninf() > 2: raise NotImplementedError("more than 2 inf dimensions") else: vector_like_p["params"].append(p) for width_mult, group in matrix_like_p.items(): # Scale learning rate and weight decay accordingly group["lr"] /= width_mult new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) return impl(new_param_groups, **kwargs)