Hello
I was reading this paper Learning explanations that are hard to vary and found the relative github repo. To keep it short, before updating the parameters theta = theta - lr * final_grads
pytorch (cuda) computes by default the arithmetic mean of the gradients, whereas I want to compute the geometric mean or to apply a mask as shown in the code.
Is there a way to do this leveraging pytorch autograd + cuda without the need to write a custom training loop?
Code taken from the linked notebook
def opt(x, y, method, lr, weight_decay, n_iters, verbose=False):
thetas, iters, losses = [], [0], []
theta = torch.randn(5, requires_grad=True) * 0.1
thetas.append(theta.data.numpy())
with torch.no_grad():
loss = loss_fn(x, theta, y)
losses.append(loss.item())
for i in range(n_iters + 1):
lr *= 0.9995
grads = []
loss_e = 0.
for e in range(x.shape[0]):
loss_e = loss_fn(x[e], theta, y[e])
grad_e = torch.autograd.grad(loss_e, theta)[0]
grads.append(grad_e)
grad = torch.stack(grads, dim=-1)
if method == 'geom_mean':
n_agreement_domains = len(grads)
signs = torch.sign(grad)
mask = torch.abs(signs.mean(dim=-1)) == 1
avg_grad = grad.mean(dim=-1) * mask
prod_grad = torch.sign(avg_grad) * \
torch.exp(torch.sum(torch.log(torch.abs(grad) + 1e-10), dim=1)) \
** (1. / n_agreement_domains)
final_grads = prod_grad
elif method == 'and_mask':
signs = torch.sign(grad)
mask = torch.abs(signs.mean(dim=-1)) == 1
avg_grad = grad.mean(dim=-1) * mask
final_grads = avg_grad
elif method == 'arithm_mean':
avg_grad = grad.mean(dim=-1)
final_grads = avg_grad
else:
raise ValueError()
theta = theta - lr * final_grads
# weight decay
theta -= weight_decay * lr * theta
if not i % (n_iters // 200):
thetas.append(theta.data.numpy())
iters.append(i)
with torch.no_grad():
loss = loss_fn(x, theta, y)
losses.append(loss.item())
if not i % (n_iters // 5):
print(".", end="")
with torch.no_grad():
loss = loss_fn(x, theta, y)
if verbose:
print(f"loss: {loss.item():.6f}, theta: {theta.data.numpy()}, it: {i}")
with torch.no_grad():
loss = loss_fn(x, theta, y)
print(f"loss: {loss.item():.6f}, theta: {theta.data.numpy()}, it: {i}")
return np.stack(thetas), iters, losses