Hi, I have a custom torch.autograd.Function
defined like this.
class SourceMemory(Function):
def __init__(self, M, alpha=0.01):
super(SourceMemory, self).__init__()
self.M = M
self.alpha = alpha
def forward(self, inputs, targets):
self.save_for_backward(inputs, targets)
outputs = inputs.mm(self.M.t())
return outputs
def backward(self, grad_outputs):
inputs, targets = self.saved_tensors
grad_inputs = None
if self.needs_input_grad[0]:
grad_inputs = grad_outputs.mm(self.M)
for x, y in zip(inputs, targets):
self.M[y] = self.alpha * self.M[y] + (1. - self.alpha) * x
self.M[y] = F.normalize(self.M[y], p=2, dim=0)
return grad_inputs, None
class IntraNet(nn.Module):
def __init__(self, beta=0.05, alpha=0.01, num_classes=0, num_features=0, weight=None):
super(IntraNet, self).__init__()
self.beta = beta
self.alpha = alpha
self.weight = weight
self.register_buffer('M', torch.zeros(num_classes, num_features))
def forward(self, inputs, targets, features=None, epoch=None):
self.alpha = self.alpha * epoch
outputs = SourceMemory(self.M, alpha=self.alpha)(features, targets)
outputs /= self.beta
loss = F.cross_entropy(inputs, targets, weight=self.weight)
return loss, outputs
M
is a memory
containing the feature. The IntraNet Module is used as a criterion in
for i, inputs in enumerate(data_loader):
inputs, pids = self._parse_data(inputs)
outputs, features = self.model(inputs)
# this will update the buffer or the memory M
source_pid_loss, _ = self.inter_criterion(outputs, pids, features=features, epoch=epoch)
Where self.inter_criterion
is the IntraNet
module. and the model returns a output = [batch_size, num_class]
and features = [batch_size, 4096]
.
The problem is that, backward
of SourceMemory
is never called. Am I doing something wrong here? I want the buffer M
to be filled during the backward pass.
Thank you