I know that p.copy(p - H) breaks the computational graph, so I need a workaround to retrieve the gradient of the loss wrt to the parameters of the optimizer
class SCALTRA(nn.Module):
def __init__(self, model, epochs,
loss_criterion,
beta = 0.9, tau = 1, gamma = 0.9):
super().__init__()
# Initialization of the learnable parameters
#self.lambda_ = nn.Parameter(torch.ones(1)*lambda_, requires_grad = True)
self.tau = nn.Parameter(torch.ones(1)*tau, requires_grad = True)
self.gamma = gamma
self.beta = beta
self.B = []
self.momentum = []
for _, p in model.named_parameters():
self.B.append(nn.Parameter(torch.randn(p.shape[0]), requires_grad = True))
self.momentum.append(torch.zeros_like(p))
self.params = nn.ParameterList(self.B + [self.tau])
# Some more initialization
self.loss_criterion = loss_criterion
self.meta_optimizer = torch.optim.Adam(params=self.parameters(), lr=1e-1, weight_decay = 5e-4)
self.epochs = epochs
def forward(self, model, X, y):
train_losses = []
for _ in tqdm(range(self.epochs)):
for i, (_, p) in enumerate(model.named_parameters()):
# Local optimization step
if len(p.shape) == 1:
H = 1/(self.B[i]**2 + self.tau*torch.ones(self.B[i].shape[0])) * self.momentum[i]
else:
H = torch.linalg.inv(torch.outer(self.B[i],self.B[i]) + self.tau*torch.eye(self.B[i].shape[0])) @ self.momentum[i]
with torch.no_grad():
# Convex smoothing
p.copy_(p - H)
self.gamma *= 0.99
loss = self.loss_criterion(model(X), y)
# Zero gradients for optimizer parameters
self.zero_grad()
# Compute gradients of loss with respect to optimizer parameters
loss.backward(retain_graph=True)
# Update optimizer parameters
self.meta_optimizer.step()
# Update momentum
for i, (_, p) in enumerate(model.named_parameters()):
self.momentum[i] = self.beta * self.momentum[i] + (1 - self.beta) * p.grad
train_losses.append(loss.detach().numpy())
return train_losses
How can I do so?