I also have this error and I can’t find the exact culprit
mat1 and mat2 shapes cannot be multiplied (3x196608 and 784x512)
I’m new to Machine Learning and it’s my first time with PyTorch.
I’m trying to build the following UNet: UNet
First, here’s how I’m building the dataset:
class TissueDataset(Dataset):
def __init__(self, img_path, target_path):
super().__init__()
self.imgs = glob(os.path.join(img_path, "*.jpg"))
self.targets = glob(os.path.join(target_path, "*.jpg"))
def __getitem__(self, idx):
size = (3, 256, 256)
image = imread(self.imgs[idx])
label = imread(self.targets[idx])
image = resize(image, size, order = 1, preserve_range = True)
label = resize(label, size, order = 0, preserve_range = True).astype(int)
return image, label
def __len__(self):
return len(self.imgs)
trainloader = DataLoader(TissueDataset(
img_path = f'data/tissue/train/jpg',
target_path = f'data/tissue/train/lbl'
), batch_size = 3, shuffle = True)
Here is an example image: Image
And here is the UNet model - I’ve tried different things but all end up with the same error while training.
I don’t know where exactly is the error, or if what I’m doing even corresponds to the image.
class DoubleConv2d(Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv2d, self).__init__()
self.stack = Sequential(
Conv2d(in_channels, out_channels, 3, 1, 1),
ReLU(),
Conv2d(out_channels, out_channels, 3, 1, 1),
ReLU()
)
def forward(self, x):
return self.stack(x)
class UNet(Module):
def __init__(self):
super(UNet, self).__init__()
inn = 3
out = 1
mid = [112, 224, 448]
self.encoder = ModuleList()
self.bottom = DoubleConv2d(mid[-1], 2*mid[-1]) # should both be mid[-1]?
self.decoder = ModuleList()
self.end = Conv2d(mid[0], 1*out, 1)
self.maxpool = MaxPool2d(2, 2)
# self.linear = Linear(32*7*7, 10)
for dim in mid:
self.encoder.append(DoubleConv2d(inn, dim))
inn = dim
for dim in mid[::-1]:
self.decoder.append(ConvTranspose2d(2*dim, dim, 2, 2))
self.decoder.append(DoubleConv2d(2*dim, dim))
def forward(self, x):
connections = []
for i in range(len(self.encoder)):
module = self.encoder[i]
x = module(x)
connections.append(x)
x = self.maxpool(x)
x = self.bottom(x)
for i in range(len(self.decoder)):
module = self.decoder[i]
x = module(x)
if i % 0 == 0: # ConvTranspose2d
connection = connections.pop()
x = torch.cat((connection, x), dim=1)
x = self.end(x)
x = x.view(x.size(0), -1)
return x
print(UNet())
Any help or pointers would be greately appreciated