Key Error while trying to follow Sharpness-Aware Minimization

Recently read about “Sharpness-Aware Minimization” and found some Pytorch Implementation:
https://github.com/davda54/sam

Following the file sam.py and implementation code, here is a very simplistic example (straight from Deep Learning with Pytorch book)

def model(t_u, w, b):
    return w * t_u + b

def loss_fn(t_p, t_c):
    squared_diffs = (t_p - t_c)**2
    return squared_diffs.mean()

input_ = [0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0]
output_ = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]
input_ = torch.tensor(input_)
output_ = torch.tensor(output_)

params = torch.tensor([1.0, 0.0], requires_grad=True)

base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM([params], base_optimizer, lr=0.1, momentum=0.9)

pred_ = model(input_, *params)

#  first forward-backward pass
loss = loss_fn(pred_, pred_)
loss.backward(retain_graph=True)
optimizer.first_step(zero_grad=True)

# second forward-backward pass
loss_fn(pred_, pred_).backward()  # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)

However, at the “second step”, there is a Key Error at self.state[p][“e_w”] = e_w.
I have included some additions to the “SAM” class in the second pass to create that [“e_w”] key in the code. Here it is simply recalculating grad_norm, scale, and e_w.

import torch


class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        grad_norm = self._grad_norm()  <------------------------- ### ADDITION
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12) <--------- ### ADDITION
            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)   <------------------- ### ADDITION
                self.state[p]["e_w"] = e_w  <-------------------- ### ADDITION
                p = p.sub(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

Questions:

  1. Since there are two forward/backward passes being made, are there a conceivable downside to having recalculated the necessary parts?

  2. Since it is part of the same class instance, why don’t the additions from first step stay till the second step?