Next step in training a neural network?

I’m looking to train a model ,i am done with writing the dataloader part and setting up the model, but im not sure what is the next step i should take? if any one can suggest what i should do next would be really helpful.

This is what i have till now ,

class DSB(Dataset):
def __init__(self, root, subset = 'train', transform = None):
    
    self.root = os.path.expanduser(root)
    self.transform = transform
    self.subset = subset
    self.data_path, self.label_path =[], []
    
    def load_images(path):
        images_dir = [os.path.join(path, file) for file in os.listdir(path) if os.path.isfile(os.path.join(path,file))]
        images_dir.sort()
        
        return images_dir
    
    if self.subset =='train':
        self.data_path = load_images(self.root+'train')
        self.label_path = load_images(self.root+'train_label')
        
    elif self.subset == 'val':
        self.datapath = load_images(self.root+'val')
        self.label_path = load_images(self.root+'val_label')
    else:
        raise RuntimeError('Invalid Dataset'+ self.subset + ', it must be one of:'
                                                             ' \'train\', \'val\'')
        
def __getitem__(self,index):
    img = Image.open(self.data_path[index])
    target = Image.open(self.label_path[index]) if not self.subset == 'test' else None
        
    if self.transform is not None:
        img = self.transform(img)
        target = self.transform(target)
        return img, target
    
def __len__(self):
    return len(self.data_path)


train_dataset = DSB(root='/media/predible/da5df9e4-cdc6-4d55-91e8-b2383e89165f/dsbdata/' ,
                    subset="train",
                    transform=transforms.Compose([
                        transforms.Scale((256, 256)),
                        transforms.ToTensor()])
                    )

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=8,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=1)


img_list = []
for i in range(1):
    img, label = train_dataset[i]
    img_list.append(img)
    img_list.append(label)

im_show(img_list)






class UNetConvBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
    super(UNetConvBlock, self).__init__()
    self.conv = nn.Conv2d(in_size, out_size, kernel_size)
    self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
    self.activation = activation

def forward(self, x):
    out = self.activation(self.conv(x))
    out = self.activation(self.conv2(out))

    return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
        self.activation = activation

def center_crop(self, layer, target_size):
    batch_size, n_channels, layer_width, layer_height = layer.size()
    xy1 = (layer_width - target_size) // 2
    return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]

def forward(self, x, bridge):
    up = self.up(x)
    crop1 = self.center_crop(bridge, up.size()[2])
    out = torch.cat([up, crop1], 1)
    out = self.activation(self.conv(out))
    out = self.activation(self.conv2(out))

    return out


class UNet(nn.Module):
    def __init__(self, imsize):
        super(UNet, self).__init__()
        self.imsize = imsize

        self.activation = F.relu
    
    self.pool1 = nn.MaxPool2d(2)
    self.pool2 = nn.MaxPool2d(2)
    self.pool3 = nn.MaxPool2d(2)
    self.pool4 = nn.MaxPool2d(2)

    self.conv_block1_64 = UNetConvBlock(1, 64)
    self.conv_block64_128 = UNetConvBlock(64, 128)
    self.conv_block128_256 = UNetConvBlock(128, 256)
    self.conv_block256_512 = UNetConvBlock(256, 512)
    self.conv_block512_1024 = UNetConvBlock(512, 1024)

    self.up_block1024_512 = UNetUpBlock(1024, 512)
    self.up_block512_256 = UNetUpBlock(512, 256)
    self.up_block256_128 = UNetUpBlock(256, 128)
    self.up_block128_64 = UNetUpBlock(128, 64)

    self.last = nn.Conv2d(64, 2, 1)


def forward(self, x):
    block1 = self.conv_block1_64(x)
    pool1 = self.pool1(block1)

    block2 = self.conv_block64_128(pool1)
    pool2 = self.pool2(block2)

    block3 = self.conv_block128_256(pool2)
    pool3 = self.pool3(block3)

    block4 = self.conv_block256_512(pool3)
    pool4 = self.pool4(block4)

    block5 = self.conv_block512_1024(pool4)

    up1 = self.up_block1024_512(block5, block4)

    up2 = self.up_block512_256(up1, block3)

    up3 = self.up_block256_128(up2, block2)

    up4 = self.up_block128_64(up3, block1)

    return F.log_softmax(self.last(up4))

Define a criterion, an optimizer and you are good to go!
Have a look at the MNIST example.
Also this example is a good starter.

1 Like