Hello,
I’m writing a programme that consists of an auto-encoder followed by a CNN. My dataset consists of RGB images (so I’m dealing with 3 channels) and the input size of the CNN needs to be 3x224x224.
I’ve written a conv AE and easily got a 3-channel output. However, I’m not sure how to achieve that with a simple fully-connected auto-encoder. My AE looks like this:
class autoencoder(nn.Module):
def __init__(self):
super(autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(3 * 224 * 224, 8192),
nn.ReLU(True),
( more linear layers + relus here )
self.decoder = nn.Sequential(
( more linear layers + relus here )
nn.ReLU(True),
nn.Linear(8192 , 3 * 224 * 224),
nn.Tanh())
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
The relevant bits of my code:
(relevant definitions)
batch_size = 20
model = autoencoder()
criterion = nn.MSELoss()
optimizer = adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
(training function)
for data in train_loader:
img,_ = data
img = Variable(img).to(device)
#forward
output = model(img)
loss = criterion(output, img)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
The above auto-encoder code was giving me the mismatch error: RuntimeError: size mismatch, m1: [13440 x 224], m2: [150528 x 8192] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2070 - I’m not 100% sure what m1 and m2 are ??
Based on what I’ve read, it seems that the input of a fully connected autoencoder needs to be flattened (so that the dimensions are batch_size x (w x h xchannels)) so I changed the forward function to :
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.encoder(x)
x = self.decoder(x)
return x
This however gives me another error: RuntimeError: input and target shapes do not match: input [20 x 150528], target [20 x 3 x 224 x 224] at /pytorch/aten/src/THNN/generic/MSECriterion.c:12
How can I fix this?
Also, if I flatten the output to fit into the fully connected autoencoder, how can I get an output in a 3D form? (and by “3D” I mean “2D with 3 channels”)
Many thanks,
NK