I am new to PyTorch and I am trying to do semantic segmentation.
I am trying to do semantic segmentation with two classes - Edge and Non-Edge.
I have 224x224x3 images and 224x224 binary segmentation masks. I am reshaping the masks to be 224x224x1 (I read somewhere that this is the format that I should pass to the model). Here is a sample image and mask: https://imgur.com/IfAO2zv
I want to try whatever model, loss, and optimizer to proceed with the training. I am currently trying with torchvision.models.segmentation.
fcn_resnet50
which I found that can be used for segmentation from the docs (I am not sure if I have to modify it or use it as it is).
I get the following errors when I try different loss functions:
BCELoss
AttributeError: 'collections.OrderedDict' object has no attribute 'size'
CrossEntropyLoss
AttributeError: 'collections.OrderedDict' object has no attribute 'log_softmax'
NLLLoss
AttributeError: 'collections.OrderedDict' object has no attribute 'dim'
Here is the code:
roof_edges_dataset.py
import os
import cv2
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from utils import create_binary_mask, get_labelme_shapes, plot_segmentation_dataset
class RoofEdgesDataset(Dataset):
def __init__(self, im_path, ann_path, transform=None):
self.im_path = im_path
self.ann_path = ann_path
self.transform = transform
self.im_fn_list = sorted(os.listdir(im_path), key=lambda x: int(x.split('.')[0]))
self.ann_fn_list = sorted(os.listdir(ann_path), key=lambda x: int(x.split('.')[0]))
def __len__(self):
return len(self.im_fn_list)
def __getitem__(self, index):
im_path = os.path.join(self.im_path, self.im_fn_list[index])
im = cv2.imread(im_path)
ann_path = os.path.join(self.ann_path, self.ann_fn_list[index])
ann = create_binary_mask(im, get_labelme_shapes(ann_path))
ann = ann.reshape(ann.shape[0], ann.shape[1], 1)
ann = transforms.ToTensor()(ann)
if self.transform:
im = self.transform(im)
return im, ann
main.py
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.utils.data import DataLoader
from roof_edges_dataset import RoofEdgesDataset
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
in_im_shape = (3, 224, 224)
num_classes = 2 # Edge / Non-edge
learning_rate = 0.001
batch_size = 4
n_epochs = 10
# Data - 60% Train - 20% Val - 20% Test
transformations = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = RoofEdgesDataset(im_path='data/images', ann_path='data/annotations', transform=transformations)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
# Model
model = torchvision.models.segmentation.fcn_resnet50(pretrained=False, progress=True, num_classes=2)
model.to(device)
print(model)
# Loss and Optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Train
for epoch in range(n_epochs):
for batch_idx, (image, annotation) in enumerate(train_loader):
image = image.to(device=device)
annotation = annotation.to(device=device)
# forward
output = model(image)
loss = criterion(output, annotation)
# backward
optimizer.zero_grad()
loss.backward()
# gradient descent (adam step)
optimizer.step()
if (batch_idx + 1) % 2 == 0:
print(
f'Epoch [{epoch + 1}/{n_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')
# Evaluate
How should I proceed? What am I doing wrong? How can I fix it?
Also, any examples, guides, tutorials, references, and everything that will help me solve my issue and understand the topic better is welcome.