I’m trying to build a VNet model for medical image segmentation. The input file is 5 dimensional (1,1,128,128,64) which is (batch, channel, height, width, depth).
I have used the vnet architecture used in this paper. I have used monai for image preprocessing, and when I try to train the model, I get the following runtime error:
RuntimeError: input must be 4-dimensional
Can someone please help me understand this? I have used Conv3d for convolution at every step and my understanding is that it accepts a 5d tensor and my input matches the format accepted by the pytorch model. So why am I getting this error?
And the forward pass is defined as follows:
def forward(self, x):
# Input
x = self.inp_prelu(self.inp_conv(x))
The model is instantiated as shown below:
model = VNet(in_channels=1, out_channels=1)
model = model.to(device)
And the training is initiated as shown below:
for epoch in range(num_epochs):
# Training Phase
model.train()
running_loss = 0.0
for inputs, labels in enumerate(train_ds_loader):
inputs, labels = batch[‘image’].to(device), batch[‘label’].to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
I tried printing the shape before and after the input layer, tried separating the convolution and prelu layer to see where the issue is. All I am able to learn is that the input tensor is of the shape [1,1,64,64,32] and that the error is occurring at the input convolution step.
I can post the entire code for the model, forward pass and training if necessary, but since I seem to be getting the error in this part and the rest of it is not getting executed, I’ve only posted this much hoping it’ll be easier to read, please tell me if the entire code is required.
I seem to have resolved the issue by switching from gpu to cpu. GPU was not supporting float64 tensors, and once I realised this, I tried converting it to float32 using x.float() and when that did not work, I used .to(torch.float32) which didn’t work either. As of now the code is executing on CPU, I need to be able to make it work on GPU
I think the network take (batch , H , W, Depth) i mean by default works on 1 channel input or there is a variable that controls the input channels if you can provide some code related to architecture