Hi all,
I am trying to implement a customized convolution. I have two inputs, image
and x
, with shape 1 x N x N
and batch_size x 1 x N x N
, respectively. The kernels depend on image
, and will be applied to x
. That means, we learn a map (9 x 9 linear map) from a 3 x 3
patch of the image to a 3 x 3
kernel. Then apply that kernel to the corresponding patch of x
.
The following is my current implementation. I have two versions: the forward
version takes up too much memory and is relatively slow compared to a built-in Conv2D, the forward1
version saves memory, but due to naive for-loops. it is VERY slow. How can I make it fast with limited CUDA memory?
class SmallSMBlock(nn.Module):
def __init__(self):
super().__init__()
self.KL = nn.Conv2d(1, 9, kernel_size=3, padding='same', bias=True)
def forward1(self, image, x): # 1 x N x N, bs x 1 x N x N
N = x.shape[-1]
y = torch.zeros_like(x)
image = F.pad(image, (1,)*4)
x = F.pad(x, (1,)*4)
for i in range(1, N+1):
for j in range(1, N+1):
K = ((self.KL.weight.squeeze() * image[0, i-1:i+2, j-1:j+2])).sum(dim=(1, 2)) + self.KL.bias
K = K.view((3, 3))
xx = x[:, 0, i-1:i+2, j-1:j+2]
y[:, 0, i-1, j-1] = (xx * K).sum(dim=(-2, -1))
return y
def forward(self, image, x): # 1 x N x N, bs x 1 x N x N
K = self.KL(image) # 1 x N x N -> 9 x N x N
K = K.permute((1, 2, 0)) # 9 x N x N -> N x N x 9
K = K.unflatten(2, (3, 3)) # N x N x 9 -> N x N x 3 x 3
x = F.pad(x, (1, 1, 1, 1)) # bs x 1 x N x N -> bs x 1 x (N+2) x (N+2)
x = x.unfold(2, 3, 1).unfold(3, 3, 1) # bs x 1 x (N+2) x (N+2) -> bs x 1 x N x N x 3 x 3
y = (x * K).sum(dim=(-2, -1))
return y