Hi,
Regarding my previous post New Convolutional Layer , I created a new custom layer like the code below. Although it seems that the forward method is working, I am facing some issues with the backward method.
Can someone help me with the backward method? Any advice on how can I implement the backward method?
Thanks
class Myquadratic(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight , bias):
x, weight, bias = x.detach(), weight.detach(), bias.detach()
batch_size = 100
channels= 1 #int(x[0][0])
height = 28
width = 28
kh=3
kw=3
W1 = torch.zeros(1,32, 9,9)
size = W1.shape
tmp = W1.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
W1.data.copy_(tmp.gather(-1, ind).squeeze(-1))
W1.data.mul_(1).add_(0)
dh, dw = 1, 1
# Pad tensor to get the same output
x = F.pad(x, (1, 1, 1, 1))
patches = x.unfold(2, kh, dh).unfold(3, kw, dw)
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
patches = Variable(patches)
patches = patches.view(*patches.size()[:3], -1)
patches = patches.permute(0,3,1,2)
batch, out_row, out_col, sizes = int(patches.shape[0]),int(patches.shape[1]),int(patches.shape[2]),int(patches.shape[3])
patches = torch.unsqueeze(patches, -1)
V = torch.einsum('aibcd,ajbcd,doij->abcdo', patches, patches, W1)
V2 = torch.sum(V, dim=3)
V2=V2.permute(0,3,1,2)
b_tensor = torch.empty(32, dtype=torch.float)
b_tensor.fill_(0.1)
b = torch.nn.Parameter(b_tensor)
ctx.save_for_backward(x, W1, weight, bias )
tmp = F.conv2d(x,weight)
result = F.relu(V2 + F.conv2d(x,weight) + bias)
i=1
j=2
return torch.as_tensor(result, dtype=x.dtype)
@staticmethod
def backward(ctx, grad_output):
#grad_output = grad_output.detach()
#inputs, W1, weight, bias = ctx.saved_tensors
#return inputs, weight, bias
class conv2D(Module):
def __init__(self):
super(conv2D, self).__init__()
self.weight = nn.Parameter(torch.rand(32,1,3,3))
b_tensor = torch.empty(1,32,1,1, dtype=torch.float)
b_tensor.fill_(0.1)
self.bias = torch.nn.Parameter(b_tensor)
def forward(self,x):
return Myquadratic.apply(x, self.weight, self.bias)