I’m trying to train my custom data set using a U-net model. My data is a set of satellite images that has 7 bands with corresponding masks.
The masks are RGB and don’t have any classes, so I just need the network to learn a correlation between the image and the mask which represents the crops land cover type for each image in the training set.
So I thought going with the Mean square loss is the right choice, but now I’m having nan starting from the first epoch.
This is my U-net model
conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
return conv
def crop_img(tensor, target_tensor):
target_size = target_tensor.size()[2]
tensor_size = tensor.size()[2]
delta = tensor_size - target_size
if delta % 2 is not 0:
delta = delta // 2
return tensor[:, :, delta:tensor_size-delta-1 , delta:tensor_size-delta-1]
else:
delta = delta // 2
return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.max_pool_2X2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_conv_1 = double_conv(7, 64)
self.down_conv_2 = double_conv(64, 128)
self.down_conv_3 = double_conv(128, 256)
self.down_conv_4 = double_conv(256, 512)
self.up_trans_1 = nn.ConvTranspose2d(
in_channels=512,
out_channels=256,
kernel_size=2,
stride=2)
self.up_conv_1 = double_conv(512, 256)
self.up_trans_2 = nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=2,
stride=2)
self.up_conv_2 = double_conv(256, 128)
self.up_trans_3 = nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=2,
stride=2)
self.up_conv_3 = double_conv(128, 64)
self.d_conv1 = nn.Conv2d(64, 32, 3, padding=1)
self.d_conv2 = nn.Conv2d(32, 16, 3, padding=1)
self.out = nn.Conv2d(
in_channels=16,
out_channels=3, # The mask have three channels
kernel_size=1
)
def forward(self, image):
# encoder
x1 = self.down_conv_1(image) # We need this part to go to the decoder
x2 = self.max_pool_2X2(x1)
x3 = self.down_conv_2(x2) #
x4 = self.max_pool_2X2(x3)
x5 = self.down_conv_3(x4) #
x6 = self.max_pool_2X2(x5)
x7 = self.down_conv_4(x6) # 512*11*11
# Decoder
x = self.up_trans_1(x7)
# Concatinate x with x5
y = crop_img(x5, x)
x = self.up_conv_1(torch.cat([x, y], 1))
x = self.up_trans_2(x)
y = crop_img(x3, x)
x = self.up_conv_2(torch.cat([x, y], 1))
#print('x', y.size())
x = self.up_trans_3(x)
y = crop_img(x1, x)
x = self.up_conv_3(torch.cat([x, y], 1))
#print('x', x.size())
# Upsampling our image
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = F.relu(self.d_conv1(x))
x = F.interpolate(x, size=[153,153], mode='nearest')
x = F.relu(self.d_conv2(x))
x = self.out(x)
#print('x', x.size())
return x
And here is my loss function and training code
# specify loss function
lr=0.001
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' , patience=2)
if train_on_gpu:
model.cuda()
n_epochs = 10
for epoch in range(1, n_epochs+1):
# monitor training loss
train_loss = 0.0
###################
# train the model #
###################
for data in train_loader:
# _ stands in for labels, here
# no need to flatten images
data_proc = process_img(data)
images, labels = data_proc['rasters'], data_proc['labels']
if train_on_gpu:
images, labels = images.cuda(), labels.cuda()
# clear the gradients of all optimized variables
optimizer.zero_grad()
# forward pass: compute predicted outputs by passing inputs to the model
outputs = model(images)
# calculate the loss
#labels = torch.argmax(labels, dim=1)
loss = criterion(outputs, labels)
# backward pass: compute gradient of the loss with respect to model parameters
loss.backward()
# perform a single optimization step (parameter update)
optimizer.step()
# update running training loss
train_loss += loss.item()*images.size(0)
# print avg training statistics
train_loss = train_loss/len(train_loader)
print('Epoch: {} \tTraining Loss: {:.6f}'.format(
epoch,
train_loss
))
I also normalized all of my images bands values so I have it between 0-1.
Thanks