so, the code snipper I tried does not reproduce the code but when I run my original code I get the error.
This is the module by which I calculate the loss. It is a Bayesian Neural network used for continual learning and the loss is the ELBO loss.
from __future__ import absolute_import
from __future__ import print_function
from tkinter import N
import torch
torch.autograd.set_detect_anomaly(True)
torch.use_deterministic_algorithms(True, warn_only=True)
import torch.nn as nn
from torch.nn import functional as F
import pdb
def _accuracy(output, target, topk=(1,)):
#pdb.set_trace()
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class ClassificationLossVI(nn.Module):
def __init__(self, args, topk=3):
super(ClassificationLossVI, self).__init__()
self._topk = tuple(range(1, topk+1))
self.label_trick = args.label_trick
self.label_trick_valid = args.label_trick_valid
self.coreset_training = args.coreset_training
self.coreset_kld = args.coreset_kld
self.merged_training = args.merged_training
def forward(self, output_dict, target_dict):
samples = 1
prediction_mean = output_dict['prediction_mean'].unsqueeze(dim=2).expand(-1, -1, samples)
prediction_variance = output_dict['prediction_variance'].unsqueeze(dim=2).expand(-1, -1, samples)
target = target_dict['target1']
target_expanded = target.unsqueeze(dim=1).expand(-1, samples)
normal_dist = torch.distributions.normal.Normal(torch.zeros_like(prediction_mean), torch.ones_like(prediction_mean))
if self.training:
losses = {}
normals = normal_dist.sample()
prediction = prediction_mean + torch.sqrt(prediction_variance) * normals
# -------------------------------------------------------------------------------
# Labels trick
# -------------------------------------------------------------------------------
if self.label_trick is False or self.coreset_kld==1:
# check the dtype of prediction tensor
loss = F.cross_entropy(prediction, target_expanded, reduction='mean')
kl_div = output_dict['kl_div']
losses['total_loss'] = loss + kl_div()
with torch.no_grad():
p = F.softmax(prediction, dim=1).mean(dim=2)
losses['xe'] = F.cross_entropy(prediction, target_expanded, reduction='mean')
acc_k = _accuracy(p, target, topk=self._topk)
for acc, k in zip(acc_k, self._topk):
losses["top%i" % k] = acc
else:
task_targets = [item -30 for item in target_dict['task_labels']]
ordered_task_targets = torch.unique(torch.Tensor(task_targets).long(), sorted=True)
if self.merged_training is True:
coreset_targets = target_dict['coresets_list']
if len(coreset_targets)>0:
flat_coreset_targets = [item for sublist in coreset_targets for item in sublist]
seen_targets=torch.cat((torch.Tensor(flat_coreset_targets),torch.Tensor(task_targets)), 0)
ordered_task_targets = torch.unique(seen_targets, sorted=True).long()
# Get the current batch labels (and sort them for reassignment)
labels = target.clone().detach()
for t_idx, t in enumerate(ordered_task_targets):
labels[labels==t] = t_idx
labels_expanded = labels.unsqueeze(dim=1).expand(-1, samples)
loss_label_trick = F.cross_entropy(prediction[:, ordered_task_targets, :], labels_expanded, reduction='mean')
kl_div = output_dict['kl_div']
losses['total_loss'] = loss_label_trick + kl_div()
with torch.no_grad():
p = F.softmax(prediction[:, ordered_task_targets, :], dim=1).mean(dim=2)
losses['xe'] = F.cross_entropy(prediction[:, ordered_task_targets, :], labels_expanded, reduction='mean')
acc_k = _accuracy(p, labels, topk=self._topk)
for acc, k in zip(acc_k, self._topk):
losses["top%i" % k] = acc
# ---------------------------------------------------------------------------------------------------
else:
if self.label_trick and self.label_trick_valid:
with torch.no_grad():
normals = normal_dist.sample()
prediction = prediction_mean + torch.sqrt(prediction_variance) * normals
labels = target.clone().detach()
task_targets = target_dict['task_labels'][0] #shape: [10, 10]
ordered_task_targets = torch.unique(task_targets, sorted=True)
for t_idx, t in enumerate(ordered_task_targets):
labels[labels==t] = t_idx
losses = {}
kl_div = output_dict['kl_div']
p = F.softmax(prediction[:, ordered_task_targets, :], dim=1).mean(dim=2)
losses['total_loss'] = - torch.log(p[range(p.shape[0]), labels]).mean() + kl_div()
losses['xe'] = - torch.log(p[range(p.shape[0]), labels]).mean()
acc_k = _accuracy(p, labels, topk=self._topk)
for acc, k in zip(acc_k, self._topk):
losses["top%i" % k] = acc
else:
pdb.set_trace()
with torch.no_grad():
normals = normal_dist.sample()
prediction = prediction_mean + torch.sqrt(prediction_variance) * normals
p = F.softmax(prediction, dim=1).mean(dim=2)
losses = {}
kl_div = output_dict['kl_div']
losses['total_loss'] = - torch.log(p[range(p.shape[0]), target]).mean() + kl_div()
losses['xe'] = - torch.log(p[range(p.shape[0]), target]).mean()
acc_k = _accuracy(p, target, topk=self._topk)
for acc, k in zip(acc_k, self._topk):
losses["top%i" % k] = acc
return losses
def set_coreset_kld_flag(self, _flag):
self.coreset_kld=_flag