I’m trying to implement part of the idea in the paper
What I find with this approach is that I have to implement each deconvolution manually, and I don’t really see how to automate this more.
When I run the deconvolution, it flows back to the image, which now only looks pixelated.
This is the current network I am using
import torch
from torch import nn
class Backbone(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=2, stride=1, padding=2)
self.batch1 = nn.BatchNorm2d(16)
self.maxpool1 = nn.MaxPool2d(kernel_size=(3, 3), return_indices=True)
self.conv2 = nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=2)
self.batch2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=2, stride=1, padding=2)
self.batch3 = nn.BatchNorm2d(64)
self.conv4 = nn.Conv2d(64, 128, kernel_size=2, stride=1, padding=2)
self.batch4 = nn.BatchNorm2d(128)
self.maxpool = nn.MaxPool2d(kernel_size=(3, 3))
self.conv5 = nn.Conv2d(128, 256, kernel_size=2, stride=1, padding=2)
self.batch5 = nn.BatchNorm2d(256)
self.relu = nn.ReLU()
self.maxupool = nn.MaxUnpool2d(kernel_size=(3, 3))
def deconv1(
self,
image_tensor: torch.Tensor,
) -> torch.Tensor:
activations = {}
kernels = {}
indices = {}
sizes = {}
name = "maxpool1"
name1 = "conv1"
def get_activation_and_indices(name):
def hook(model, input, output):
activations[name] = output[0].detach()
indices[name] = output[1].detach()
# print(input[0].size())
sizes[name] = input[0].size()
return hook
self.maxpool1.register_forward_hook(get_activation_and_indices(name))
kernels[name1] = self.conv1.weight.detach()
self.forward(image_tensor)
x = self.maxupool(activations[name], indices[name], output_size=sizes[name])
x = self.relu(x)
x = nn.functional.conv_transpose2d(x, kernels[name1], stride=1, padding=2)
x = x.mul(255.0).clamp(0.0, 255.0)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.batch1(x)
x = self.relu(x)
x, _ = self.maxpool1(x)
x = self.conv2(x)
x = self.batch2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.batch3(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv4(x)
x = self.batch4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.batch5(x)
x = self.relu(x)
return x
Is there any better way to frame this problem of deconvolution so that it becomes more automatic ?