Here’s a toy model I played and the model throws RuntimError of no grad_fn.
class DevModel(nn.Module):
def __init__(self):
super(DevModel, self).__init__()
self.fc1 = nn.Linear(300, 1)
self.fc2 = nn.Linear(300, 1)
def forward(self, inputs_ticket, inputs_article):
out_ticket = self.fc1(inputs_ticket).squeeze()
out_article = self.fc2(inputs_article).squeeze()
mask_ticket = (out_ticket > 0.2).float()
mask_article = (out_article > 0.2).float()
mask_ticket = mask_ticket.unsqueeze(2).expand(2, 17, 300)
mask_article = mask_article.unsqueeze(2).expand(2, 17, 300)
extracted_ticket = torch.mul(inputs_ticket, mask_ticket)
extracted_article = torch.mul(inputs_article, mask_article)
ave_ticket = torch.mean(extracted_ticket, dim=1)
ave_article = torch.mean(extracted_article, dim=1)
scores = torch.matmul(ave_ticket, ave_article.transpose(1, 0))
return scores
My guess is the mask ops are causing the problem stops gradients propagating backward. Would anyone have any suggestions, like what the right way to do it?