Hello l am trying to do an auto encoder with 2 inputs base on the Siamese method but l got this issue and l don’t know why. Could you help me ?
ValueError Traceback (most recent call last)
Cell In[72], line 3
1 epochs=10
2 for epoch in range(1,epochs+1):
----> 3 training(epoch)
Cell In[71], line 40, in training(epochs)
35 #Clear the gradients
37 optimizer.zero_grad()
—> 40 output1,ouptut2 = Model(img_pre,img_post)
42 loss = criterion(output1,output2,data)
44 loss.backward()
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/codes/Auto_Encoder/Auto_encoder.py:106, in Auto_encoder.forward(self, x1, x2)
104 def forward(self, x1,x2):
→ 106 encoded = self.encoder(x1)
107 decoded1 = self.decoder(encoded)
110 encoded = self.encoder(x2)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
→ 217 input = module(input)
218 return input
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:138, in _BatchNorm.forward(self, input)
137 def forward(self, input: Tensor) → Tensor:
→ 138 self._check_input_dim(input)
140 # exponential_average_factor is set to self.momentum
141 # (when it is available) only so that it gets updated
142 # in ONNX graph when this node is exported to ONNX.
143 if self.momentum is None:
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:410, in BatchNorm2d._check_input_dim(self, input)
408 def _check_input_dim(self, input):
409 if input.dim() != 4:
→ 410 raise ValueError(“expected 4D input (got {}D input)”.format(input.dim()))
ValueError: expected 4D input (got 3D input)
my code :
class Auto_encoder(nn.Module):
def __init__(self, in_channels=3, out_channels=16, latent_dim=200, act_fn=nn.ReLU()):
super(Auto_encoder,self).__init__()
#Encoder
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1), # (32, 32)
act_fn,
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
act_fn,
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, 2*out_channels, 3, padding=1, stride=2), # (16, 16)
act_fn,
nn.BatchNorm2d(out_channels),
nn.Conv2d(2*out_channels, 2*out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
act_fn,
nn.BatchNorm2d(out_channels),
nn.Conv2d(2*out_channels, 4*out_channels, 3, padding=1, stride=2), # (8, 8)
act_fn,
nn.BatchNorm2d(out_channels),
nn.Conv2d(4*out_channels, 4*out_channels, 3, padding=1),
act_fn,
nn.BatchNorm2d(out_channels),
nn.Flatten(),
nn.Linear(4*out_channels*8*8, latent_dim),
act_fn
)
#Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 4*out_channels*8*8),
act_fn,
nn.ConvTranspose2d(4*out_channels, 4*out_channels, 3, padding=1), # (8, 8)
act_fn,
nn.ConvTranspose2d(4*out_channels, 2*out_channels, 3, padding=1,
stride=2, output_padding=1), # (16, 16)
act_fn,
nn.ConvTranspose2d(2*out_channels, 2*out_channels, 3, padding=1),
act_fn,
nn.ConvTranspose2d(2*out_channels, out_channels, 3, padding=1,
stride=2, output_padding=1), # (32, 32)
act_fn,
nn.ConvTranspose2d(out_channels, out_channels, 3, padding=1),
act_fn,
nn.ConvTranspose2d(out_channels, in_channels, 3, padding=1)
)
def forward(self, x1,x2):
encoded = self.encoder(x1)
decoded1 = self.decoder(encoded)
encoded = self.encoder(x2)
decoded2 = self.decoder(encoded)
return decoded1, decoded2
def training(epochs):
Model.train()
epoch_loss = 0.0
running_loss_train=0.0
#for label,data in dataloader_train:
for data, _ in (dataloader_train):
img_pre, img_post = data
img_pre.to(device=device) # move to device, e.g. GPU
img_post.to(device=device)
#Clear the gradients
optimizer.zero_grad()
output1,ouptut2 = Model(img_pre,img_post)
loss = criterion(output1,output2,data)
loss.backward()
# Update Weight
optimizer.step()
#Calculate prediction
prediction_train=output.argmax(dim=1)
running_loss_train =+ loss.item()
epoch_loss += running_loss_train / len(train_loader)
#acc_train += (prediction_train == label).type(torch.float).sum().item()
Tp = ((label == 1) & (prediction_train == 1)).sum().item()
fp = ((label == 0) & (prediction_train == 1)).sum().item()
Fn = ((label == 1) & (prediction_train == 0)).sum().item()
# calculate precision, recall, and F1-score
Precision_train = Tp / (Tp + Fp)
Recall_train = Tp / (Tp + Fn)
F1_score_train = 2 * (Precision * Recall) / (Precision + Recall)