Backward not called, Bugs?

I try to write my own function, but backward not called in some situation. My codes:

import torch
import torch.nn as nn
from torch.autograd import Variable, Function

# my own function
class TVLossFunction(Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        print('TV Loss backward Called')
        x, = ctx.saved_tensors
        u_diff = x[:, :, :-1, :-1] - x[:, :, :-1, 1:]
        v_diff = x[:, :, :-1, :-1] - x[:, :, 1:, :-1]
        grad_input = torch.zeros(x.size())
        grad_input[:, :, :-1, :-1] = u_diff + v_diff
        grad_input[:, :, :-1, 1:] -= u_diff
        grad_input[:, :, 1:, :-1] -= v_diff
        grad_input.add_(grad_output.data)
        return Variable(grad_input),

class OkNet(nn.Module):
    def __init__(self):
        super(OkNet, self).__init__()
        self.conv = nn.Conv2d(3, 3, 3)

    def forward(self, input):
        out = TVLossFunction.apply(input)
        out = self.conv(out)
        return out

class FailedNet(nn.Module):
    def __init__(self):
        super(FailedNet, self).__init__()
        self.conv = nn.Conv2d(3, 3, 3)

    def forward(self, input):
        out = self.conv(input)
        out = TVLossFunction.apply(out)
        return out

# Test Ok
data = torch.FloatTensor([[1, 2, 3, 4, 5],
                              [2, 3, 4, 5, 1],
                              [3, 4, 5, 1, 2],
                              [4, 5, 1, 2, 3],
                              [5, 1, 2, 3, 4],
                              ]).expand((3, 3, 5, 5))
input = Variable(torch.FloatTensor().resize_as_(data).copy_(data), requires_grad=True)
net = OkNet()
output = net(input)
output.backward(Variable(torch.zeros(output.size())))

# OkNet works。 It prints out : TV Loss backward Called

# Test Failed
data = torch.FloatTensor([[1, 2, 3, 4, 5],
                              [2, 3, 4, 5, 1],
                              [3, 4, 5, 1, 2],
                              [4, 5, 1, 2, 3],
                              [5, 1, 2, 3, 4],
                              ]).expand((3, 3, 5, 5))
input = Variable(torch.FloatTensor().resize_as_(data).copy_(data), requires_grad=True)
net = FailedNet()
output = net(input)
output.backward(Variable(torch.zeros(output.size())))

# Nothing print out

My pytorch version is 0.3.0。

Hi
This has already been answered here.

1 Like

It works with clone. Thanks