Masking the gradients before the update

Hello

I was reading this paper Learning explanations that are hard to vary and found the relative github repo. To keep it short, before updating the parameters theta = theta - lr * final_grads pytorch (cuda) computes by default the arithmetic mean of the gradients, whereas I want to compute the geometric mean or to apply a mask as shown in the code.

Is there a way to do this leveraging pytorch autograd + cuda without the need to write a custom training loop?

Code taken from the linked notebook

def opt(x, y, method, lr, weight_decay, n_iters, verbose=False):
    thetas, iters, losses = [], [0], []
    theta = torch.randn(5, requires_grad=True) * 0.1
    thetas.append(theta.data.numpy())
    
    with torch.no_grad():
        loss = loss_fn(x, theta, y)
        losses.append(loss.item())
    
    for i in range(n_iters + 1):
        lr *= 0.9995
        
        grads = []
        loss_e = 0.
        for e in range(x.shape[0]):
            loss_e = loss_fn(x[e], theta, y[e])
            grad_e = torch.autograd.grad(loss_e, theta)[0]
            grads.append(grad_e)

        grad = torch.stack(grads, dim=-1)
        
        if method == 'geom_mean':
            n_agreement_domains = len(grads)
            signs = torch.sign(grad) 
            mask = torch.abs(signs.mean(dim=-1)) == 1
            avg_grad = grad.mean(dim=-1) * mask
            prod_grad = torch.sign(avg_grad) * \
                        torch.exp(torch.sum(torch.log(torch.abs(grad) + 1e-10), dim=1)) \
                        ** (1. / n_agreement_domains)
            final_grads = prod_grad
        elif method == 'and_mask':
            signs = torch.sign(grad) 
            mask = torch.abs(signs.mean(dim=-1)) == 1
            avg_grad = grad.mean(dim=-1) * mask
            final_grads = avg_grad
        elif method == 'arithm_mean':
            avg_grad = grad.mean(dim=-1)
            final_grads = avg_grad
        else:
            raise ValueError()
            
        theta = theta - lr * final_grads
        
        # weight decay
        theta -= weight_decay * lr * theta
        
        if not i % (n_iters // 200):
            thetas.append(theta.data.numpy())
            iters.append(i)
            with torch.no_grad():
                loss = loss_fn(x, theta, y)
                losses.append(loss.item())
        
        if not i % (n_iters // 5):
            print(".", end="")
            with torch.no_grad():
                loss = loss_fn(x, theta, y)
                if verbose:
                    print(f"loss: {loss.item():.6f}, theta: {theta.data.numpy()}, it: {i}")
              
    with torch.no_grad():
        loss = loss_fn(x, theta, y)
        print(f"loss: {loss.item():.6f}, theta: {theta.data.numpy()}, it: {i}")
    return np.stack(thetas), iters, losses