I have this encoder architecture
class Encoder(nn.Module):
def __init__(self, in_channel, img_width, hidden_dim, device, max_filters=512, num_layers=4, small_conv=False, norm_type = 'batch', num_groups=1, kernel_size=4, stride_size=2, padding_size=0, activation = nn.PReLU()):
super(Encoder,self).__init__()
self.nchannel = in_channel
self.hidden_dim = hidden_dim
self.img_width = img_width
self.device = device
self.enc_kernel = kernel_size
self.enc_stride = stride_size
self.enc_padding = padding_size
self.res_kernel = 3
self.res_stride = 1
self.res_padding = 1
self.activation = activation
########################
# ENCODER-CONVOLUTION LAYERS
if small_conv:
num_layers += 1
channel_sizes = calculate_channel_sizes(
self.nchannel, max_filters, num_layers
)
# Encoder
encoder_layers = nn.ModuleList()
# Encoder Convolutions
for i, (in_channels, out_channels) in enumerate(channel_sizes):
if small_conv and i == 0:
# 1x1 Convolution
encoder_layers.append(nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.enc_kernel,
stride=self.enc_stride,
padding=self.enc_padding,
))
else:
encoder_layers.append( nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.enc_kernel,
stride=self.enc_stride,
padding=self.enc_padding,
bias=False,
))
# Batch Norm
if norm_type == 'batch':
encoder_layers.append(nn.BatchNorm2d(out_channels))
elif norm_type == 'layer':
encoder_layers.append(nn.GroupNorm(num_groups, out_channels ))
# ReLU
encoder_layers.append(self.activation)
if (i==num_layers//2):
#add a residual Layer
encoder_layers.append(ResidualBlock(
out_channels,
self.res_kernel,
self.res_stride,
self.res_padding,
norm_type=norm_type,
nonlinearity=self.activation
))
# Flatten Encoder Output
encoder_layers.append(nn.Flatten())
self.encoder = nn.Sequential(*encoder_layers)
# Calculate shape of the flattened image
self.h_dim, self.h_image_dim = self.get_flattened_size(self.img_width)
#linear layers
layers = []
layers.append(nn.Linear(self.h_dim, hidden_dim, bias=False))
if norm_type == 'batch':
layers.append(nn.BatchNorm1d(hidden_dim))
elif norm_type == 'layer':
layers.append(nn.LayerNorm(hidden_dim ))
layers.append(self.activation)
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=False))
if norm_type == 'batch':
layers.append(nn.BatchNorm1d(hidden_dim))
elif norm_type == 'layer':
layers.append(nn.LayerNorm(hidden_dim ))
layers.append(self.activation)
self.linear_layers = nn.Sequential( *layers)
self.to(device=self.device)
def forward(self,X):
# Encode (note ensure input tensor has the shape [batch_size, channels, height, width])
h = self.encoder(X)
print(f"computed hidden dimension {self.h_dim}, {h.shape} and {self.h_image_dim}, input shape {X.shape}, architecture {self.encoder} input channel {self.nchannel}, image width {self.img_width}")
#print(f"encoder architecture {self.encoder} input {X.shape} hidden {h.shape}")
# Get latent variables
return self.linear_layers(h)
def get_flattened_size( self, image_size ):
for layer in self.encoder.modules():
if isinstance(layer, nn.Conv2d):
kernel_size = layer.kernel_size[0]
stride = layer.stride[0]
padding = layer.padding[0]
filters = layer.out_channels
image_size = calculate_layer_size(
image_size, kernel_size, stride, padding
)
return filters * image_size * image_size, image_size
here is how I called this class
def calculate_conv_params(input_size):
height, width, channels = input_size
if height > 100 or width > 100:
kernel_size = 5
elif height > 50 or width > 50:
kernel_size = 4
else:
kernel_size = 3
stride = 1 # To keep spatial dimensions same, stride should be 1
padding = (kernel_size - 1) // 2
return kernel_size, stride, padding
kernel, stride, padding = calculate_conv_params((input_width,input_width,input_channel))
self.base = Encoder(input_channel, input_width, self.hidden_size, device, max_filters=256, num_layers=3, kernel_size= kernel, stride_size=stride, padding_size=padding)
here is the printed architecture and the size of input and hidden layer for the linear layer
computed hidden dimension 409600, torch.Size([256, 30720]) and 40, input shape torch.Size([256, 40, 40, 3]), architecture Sequential(
(0): Conv2d(40, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): PReLU(num_parameters=1)
(3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): PReLU(num_parameters=1)
(6): ResidualBlock(
(layers): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): PReLU(num_parameters=1)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): PReLU(num_parameters=1)
)
)
(7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): PReLU(num_parameters=1)
(10): Flatten(start_dim=1, end_dim=-1)
) input channel 40, image width 40
the error message from this specific input while it worked for other images
actor_features = self.base(obs)
File "/home/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/lustre03/project/6057506/onpolicy/algorithms/utils/cnn.py", line 268, in forward
return self.linear_layers(h)
File "/home/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x30720 and 409600x64)
Can you kindly suggest a way to fix this error?