Extending a 2D custom layer to a 3D custom layer


class LinearKernel(torch.nn.Module):    
def __init__(self):  super(LinearKernel, self).__init__()        
def forward(self, x_unf, w, b):        
t = x_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)  
 if b is not None:           
    return t + b      
  return t                
class PolynomialKernel(LinearKernel):    def __init__(self, cp=2.0, dp=3, train_cp=True):        
super(PolynomialKernel, self).__init__() self.cp = torch.nn.parameter.Parameter(torch.tensor(cp, requires_grad=train_cp))        #print("cp = ",self.cp)        self.dp = dp        #print("dp = ",self.dp)    
 def forward(self, x_unf, w, b):              
  return (self.cp + super(PolynomialKernel, self).forward(x_unf, w, b))**self.dp  
class GaussianKernel(torch.nn.Module):    
def __init__(self, gamma): super(GaussianKernel, self).__init__()        self.gamma = torch.nn.parameter.Parameter(                          torch.tensor(gamma, requires_grad=True))        
def forward(self, x_unf, w, b):      
  l = x_unf.transpose(1, 2)[:, :, :,:, None] - w.view(1, 1, -1, w.size(0))      
  l = torch.sum(l**2, 2)       
 t = torch.exp(-self.gamma * l)       
 if b:            
  return t + b        
