Hello sir, I am sorry for the confusion. I just realized that my PatchMerging class successfully processes the input once within the transformer and works well, as I have printed the shapes for verification. However, when my training process reaches the PatchMerging class for the second time, I encounter the following error:
Error:
Traceback (most recent call last):
File “train.py”, line 175, in
trainer.fit(net, datamodule=data_module)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 771, in fit
self._call_and_handle_interrupt(
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 724, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 812, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 1222, in _run
self._call_callback_hooks(“on_fit_start”)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 1637, in _call_callback_hooks
fn(self, self.lightning_module, *args, **kwargs)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\callbacks\model_summary.py”, line 56, in on_fit_start
model_summary = summarize(pl_module, max_depth=self._max_depth)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\utilities\model_summary.py”, line 427, in summarize
return ModelSummary(lightning_module, max_depth=max_depth)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\utilities\model_summary.py”, line 187, in init
self._layer_summary = self.summarize()
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\utilities\model_summary.py”, line 244, in summarize
self._forward_example_input()
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\pytorch_lightning\utilities\model_summary.py”, line 274, in forward_example_input
model(input)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\torch\nn\modules\module.py”, line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\torch\nn\modules\module.py”, line 1541, in _call_impl
return forward_call(*args, **kwargs)
File “D:\3_3D-EffiViTCaps-main\module\effiViTcaps.py”, line 190, in forward
conv_3_1 = self.patchMergingblock_2(conv_2_1)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\torch\nn\modules\module.py”, line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\torch\nn\modules\module.py”, line 1582, in _call_impl
result = forward_call(*args, **kwargs)
File “D:\3_3D-EffiViTCaps-main\main_block\UCTransNet.py”, line 190, in forward
x = self.reduction(x) # Applies the linear reduction
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\torch\nn\modules\module.py”, line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\torch\nn\modules\module.py”, line 1541, in _call_impl
return forward_call(*args, **kwargs)
File “C:\Users\LGD\anaconda3\envs\EffiViTCaps\lib\site-packages\torch\nn\modules\linear.py”, line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x512 and 1024x128)
“Shapes printed out:”
x0 shape: torch.Size([1, 16, 16, 16, 64])
x1 shape: torch.Size([1, 16, 16, 16, 64])
x2 shape: torch.Size([1, 16, 16, 16, 64])
x3 shape: torch.Size([1, 16, 16, 16, 64])
x4 shape: torch.Size([1, 16, 16, 16, 64])
x5 shape: torch.Size([1, 16, 16, 16, 64])
x6 shape: torch.Size([1, 16, 16, 16, 64])
x7 shape: torch.Size([1, 16, 16, 16, 64])
x.view shape: torch.Size([1, 16, 16, 16, 512])
x shape before reduction: torch.Size([1, 16, 16, 16, 512])
x shape transpose: torch.Size([1, 128, 16, 16, 16])
x1 shape: torch.Size([1, 64, 16, 16])
en1 shape: torch.Size([1, 64, 16, 16, 16])
x1 shape: torch.Size([1, 64, 16, 16, 16])
en1 shape: torch.Size([1, 64, 16, 16, 16])
x2 shape: torch.Size([1, 128, 8, 8])
en2 shape: torch.Size([1, 128, 8, 8, 8])
x2 shape: torch.Size([1, 128, 8, 8, 8])
en2 shape: torch.Size([1, 128, 8, 8, 8])
x3 shape: torch.Size([1, 256, 4, 4])
en3 shape: torch.Size([1, 256, 4, 4, 4])
x3 shape: torch.Size([1, 256, 4, 4, 4])
en3 shape: torch.Size([1, 256, 4, 4, 4])
x4 shape: torch.Size([1, 512, 2, 2])
en4 shape: torch.Size([1, 512, 2, 2, 2])
x4 shape: torch.Size([1, 512, 2, 2, 2])
en4 shape: torch.Size([1, 512, 2, 2, 2])
x0 shape: torch.Size([1, 8, 8, 8, 64])
x1 shape: torch.Size([1, 8, 8, 8, 64])
x2 shape: torch.Size([1, 8, 8, 8, 64])
x3 shape: torch.Size([1, 8, 8, 8, 64])
x4 shape: torch.Size([1, 8, 8, 8, 64])
x5 shape: torch.Size([1, 8, 8, 8, 64])
x6 shape: torch.Size([1, 8, 8, 8, 64])
x7 shape: torch.Size([1, 8, 8, 8, 64])
x.view shape: torch.Size([1, 8, 8, 8, 512])
x shape before reduction: torch.Size([1, 8, 8, 8, 512])
My class patch_merging code:
class PatchMerging3D(nn.Module): # class PatchMerging3D inherits from nn.Module
“”" 3D Patch Merging Layer
Args:
input_dim (int): Number of input channels.
output_dim (int): Number of output channels after reduction.
"""
def __init__(self, input_dim, output_dim, norm_layer=nn.LayerNorm): # Initializes the PatchMerging3D class with input_dim, output_dim, and an optional norm_layer
super().__init__() # Calls the __init__ method of the parent class
self.reduction = nn.Linear(8 * input_dim, output_dim, bias=False) # Defines a linear layer to reduce the dimensionality of the input
self.norm = norm_layer(8 * input_dim) # Defines a normalization layer
def forward(self, x): # Defines the forward pass of the network
# If x is a tuple, extract the first element
if isinstance(x, tuple): # Checks if x is a tuple
x = x[0] # Extracts the first element of the tuple
# Ensure x is shaped correctly
x = x.transpose(1, 4) # Transposes the input tensor from (B, C, D, H, W) to (B, D, H, W, C)
B, D, H, W, C = x.shape # Gets the shape of the input tensor
# Padding
pad_input = (H % 2 == 1) or (W % 2 == 1) or (D % 2 == 1) # Checks if padding is needed
if pad_input: # If padding is needed
x = F.pad(x, (0, 0, 0, D % 2, 0, W % 2, 0, H % 2)) # Pads the input tensor if necessary
# Perform patch merging
x0 = x[:, 0::2, 0::2, 0::2, :] # Selects patches
x1 = x[:, 0::2, 0::2, 1::2, :] # Selects patches
x2 = x[:, 0::2, 1::2, 0::2, :] # Selects patches
x3 = x[:, 0::2, 1::2, 1::2, :] # Selects patches
x4 = x[:, 1::2, 0::2, 0::2, :] # Selects patches
x5 = x[:, 1::2, 0::2, 1::2, :] # Selects patches
x6 = x[:, 1::2, 1::2, 0::2, :] # Selects patches
x7 = x[:, 1::2, 1::2, 1::2, :] # Selects patches
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) # Concatenates the patches along the last dimension
print(f"x0 shape: {x0.shape}")
print(f"x1 shape: {x1.shape}")
print(f"x2 shape: {x2.shape}")
print(f"x3 shape: {x3.shape}")
print(f"x4 shape: {x4.shape}")
print(f"x5 shape: {x5.shape}")
print(f"x6 shape: {x6.shape}")
print(f"x7 shape: {x7.shape}")
# Flatten x to apply LayerNorm
x = x.view(B, D // 2, H // 2, W // 2, -1) # Reshape x to (B, D/2, H/2, W/2, 8 * C)
print(f"x.view shape: {x.shape}")
# Apply LayerNorm dynamically
norm_layer = nn.LayerNorm(x.size()[1:]).to(x.device) # Dynamically create LayerNorm with correct normalized shape
x = norm_layer(x) # Applies LayerNorm
# Apply linear reduction
print(f"x shape before reduction: {x.shape}")
x = self.reduction(x) # Applies the linear reduction
# Transpose x back to its original shape
x = x.transpose(1, 4) # Transposes the tensor back to (B, C, D/2, H/2, W/2)
print(f"x shape transpose: {x.shape}")
return x # Returns the output tensor
First Pass:
x0, x1, …, x7 shapes: [1, 16, 16, 16, 64]
x.view shape: [1, 16, 16, 16, 512]
x shape before reduction: [1, 16, 16, 16, 512]
x shape transpose: [1, 128, 16, 16, 16]
Second Pass:
x0, x1, …, x7 shapes: [1, 8, 8, 8, 64]
x.view shape: [1, 8, 8, 8, 512]
x shape before reduction: [1, 8, 8, 8, 512]
Error occurs here before the transpose.
Analysis:
The first pass through the PatchMerging3D
reduces the spatial dimensions by half, and the number of channels changes as expected. However, the second pass seems to be encountering an issue because the shape transformations lead to an unexpected size for the linear layer.
Problem:
The shape mismatch suggests that the self.reduction layer is not correctly initialized for the second pass. The error occurs because the output channels (128) of the nn.Linear layer in the first pass do not match the expected input size in the second pass. The reduction layer expects a size of 512 (from the 8 * input_dim), but it is not updated correctly after the first pass.
Solution:
To ensure the reduction layer is correctly updated for each pass, you need to manage the input dimensions dynamically. One way to do this is to recreate the reduction layer dynamically within the forward method based on the current input shape.
I found analysis, problems, and solutions on the internet, but I couldn’t dynamically recreate the reduction
layer within the forward
method based on the current input shape. Thanks in advance, sir!