can somebody tell me what is wrong with my code?
all the images have the size of 256x256
images have 3 channels and mask has 1 channel
and 1 batch for each
I don’t understand what this error is
RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [1, 1, 256, 256]
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import os, sys, random, time
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
import re
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#model
class unet(nn.Module):
def __init__(self):
super(unet, self).__init__()
self.encoder1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1, bias=False),
nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.encoder2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.encoder3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.encoder4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
self.conv_mid = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.decoder4 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.decoder3 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.decoder2 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.decoder1 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(inplace=True))
self.upconv4 = nn.ConvTranspose2d(1024, 1024, kernel_size=(2, 2), stride=2)
self.upconv3 = nn.ConvTranspose2d(512, 512, kernel_size=(2, 2), stride=(2, 2))
self.upconv2 = nn.ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
self.upconv1 = nn.ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2))
self.conv1x1_out = nn.Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
def forward(self, x):
out = self.encoder1(x)
out = self.pool1(out)
out = self.encoder2(out)
out = self.pool2(out)
out = self.encoder3(out)
out = self.pool3(out)
out = self.encoder4(out)
out = self.pool4(out)
out = self.conv_mid(out)
out = self.upconv4(out)
out = self.decoder4(out)
out = self.upconv3(out)
out = self.decoder3(out)
out = self.upconv2(out)
out = self.decoder2(out)
out = self.upconv1(out)
out = self.decoder1(out)
out = self.conv1x1_out(out)
return out
model = unet().to(device)
class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.mriids = next(os.walk(self.imgs_dir))[2]
self.maskids = next(os.walk(self.masks_dir))[2]
def __len__(self):
return len(self.mriids)
def __getitem__(self, i):
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split(r'(\d+)', text) ]
self.mriids = sorted(self.mriids, key = natural_keys)
self.maskids = sorted(self.maskids, key = natural_keys)
mriidx = self.mriids[i]
maskidx = self.maskids[i]
mask_file = glob(self.masks_dir + maskidx)
img_file = glob(self.imgs_dir + mriidx)
mask = Image.open(mask_file[0])
img = Image.open(img_file[0])
mask_np = np.array(mask)
img_np = np.array(img)
mask_np = np.expand_dims(mask_np, axis=-1)
mask_np = np.transpose(mask_np, (2, 0, 1))
img_np = np.transpose(img_np, (2, 0, 1))
#print(img_np.shape)
#print(mask_np.shape)
return {'image': torch.from_numpy(img_np), 'mask': torch.from_numpy(mask_np)}
epochs = 1000
batch_size = 1
lr = 0.00001
momentum = 0.99
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
loss_func = nn.CrossEntropyLoss().to(device)
gen = BasicDataset('/home/intern/Desktop/YH/Brain_MRI/BrainMRI_train/MRI/MRI/', '/home/intern/Desktop/YH/Brain_MRI/BrainMRI_train/mask/mask/')
train_loader = DataLoader(gen, batch_size=batch_size)
total_batch = len(gen)
model.train()
for epoch in range(epochs):
for batch in train_loader:
t0 = time.time()
avg_cost = 0.0
mri = batch['image']
true_mask = batch['mask']
mri = mri.type(torch.FloatTensor)
true_mask = true_mask.type(torch.LongTensor)
mri = mri.to(device)
true_mask = true_mask.to(device)
pred_mask = model(mri)
loss = loss_func(pred_mask, true_mask)
optimizer.zero_grad()
loass.backward()
optimizer.step()
avg_cost += loss / total_batch
t1 = time.time()
print('[Epoch:{}], cost = {}, time = {}'.format(epoch+1, avg_cost, t1-t0))
print('training Finished!')
Output
{'mask': tensor([[[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8), 'image': tensor([[[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 1, ..., 1, 0, 0],
[0, 0, 0, ..., 1, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],
[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[3, 4, 4, ..., 6, 5, 2],
[4, 4, 4, ..., 5, 6, 4],
[3, 2, 2, ..., 2, 1, 2]],
[[4, 4, 3, ..., 0, 0, 0],
[2, 1, 1, ..., 0, 0, 0],
[2, 3, 3, ..., 0, 0, 0],
...,
[2, 2, 1, ..., 1, 1, 1],
[1, 2, 2, ..., 2, 1, 1],
[1, 1, 1, ..., 1, 2, 3]]]], dtype=torch.uint8)}
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-69-a858eeed6905> in <module>()
22
23 pred_mask = model(mri)
---> 24 loss = loss_func(pred_mask, true_mask)
25
26 optimizer.zero_grad()
/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
914 def forward(self, input, target):
915 return F.cross_entropy(input, target, weight=self.weight,
--> 916 ignore_index=self.ignore_index, reduction=self.reduction)
917
918
/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2019 if size_average is not None or reduce is not None:
2020 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2021 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
2022
2023
/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1838 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
1839 elif dim == 4:
-> 1840 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
1841 else:
1842 # dim == 3 or dim > 4
RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [1, 1, 256, 256]