I am indexing the model classifier logits so I can compute the loss on certain nodes. I assumed that since the loss is defined only over those nodes, then only their connection (weights) in the classifier must be updated and the grad of the other connection should not change. However, when I checked the gradients before optimizer.ste(), I noticed that I have a full gradient matrix of values. I can’t debug my code because I am not sure where the problem stems from. Any help would be appreciated.
This is the loss code:
class ClassificationLossVI(nn.Module):
def __init__(self, args, topk=3):
super(ClassificationLossVI, self).__init__()
self._topk = tuple(range(1, topk+1))
self.label_trick = args.label_trick
def forward(self, output_dict, target_dict):
samples = 64
prediction_mean = output_dict['prediction_mean'].unsqueeze(dim=2).expand(-1, -1, samples)
prediction_variance = output_dict['prediction_variance'].unsqueeze(dim=2).expand(-1, -1, samples)
target = target_dict['target1'] # this is the batch target
#tensor([101, 150, 133, 40, 40, 133, 129, 57, 40, 40], device='cuda:0')
target_expanded = target.unsqueeze(dim=1).expand(-1, samples) #torch.Size([10, 64])
normal_dist = torch.distributions.normal.Normal(torch.zeros_like(prediction_mean), torch.ones_like(prediction_mean))
if self.training:
losses = {}
normals = normal_dist.sample()
prediction = prediction_mean + torch.sqrt(prediction_variance) * normals #torch.Size([10, 170, 64])
# -------------------------------------------------------------------------------
# Labels trick
# -------------------------------------------------------------------------------
if self.label_trick is False:
loss = F.cross_entropy(prediction, target_expanded, reduction='mean')
kl_div = output_dict['kl_div']
losses['total_loss'] = loss + kl_div()
with torch.no_grad():
p = F.softmax(prediction, dim=1).mean(dim=2)
losses['xe'] = F.cross_entropy(prediction, target_expanded, reduction='mean')
acc_k = _accuracy(p, target, topk=self._topk)
for acc, k in zip(acc_k, self._topk):
losses["top%i" % k] = acc
else:
task_targets = target_dict['task_labels'][0] #shape: [10, 10]
ordered_task_targets = torch.unique(task_targets, sorted=True)
#tensor([ 40, 48, 51, 57, 94, 101, 109, 129, 133, 150])
# Get current batch labels (and sort them for reassignment)
labels = target.clone().detach() #tensor([101, 150, 133, 40, 40, 133, 129, 57, 40, 40], device='cuda:0')
#unq_labels = torch.unique(labels, sorted=True) # in an ascending order - tensor([0, 1], device='cuda:0')
# ---------- ToDo: the unq_targets must be relabeld according to the indexes of task_targets ------------
for t_idx, t in enumerate(ordered_task_targets):
labels[labels==t] = t_idx
'''
# Assign new labels (0,1 ...)
for t_idx, t in enumerate(unq_labels):
labels[labels == t] = t_idx
'''
# expand the target here
labels_expanded = labels.unsqueeze(dim=1).expand(-1, samples) #torch.Size([batch_size, 64])
#loss_label_trick = F.cross_entropy(prediction[:, unq_labels, :], labels_expanded, reduction='mean')
# should we use an ordered task_targets
loss_label_trick = F.cross_entropy(prediction[:, ordered_task_targets, :], labels_expanded, reduction='mean')
kl_div = output_dict['kl_div']
losses['total_loss'] = loss_label_trick + kl_div()
with torch.no_grad():
# ToInvestigate: should we apply the label trick for caculating the accuracy and the xe as well?
p = F.softmax(prediction[:, ordered_task_targets, :], dim=1).mean(dim=2)
losses['xe'] = F.cross_entropy(prediction[:, ordered_task_targets, :], labels_expanded, reduction='mean')
acc_k = _accuracy(p, labels, topk=self._topk)
for acc, k in zip(acc_k, self._topk):
losses["top%i" % k] = acc