Hi,
I would like to implement a locally connected network for decoding purpose.
You can see this layer as a transposed convolution but without weight sharing.
I use in my code the fold function to perform the sliding windows operation, unfortunately it seems that it doesn’t work for 4D tensor…
What I don’t understand, is that the unfold operation (that i use for the encoding locally connected part) work with 4D tensor.
Do you have any idea to make it work ?
Or an alternative to perform the fold operation ?
(here is the code for the locally connected layer (encoding and decoding))
# Main encoder locally connected linear layer
class LocallyConnected2d(nn.Module):
def calculate_spatial_output_shape(self, input_shape, kernel_size, dilation, padding, stride):
return [np.floor(((input_shape[index]+2*padding[index]-dilation[index]*(kernel_size[index]-1)-1)/stride[index])+1).astype(int) for index in range(len(input_shape))]
def __init__(self, input_shape, in_channels, out_channels, kernel_size, dilation, padding, stride):
super().__init__()
self.kernel_size = kernel_size
self.out_channels = out_channels
self.dilation = dilation
self.padding = padding
self.stride = stride
# calculate desired output shape and generate weight/bias matrix
self.output_height, self.output_width = self.calculate_spatial_output_shape(input_shape, kernel_size,dilation, padding, stride)
self.weight_tensor_depth = in_channels * kernel_size[0] * kernel_size[1]
self.spatial_blocks_size = self.output_height * self.output_width
self.weights = nn.Parameter(torch.empty((1, self.weight_tensor_depth, self.spatial_blocks_size, out_channels),requires_grad=True, dtype=torch.float))
self.bias = nn.Parameter(torch.empty((1, out_channels, self.output_height, self.output_width),requires_grad=True, dtype=torch.float))
# init weight and bias
torch.nn.init.xavier_uniform_(self.weights)
torch.nn.init.xavier_uniform_(self.bias)
def forward(self, input):
input_unf = torch.nn.functional.unfold(input, self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride)
local_conv_unf = (input_unf.view((*input_unf.shape, 1)) * self.weights)
return local_conv_unf.sum(dim=1).transpose(2, 1).reshape((-1, self.out_channels, self.output_height, self.output_width)) + self.bias
# Main decoder locally connected linear layer
class LocallyConnected2dTranspose(nn.Module):
def calculate_spatial_transposed_output_shape(self, input_shape, kernel_size, dilation, input_padding, out_padding, stride):
return [np.floor((input_shape[index]-1)*stride[index]-2*input_padding[index]+dilation[index]*kernel_size[index]-1+out_padding[index]+1).astype(int) for index in range(len(input_shape))]
def __init__(self, input_shape, in_channels, out_channels, kernel_size, dilation, input_padding, out_padding, stride):
super().__init__()
self.kernel_size = kernel_size
self.out_channels = out_channels
self.dilation = dilation
self.input_padding = input_padding
self.out_padding = out_padding
self.stride = stride
# calculate desired output shape and generate weight/bias matrix
self.output_height, self.output_width = self.calculate_spatial_transposed_output_shape(input_shape, kernel_size, dilation, input_padding, out_padding, stride)
# weight and spatial block
self.weight_tensor_depth = in_channels * kernel_size[0] * kernel_size[1]
self.spatial_blocks_size = self.output_height * self.output_width
self.weights = nn.Parameter(torch.empty((1, self.weight_tensor_depth, self.spatial_blocks_size, out_channels),requires_grad=True, dtype=torch.float))
print(self.weights.shape)
self.bias = nn.Parameter(torch.empty((1, out_channels, self.output_height, self.output_width),requires_grad=True, dtype=torch.float))
# init weight and bias
torch.nn.init.xavier_uniform_(self.weights)
torch.nn.init.xavier_uniform_(self.bias)
def forward(self, input):
input_f = torch.nn.functional.fold(input, [self.output_height, self.output_width], self.kernel_size, dilation=1, padding=1, stride=2)
local_conv_f = (input_f.view((*input_f.shape, 1)) * self.weights)
print(local_conv_f.shape)
return local_conv_f.sum(dim=1).transpose(2, 1).reshape((-1, self.out_channels, self.output_height, self.output_width)) + self.bias
layer = LocallyConnected2d(input_shape=[128,128], in_channels=3, out_channels=49, kernel_size=[5,5], dilation=(1,1), padding=(2,2), stride=(2,2))
xt = torch.randn(1,3,128,128)
print(layer(xt).shape)
layer_ = LocallyConnected2dTranspose([64,64], 49, 3, [5,5], (1,1), (2,2), (1,1), (2,2))
ht = torch.randn(1,49,64,64)
print(layer_(ht).shape)
Thanks in advance for your answer !
Best regards,
Munch Quentin.