opt = optim.SGD([lamb] if isinstance(lamb, torch.Tensor) else lamb, lr=lr_h, weight_decay=wd_h, momentum=mm_h)
theta_list = []
theta_list = loss_cls.init_theta(requires_grad=False)
print("\n\ntestxx: ", (theta_list[0][0]).size())
for it_h in range(T):
for i in range(M):
for it_l in range(K):
z_tr = [item.to(device) for item in next(train_dataset_loader)]
theta_list[i], loss_tr_sgd = sgd_step(theta_list[i], lamb, loss_cls.loss_in, z_tr, lr_l, wd_l)
z_val = [item.to(device) for item in next(val_dataset_loader)]
loss_val = 0
for i in range(M):
loss_val += loss_cls.loss_out(lamb, theta_list[i], z_val).mean()
loss_val /= M
opt.zero_grad()
loss_val.backward()
opt.step()
I’ve read the other posts, but unless I’m mistaken none of them solve my problem.