I try to write my own function, but backward not called in some situation. My codes:
import torch
import torch.nn as nn
from torch.autograd import Variable, Function
# my own function
class TVLossFunction(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_output):
print('TV Loss backward Called')
x, = ctx.saved_tensors
u_diff = x[:, :, :-1, :-1] - x[:, :, :-1, 1:]
v_diff = x[:, :, :-1, :-1] - x[:, :, 1:, :-1]
grad_input = torch.zeros(x.size())
grad_input[:, :, :-1, :-1] = u_diff + v_diff
grad_input[:, :, :-1, 1:] -= u_diff
grad_input[:, :, 1:, :-1] -= v_diff
grad_input.add_(grad_output.data)
return Variable(grad_input),
class OkNet(nn.Module):
def __init__(self):
super(OkNet, self).__init__()
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, input):
out = TVLossFunction.apply(input)
out = self.conv(out)
return out
class FailedNet(nn.Module):
def __init__(self):
super(FailedNet, self).__init__()
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, input):
out = self.conv(input)
out = TVLossFunction.apply(out)
return out
# Test Ok
data = torch.FloatTensor([[1, 2, 3, 4, 5],
[2, 3, 4, 5, 1],
[3, 4, 5, 1, 2],
[4, 5, 1, 2, 3],
[5, 1, 2, 3, 4],
]).expand((3, 3, 5, 5))
input = Variable(torch.FloatTensor().resize_as_(data).copy_(data), requires_grad=True)
net = OkNet()
output = net(input)
output.backward(Variable(torch.zeros(output.size())))
# OkNet works。 It prints out : TV Loss backward Called
# Test Failed
data = torch.FloatTensor([[1, 2, 3, 4, 5],
[2, 3, 4, 5, 1],
[3, 4, 5, 1, 2],
[4, 5, 1, 2, 3],
[5, 1, 2, 3, 4],
]).expand((3, 3, 5, 5))
input = Variable(torch.FloatTensor().resize_as_(data).copy_(data), requires_grad=True)
net = FailedNet()
output = net(input)
output.backward(Variable(torch.zeros(output.size())))
# Nothing print out
My pytorch version is 0.3.0。