I’m looking to train a model ,i am done with writing the dataloader part and setting up the model, but im not sure what is the next step i should take? if any one can suggest what i should do next would be really helpful.
This is what i have till now ,
class DSB(Dataset):
def __init__(self, root, subset = 'train', transform = None):
self.root = os.path.expanduser(root)
self.transform = transform
self.subset = subset
self.data_path, self.label_path =[], []
def load_images(path):
images_dir = [os.path.join(path, file) for file in os.listdir(path) if os.path.isfile(os.path.join(path,file))]
images_dir.sort()
return images_dir
if self.subset =='train':
self.data_path = load_images(self.root+'train')
self.label_path = load_images(self.root+'train_label')
elif self.subset == 'val':
self.datapath = load_images(self.root+'val')
self.label_path = load_images(self.root+'val_label')
else:
raise RuntimeError('Invalid Dataset'+ self.subset + ', it must be one of:'
' \'train\', \'val\'')
def __getitem__(self,index):
img = Image.open(self.data_path[index])
target = Image.open(self.label_path[index]) if not self.subset == 'test' else None
if self.transform is not None:
img = self.transform(img)
target = self.transform(target)
return img, target
def __len__(self):
return len(self.data_path)
train_dataset = DSB(root='/media/predible/da5df9e4-cdc6-4d55-91e8-b2383e89165f/dsbdata/' ,
subset="train",
transform=transforms.Compose([
transforms.Scale((256, 256)),
transforms.ToTensor()])
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=8,
shuffle=True,
pin_memory=True,
num_workers=1)
img_list = []
for i in range(1):
img, label = train_dataset[i]
img_list.append(img)
img_list.append(label)
im_show(img_list)
class UNetConvBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
super(UNetConvBlock, self).__init__()
self.conv = nn.Conv2d(in_size, out_size, kernel_size)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
self.activation = activation
def forward(self, x):
out = self.activation(self.conv(x))
out = self.activation(self.conv2(out))
return out
class UNetUpBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
super(UNetUpBlock, self).__init__()
self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
self.conv = nn.Conv2d(in_size, out_size, kernel_size)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
self.activation = activation
def center_crop(self, layer, target_size):
batch_size, n_channels, layer_width, layer_height = layer.size()
xy1 = (layer_width - target_size) // 2
return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]
def forward(self, x, bridge):
up = self.up(x)
crop1 = self.center_crop(bridge, up.size()[2])
out = torch.cat([up, crop1], 1)
out = self.activation(self.conv(out))
out = self.activation(self.conv2(out))
return out
class UNet(nn.Module):
def __init__(self, imsize):
super(UNet, self).__init__()
self.imsize = imsize
self.activation = F.relu
self.pool1 = nn.MaxPool2d(2)
self.pool2 = nn.MaxPool2d(2)
self.pool3 = nn.MaxPool2d(2)
self.pool4 = nn.MaxPool2d(2)
self.conv_block1_64 = UNetConvBlock(1, 64)
self.conv_block64_128 = UNetConvBlock(64, 128)
self.conv_block128_256 = UNetConvBlock(128, 256)
self.conv_block256_512 = UNetConvBlock(256, 512)
self.conv_block512_1024 = UNetConvBlock(512, 1024)
self.up_block1024_512 = UNetUpBlock(1024, 512)
self.up_block512_256 = UNetUpBlock(512, 256)
self.up_block256_128 = UNetUpBlock(256, 128)
self.up_block128_64 = UNetUpBlock(128, 64)
self.last = nn.Conv2d(64, 2, 1)
def forward(self, x):
block1 = self.conv_block1_64(x)
pool1 = self.pool1(block1)
block2 = self.conv_block64_128(pool1)
pool2 = self.pool2(block2)
block3 = self.conv_block128_256(pool2)
pool3 = self.pool3(block3)
block4 = self.conv_block256_512(pool3)
pool4 = self.pool4(block4)
block5 = self.conv_block512_1024(pool4)
up1 = self.up_block1024_512(block5, block4)
up2 = self.up_block512_256(up1, block3)
up3 = self.up_block256_128(up2, block2)
up4 = self.up_block128_64(up3, block1)
return F.log_softmax(self.last(up4))