Hi ,
i have this error in my fusion cnn . two input images should be fused in cnn, images size is (256*256).
If someone have an idee , please help.
Thanks.
My cnn and my training code:
#define the network
class FunFuseAn(nn.Module):
def init(self):
super(FunFuseAn, self).init()
#####mri lf layer 1#####
self.mri_lf = nn.Sequential( #input shape (,1,256,256)
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=9, stride=1, padding=4),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2,inplace=True)) #output shape (,16,256,256)
#####mri hf layers#####
self.mri_hf = nn.Sequential( #input shape (,1,256,256)
nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = 3, stride= 1, padding = 1),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2,inplace=True)) #output shape (,64,256,256)
#####pet lf layer 1#####
self.pet_lf = nn.Sequential( #input shape (,1,256,256)
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2,inplace=True)) #output shape (,16,256,256)
#####pet hf layers#####
self.pet_hf = nn.Sequential( #input shape (,1,256,256)
nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = 5, stride= 1, padding = 2),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 5, stride = 1, padding = 2),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2,inplace=True)) #output shape (,64,256,256)
#####reconstruction layer 1#####
self.recon1 = nn.Sequential( #input shape (, 64, 256, 256)
nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 5, stride = 1, padding = 2),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(in_channels = 32, out_channels = 16, kernel_size = 5, stride = 1, padding = 2),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2,inplace=True)) #output shape (,16, 256, 256)
#####reconstruction layer 2#####
self.recon2 = nn.Sequential( #input shape (,16, 256, 256)
nn.Conv2d(in_channels = 16, out_channels = 1, kernel_size = 5, stride = 1, padding = 2)) #output shape (,1,256,256)
def forward(self, x, y):
#mri lf layer 1
#x1 = x.view(x.size(0), -1)
x1 = self.mri_lf(x)
#mri hf layers
# x2 = x.view(x.size(0), -1)
#x2 = x.reshape(-1,1,256,256)
x2 = self.mri_hf(x)
#pet lf layer 1
#y1 = x.view(x.size(0), -1)
y1 = self.pet_lf(y)
#pet hf layers
#y2 = x.view(x.size(0), -1)
#y2 = x.reshape(-1,1,256,256)
y2 = self.pet_hf(y)
#high frequency fusion layer
fuse_hf = x2 + y2
#reconstruction layer1
recon_hf = self.recon1(fuse_hf)
#low frequency fusion layer
fuse_lf = (x1 + y1 + recon_hf)/3
#reconstruction layer2
recon3 = self.recon2(fuse_lf)
#tanh layer
fused = torch.tanh(recon3)
fused = fused.reshape(-1)
return fused
#execute the network
x = torch.randn(14, 1, 256, 256)
y = torch.randn(14, 1,256,256)
cnn = FunFuseAn()
cnn = cnn.float()
output = cnn(x, y)
print(cnn)
#mage_length, image_width, gray_channels, batch_size, epoch, lr, images_pet, images_mri = init_param()
#import train dataset
#define the network
#define the optimizers and loss functions
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = SSIM() # the target label is not one-hotted
training
for epoch in range(EPOCH):
#data.sequeeze(dim=1)
for step, (b_x,b_y ) in enumerate(train_data_loader):
b_x = torch.randn(14, 1,256 , 256)
b_y = torch.randn( 14,1,256 , 256)
b_x = b_x.unsqueeze(0)
b_y = b_y.unsqueeze(0)
output = cnn(b_x,b_y)
loss = loss_func(output.squeeze(),b_y)+loss_func(output.squeeze(), b_x)
# cnn output
# loss = loss_func(output, b_x)+ loss_func(output,b_y)#+ cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients