class ImageClassification(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv3d(1, 8, 3, 2)
self.conv1_bn = nn.BatchNorm3d(8)
self.conv2 = nn.Conv3d(8, 16, 3, 1)
self.conv2_bn = nn.BatchNorm3d(16)
self.conv3 = nn.Conv3d(16,32,3,1)
self.conv3_bn = nn.BatchNorm3d(32)
self.fc1 = nn.Linear(4*4*1*32, 256)
self.fc1_bn = nn.BatchNorm3d(256)
self.fc2 = nn.Linear(256, 128)
self.fc2_bn = nn.BatchNorm3d(128)
self.fc3 = nn.Linear(128, 2)
def forward(self, X):
X = self.conv1(X)
X = F.relu(self.conv1_bn(X))
X = F.max_pool3d(X,2)
X = self.conv2(X)
X = F.relu(self.conv2_bn(X))
X = F.max_pool3d(X,2)
X = self.conv3(X)
X = F.relu(self.conv3_bn(X))
X = F.max_pool3d(X,2)
X = x.view(-1,512)
X = self.fc1(X)
X = self.fc1_bn(X)
X = self.fc2(X)
X = F.relu(self.fc2_bn(X))
X= F.relu(self.fc3(X))
return F.log_softmax(X, dim=1)
ValueError Traceback (most recent call last)
in ()
17
18 # Apply the model
—> 19 y_pred = model(X_train) # we don’t flatten X-train here
20 loss = criterion(y_pred, y_train)
21
4 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
in forward(self, X)
26 X = x.view(-1,512)
27 X = self.fc1(X)
—> 28 X = self.fc1_bn(X)
29 X = self.fc2(X)
30 X = F.relu(self.fc2_bn(X))
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/batchnorm.py in forward(self, input)
133
134 def forward(self, input: Tensor) → Tensor:
→ 135 self._check_input_dim(input)
136
137 # exponential_average_factor is set to self.momentum
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/batchnorm.py in _check_input_dim(self, input)
512 def _check_input_dim(self, input):
513 if input.dim() != 5:
→ 514 raise ValueError(“expected 5D input (got {}D input)”.format(input.dim()))
515
516
ValueError: expected 5D input (got 2D input)