return t               
class KernelConv2d(torch.nn.Conv2d):    def __init__(self, in_channels, out_channels, kernel_size, kernel_fn=PolynomialKernel,                 stride=1, padding=0, dilation=1, groups=1, bias=None,                 padding_mode='zeros'):        '''        Follows the same API as torch Conv2d except kernel_fn.        kernel_fn should be an instance of the above kernels.        '''        
super(KernelConv2d, self).__init__(in_channels, out_channels,                                            kernel_size, stride, padding,                                           dilation, groups, bias, padding_mode)        self.kernel_fn = kernel_fn()       
def compute_shape(self, x):       
 h = (x.shape[2] + 2 * self.padding[0] - 1 * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1      
 w = (x.shape[3] + 2 * self.padding[1] - 1 * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1        
return h, w        
def forward(self, x):        
x_unf = torch.nn.functional.unfold(x, self.kernel_size, self.dilation,self.padding, self.stride)       
 h, w = self.compute_shape(x)        
return self.kernel_fn(x_unf, self.weight, self.bias).view(x.shape[0], -1, h, w)

Can anyone help me in implementing KernelConv3D from KerneConv2D.

Your current code is not properly formatted, which makes it quite hard to debug, so please format it if possible.
Also, could you describe where you are currently stuck? I.e. are you seeing an error etc.?

class LinearKernel(torch.nn.Module):
    def __init__(self):
        super(LinearKernel, self).__init__()
    
    def forward(self, x_unf, w, b):
        t = x_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
        if b is not None:
            return t + b
        return t
        
        
class PolynomialKernel(LinearKernel):
    def __init__(self, cp=2.0, dp=3, train_cp=True):
        super(PolynomialKernel, self).__init__()
        self.cp = torch.nn.parameter.Parameter(torch.tensor(cp, requires_grad=train_cp))
        self.dp = dp

    def forward(self, x_unf, w, b):
        return (self.cp + super(PolynomialKernel, self).forward(x_unf, w, b))**self.dp


class GaussianKernel(torch.nn.Module):
    def __init__(self, gamma):
        super(GaussianKernel, self).__init__()
        self.gamma = torch.nn.parameter.Parameter(
                            torch.tensor(gamma, requires_grad=True))
    
    def forward(self, x_unf, w, b):
        l = x_unf.transpose(1, 2)[:, :, :, None] - w.view(1, 1, -1, w.size(0))
        l = torch.sum(l**2, 2)
        t = torch.exp(-self.gamma * l)
        if b:
            return t + b
        return t
class LinearKernel(torch.nn.Module):
    def __init__(self):
        super(LinearKernel, self).__init__()
    
    def forward(self, x_unf, w, b):
        t = x_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
        if b is not None:
            return t + b
        return t
        
        
class PolynomialKernel(LinearKernel):
    def __init__(self, cp=2.0, dp=3, train_cp=True):
        super(PolynomialKernel, self).__init__()
        self.cp = torch.nn.parameter.Parameter(torch.tensor(cp, requires_grad=train_cp))
        #print("cp = ",self.cp)
        self.dp = dp
        #print("dp = ",self.dp)

    def forward(self, x_unf, w, b):
        
        return (self.cp + super(PolynomialKernel, self).forward(x_unf, w, b))**self.dp


class GaussianKernel(torch.nn.Module):
    def __init__(self, gamma):
        super(GaussianKernel, self).__init__()
        self.gamma = torch.nn.parameter.Parameter(
                            torch.tensor(gamma, requires_grad=True))
    
    def forward(self, x_unf, w, b):
        l = x_unf.transpose(1, 2)[:, :, :,:, None] - w.view(1, 1, -1, w.size(0))
        l = torch.sum(l**2, 2)
        t = torch.exp(-self.gamma * l)
        if b:
            return t + b
        return t
        
       
class KernelConv2d(torch.nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, kernel_fn=PolynomialKernel,
                 stride=1, padding=0, dilation=1, groups=1, bias=None,
                 padding_mode='zeros'):
        '''
        Follows the same API as torch Conv2d except kernel_fn.
        kernel_fn should be an instance of the above kernels.
        '''
        super(KernelConv2d, self).__init__(in_channels, out_channels, 
                                           kernel_size, stride, padding,
                                           dilation, groups, bias, padding_mode)
        self.kernel_fn = kernel_fn()
   
    def compute_shape(self, x):
        h = (x.shape[2] + 2 * self.padding[0] - 1 * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
        w = (x.shape[3] + 2 * self.padding[1] - 1 * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
        return h, w
    
    def forward(self, x):
        x_unf = torch.nn.functional.unfold(x, self.kernel_size, self.dilation,self.padding, self.stride)
        h, w = self.compute_shape(x)
        return self.kernel_fn(x_unf, self.weight, self.bias).view(x.shape[0], -1, h, w)
class KernelConv3d(torch.nn.Conv3d):
    def __init__(self, in_channels, out_channels, kernel_size, kernel_fn=PolynomialKernel,
                 stride=1, padding=0, dilation=1, groups=1, bias=None,
                 padding_mode='zeros'):
        '''
        Follows the same API as torch Conv2d except kernel_fn.
        kernel_fn should be an instance of the above kernels.
        '''
        super(KernelConv3d, self).__init__(in_channels, out_channels, 
                                           kernel_size, stride, padding,
                                           dilation, groups, bias, padding_mode)
        self.kernel_fn = PolynomialKernel()
   
    def compute_shape(self, x):
        h = (x.shape[2] + 2 * self.padding[0] - 1 * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
        w = (x.shape[3] + 2 * self.padding[1] - 1 * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
        d = (x.shape[4] + 2 * self.padding[2] - 1 * (self.kernel_size[2] - 1) - 1) // self.stride[2] + 1
        return h, w,d
    
    def forward(self, x):
        x_unf = torch.nn.functional.unfold(x, self.kernel_size, self.dilation,self.padding, self.stride)
        h, w,d = self.compute_shape(x)
        return self.kernel_fn(x_unf, self.weight, self.bias).view(x.shape[0], -1, h, w,d)

I have extended class KernelConv2d(torch.nn.Conv2d) to class KernelConv3d(torch.nn.Conv3d) by including the depth of the image along with its height and width and tried to run the following code.

from torchvision import models
from torchsummary import summary
model = HybridSN().to(device)
summary(model,(15,15,15,1))

I am getting the following error:
AssertionError: kernel_size must be int or 2-tuple for 4D input. I guess it is due to the fact that torch.nn.functional.unpack supports only 4D tensor. Can you please resolve my issue.

I don’t know which line of code raises this error, but you are correct that F.unfold expects 4-dimensional tensors. For other shapes you could call tensor.unfold directly and would then have to specify the dimensions in each call.

Can you please suggest me the int values that I need to pass into the tensor.unfold for my problem. I am using the following github repository to extend my KernelConv2d to KernelConv3d.

I think you could reuse the already computed kernel size and stride.
Padding is not supported in this method and you could add the padding via F.pad beforehand.
In case you are using a dilation other than 1, I don’t think you could use tensor.unfold directly, unfortunately.

Could you please write the appropriate statements using F.pad and dilation =1 as suggested by you to resolve the issue.

Here is an example using F.unfold and tensor.unfold for the 4D case (you can add another unfold call for the additional dimension):

B, C, H, W = 2, 3, 4, 4
x = torch.arange(1, 1+(B*C*H*W)).float().view(B, C, H, W)
kernel_size = 2
dilation = 1
padding = 1
stride = 2
x_unf = torch.nn.functional.unfold(
    input=x, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
x_unf.shape # [2, 12, 9]

x_pad = F.pad(x, [1, 1, 1, 1])
out = x_pad.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
out.shape
out = out.contiguous().permute(0, 2, 3, 1, 4, 5).contiguous().view(out.size(0), -1, kernel_size*kernel_size*x.size(1))
out = out.permute(0, 2, 1)
print((out == x_unf).all())
> tensor(True)

(thanks @tom for finding an initial error :slight_smile: )