class DTT2B(DTL2R):
def __init__(self):
super(DTT2B, self).__init__()
def forward(self, weight, x):
self.dim = x.size()
self.intermediate = torch.zeros(self.dim)
self.weight = weight
h = x.clone()
for i in range(1, self.dim[2]):
self.intermediate[:, :, i, :] = h[:, :, i-1, :] - h[:, :, i, :]
h[:, :, i, :] += weight[i, :]*self.intermediate[:, :, i, :]
return h
def backward(self, grad_weight, grad_output):
print '2'*200
intermediate, weight = self.intermediate, self.weight
grad_weight, grad_input = grad_weight.clone(), grad_output.clone()
print grad_output.size()
for i in range(self.dim[2]-1, 0, -1):
grad_weight[i, :] += (grad_input[:, :, i+1, :] *
intermediate[:, :, :, i]).sum(dim=0).sum(dim=0)
grad_input[:, :, i-1, :] = weight[i, :] * grad_input[:, :, i, :]
grad_input[:, :, i, :] = (1-weight[i, :]) * grad_input[:, :, i, :]
print self.__class__.__name__, grad_input, grad_weight
return grad_input, grad_weight
class DTB2T(DTL2R):
def __init__(self):
super(DTB2T, self).__init__()
def forward(self, weight, x):
self.dim = x.size()
self.intermediate = torch.zeros(self.dim)
self.weight = weight
h = x.clone()
for i in range(self.dim[2]-2, -1, -1):
self.intermediate[:, :, i, :] = h[:, :, i+1, :] - h[:, :, i, :]
h[:, :, i, :] += weight[i+1, :]*self.intermediate[:, :, i, :]
return h
def backward(self, grad_weight, grad_output):
print '3'*200
intermediate, weight = self.intermediate, self.weight
grad_weight, grad_input = grad_weight.clone(), grad_output.clone()
for i in range(self.dim[2]-1):
grad_weight[i+1, :] += (grad_input[:, :, i, :] *
intermediate[:, :, i, :]).sum(dim=0).sum(dim=0)
grad_input[:, :, i+1, :] = weight[i+1, :] * grad_input[:, :, i, :]
grad_input[:, :, i, :] = (1-weight[i+1, :]) * grad_input[:, :, i, :]
print self.__class__.__name__, grad_input, grad_weight
print grad_input.size(), grad_weight.size()
return grad_input, grad_weight
DT1 = DTT2B()
DT2 = DTB2T()
x=...
h=DT1(x)
h=DT2(h)
loss=softmax2d_loss()
loss.backward()# here would be error
Could any one can help me? I want to propagate multiple gradient to previous layer.
TypeError: backward() takes exactly 3 arguments (2 given)