Masking the gradients before the update


I was reading this paper Learning explanations that are hard to vary and found the relative github repo. To keep it short, before updating the parameters theta = theta - lr * final_grads pytorch (cuda) computes by default the arithmetic mean of the gradients, whereas I want to compute the geometric mean or to apply a mask as shown in the code.

Is there a way to do this leveraging pytorch autograd + cuda without the need to write a custom training loop?

Code taken from the linked notebook

def opt(x, y, method, lr, weight_decay, n_iters, verbose=False):
    thetas, iters, losses = [], [0], []
    theta = torch.randn(5, requires_grad=True) * 0.1
    with torch.no_grad():
        loss = loss_fn(x, theta, y)
    for i in range(n_iters + 1):
        lr *= 0.9995
        grads = []
        loss_e = 0.
        for e in range(x.shape[0]):
            loss_e = loss_fn(x[e], theta, y[e])
            grad_e = torch.autograd.grad(loss_e, theta)[0]

        grad = torch.stack(grads, dim=-1)
        if method == 'geom_mean':
            n_agreement_domains = len(grads)
            signs = torch.sign(grad) 
            mask = torch.abs(signs.mean(dim=-1)) == 1
            avg_grad = grad.mean(dim=-1) * mask
            prod_grad = torch.sign(avg_grad) * \
                        torch.exp(torch.sum(torch.log(torch.abs(grad) + 1e-10), dim=1)) \
                        ** (1. / n_agreement_domains)
            final_grads = prod_grad
        elif method == 'and_mask':
            signs = torch.sign(grad) 
            mask = torch.abs(signs.mean(dim=-1)) == 1
            avg_grad = grad.mean(dim=-1) * mask
            final_grads = avg_grad
        elif method == 'arithm_mean':
            avg_grad = grad.mean(dim=-1)
            final_grads = avg_grad
            raise ValueError()
        theta = theta - lr * final_grads
        # weight decay
        theta -= weight_decay * lr * theta
        if not i % (n_iters // 200):
            with torch.no_grad():
                loss = loss_fn(x, theta, y)
        if not i % (n_iters // 5):
            print(".", end="")
            with torch.no_grad():
                loss = loss_fn(x, theta, y)
                if verbose:
                    print(f"loss: {loss.item():.6f}, theta: {}, it: {i}")
    with torch.no_grad():
        loss = loss_fn(x, theta, y)
        print(f"loss: {loss.item():.6f}, theta: {}, it: {i}")
    return np.stack(thetas), iters, losses