Can't figure out the error?

class ImageClassification(nn.Module):

def __init__(self):

    super().__init__()

    self.conv1 = nn.Conv3d(1, 8, 3, 2) 

    self.conv1_bn = nn.BatchNorm3d(8)

    self.conv2 = nn.Conv3d(8, 16, 3, 1)

    self.conv2_bn = nn.BatchNorm3d(16)

    self.conv3 = nn.Conv3d(16,32,3,1)

    self.conv3_bn = nn.BatchNorm3d(32)

    self.fc1 = nn.Linear(4*4*1*32, 256)

    self.fc1_bn = nn.BatchNorm3d(256)

    self.fc2 = nn.Linear(256, 128) 

    self.fc2_bn = nn.BatchNorm3d(128)

    self.fc3 = nn.Linear(128, 2)

def forward(self, X):

    X = self.conv1(X)

    X = F.relu(self.conv1_bn(X))

    X = F.max_pool3d(X,2)

    X = self.conv2(X)

    X = F.relu(self.conv2_bn(X))

    X = F.max_pool3d(X,2)

    X = self.conv3(X)

    X = F.relu(self.conv3_bn(X))

    X = F.max_pool3d(X,2)

    X = x.view(-1,512)

    X = self.fc1(X)

    X = self.fc1_bn(X)

    X = self.fc2(X)

    X = F.relu(self.fc2_bn(X))

    X=  F.relu(self.fc3(X))

    return F.log_softmax(X, dim=1)

ValueError Traceback (most recent call last)
in ()
17
18 # Apply the model
—> 19 y_pred = model(X_train) # we don’t flatten X-train here
20 loss = criterion(y_pred, y_train)
21

4 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

in forward(self, X)
26 X = x.view(-1,512)
27 X = self.fc1(X)
—> 28 X = self.fc1_bn(X)
29 X = self.fc2(X)
30 X = F.relu(self.fc2_bn(X))

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/batchnorm.py in forward(self, input)
133
134 def forward(self, input: Tensor) → Tensor:
→ 135 self._check_input_dim(input)
136
137 # exponential_average_factor is set to self.momentum

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/batchnorm.py in _check_input_dim(self, input)
512 def _check_input_dim(self, input):
513 if input.dim() != 5:
→ 514 raise ValueError(“expected 5D input (got {}D input)”.format(input.dim()))
515
516

ValueError: expected 5D input (got 2D input)

Let me explain what did you wrong in 2 steps,

X = x.view(-1, 512)
X = self.fc1(X)

X from the above has a size of (batch_size, 256) as you define the output size of self.fc1 256.

Next,

X = self.fc1_bn(X)

The input X is a 2-dimensional tensor but fc1_bn is defined as 3-D batch normalization layer.
If you really want to use 3-D BatchNorm, then you should reform X before the normalization e.g.

X = X.view(1, -1, 1, 1, 256)
X = self.fc1_bn(X)

I recommend you not use batch normalization after nn.Linear. It means nothing with 2-dimensional tensor.
Just an activation is enough such as ReLU, Sigmoid, and so on.
Which activation function is the best? It’s totally up to your task.