HI,
Support I have a module like this;
class Loss(nn.Module):
def __init__(self, lam=0.3):
super(Loss, self).__init__()
self.lam = lam
self.ce_crit = nn.CrossEntropyLoss(
reduction='none', ignore_index=ignore_index)
def forward(self, logits, label):
'''
args: logits: tensor of shape (N, C, H, W, ...)
args: label: tensor of shape(N, H, W, ...)
'''
# overcome ignored label
logits = logits.float()
with torch.no_grad():
num_classes = logits.size(1)
coeff = 1. / (num_classes - 1.)
idx = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1.)
lgts = logits - idx * 1.e6
q = lgts.softmax(dim=1) # how could I obtain grad of this tensor of q and log_q
log_q = lgts.log_softmax(dim=1)
mg_loss = ((q - coeff) * log_q) * (self.lam / 2)
mg_loss = mg_loss.sum(dim=1)
ce_loss = self.ce_crit(logits, label)
loss = ce_loss + mg_loss
return loss.mean()
I would like to know or manipulate the values of the gradient of the tensor of q
and log_q
during loss.backward()
, how could I do this ?