Hello, I am training a UNet, and the input and output images are grayscale.
When I run it, I don’t know how to fixed the error
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class UNET(nn.Module):
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
super(UNET, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Up part of UNET
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2),
)
self.ups.append(DoubleConv(feature * 2, feature))
self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1] # reverse sort
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx // 2] # floor division
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx + 1](concat_skip)
return self.final_conv(x)
def test():
x = torch.randn((3, 1, 600, 600))
model = UNET(in_channels=1, out_channels=1)
preds = model(x)
print(x.shape)
assert preds.shape == x.shape
if __name__ == "__main__":
test()
It gives the error:
`Traceback (most recent call last):`
File "C:/DeepLearning/UNet/train.py", line 120, in <module>
main()
File "C:/DeepLearning/UNet/train.py", line 98, in main
check_accuracy(val_loader, model, device=DEVICE)
File "C:\DeepLearning\UNet\utils.py", line 62, in check_accuracy
for x, y in loader:
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\torch\utils\data\dataloader.py", line 517, in __next__
data = self._next_data()
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\torch\utils\data\dataloader.py", line 1199, in _next_data
return self._process_data(data)
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\torch\utils\data\dataloader.py", line 1225, in _process_data
data.reraise()
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\torch\_utils.py", line 429, in reraise
raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\torch\utils\data\_utils\worker.py", line 202, in _worker_loop
data = fetcher.fetch(index)
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\DeepLearning\UNet\dataset.py", line 24, in __getitem__
augmentations = self.transform(image=image, mask=mask)
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\albumentations\core\composition.py", line 182, in __call__
data = t(force_apply=force_apply, **data)
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\albumentations\core\transforms_interface.py", line 89, in __call__
return self.apply_with_params(params, **kwargs)
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\albumentations\core\transforms_interface.py", line 102, in apply_with_params
res[key] = target_function(arg, **dict(params, **target_dependencies))
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\albumentations\augmentations\transforms.py", line 1496, in apply
return F.normalize(image, self.mean, self.std, self.max_pixel_value)
File "C:\Users\Tsai\anaconda3\envs\tf\lib\site-packages\albumentations\augmentations\functional.py", line 141, in normalize
img -= mean
ValueError: operands could not be broadcast together with shapes (600,600,2) (3,) (600,600,2)