Hi,
I am currently trying to implement a 3dUnet {crop-type-mapping/unet3d.py at master · roserustowicz/crop-type-mapping · GitHub} for satellite time series regression.
My input is of the shape [b*t*c*h*w], and the mask is [b*c*h*w]. My training loss (MSE) is not decreasing during my preliminary training results (it stays exactly the same), and i’m just not sure where i’ve gone wrong.
torch.Size([1, 12, 15, 256, 256])
torch.Size([1, 1, 256, 256])
## model
"""
Taken from https://github.com/roserustowicz/crop-type-mapping/
Implementation by the authors of the paper :
"Semantic Segmentation of crop type in Africa: A novel Dataset and analysis of deep learning methods"
R.M. Rustowicz et al.
Slightly modified to support image sequences of varying length in the same batch.
"""
import torch
import torch.nn as nn
def conv_block(in_dim, middle_dim, out_dim):
model = nn.Sequential(
nn.Conv3d(in_dim, middle_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(middle_dim),
nn.LeakyReLU(inplace=True),
nn.Conv3d(middle_dim, out_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(out_dim),
nn.LeakyReLU(inplace=True),
)
return model
def center_in(in_dim, out_dim):
model = nn.Sequential(
nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(out_dim),
nn.LeakyReLU(inplace=True))
return model
def center_out(in_dim, out_dim):
model = nn.Sequential(
nn.Conv3d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(in_dim),
nn.LeakyReLU(inplace=True),
nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1))
return model
def up_conv_block(in_dim, out_dim):
model = nn.Sequential(
nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm3d(out_dim),
nn.LeakyReLU(inplace=True),
)
return model
class UNet3D(nn.Module):
def __init__(self, in_channel, n_classes, timesteps=12, dropout=0.5):
super(UNet3D, self).__init__()
self.in_channel = in_channel
self.n_classes = n_classes
feats = 16
self.en3 = conv_block(in_channel, feats * 4, feats * 4)
self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
self.en4 = conv_block(feats * 4, feats * 8, feats * 8)
self.pool_4 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
self.center_in = center_in(feats * 8, feats * 16)
self.center_out = center_out(feats * 16, feats * 8)
self.dc4 = conv_block(feats * 16, feats * 8, feats * 8)
self.trans3 = up_conv_block(feats * 8, feats * 4)
self.dc3 = conv_block(feats * 8, feats * 4, feats * 2)
self.final = nn.Conv3d(feats * 2, n_classes, kernel_size=3, stride=1, padding=1)
self.fn = nn.Linear(timesteps, 1)
self.logsoftmax = nn.LogSoftmax(dim=1)
self.dropout = nn.Dropout(p=dropout, inplace=True)
def forward(self, x):
x = x.float()
x = x.permute(0, 2, 1, 3, 4)
out = x.cuda()
en3 = self.en3(out)
pool_3 = self.pool_3(en3)
en4 = self.en4(pool_3)
pool_4 = self.pool_4(en4)
center_in = self.center_in(pool_4)
center_out = self.center_out(center_in)
concat4 = torch.cat([center_out,en4],dim=1)
dc4 = self.dc4(concat4)
trans3 = self.trans3(dc4)
concat3 = torch.cat([trans3,en3],dim=1)
dc3 = self.dc3(concat3)
final = self.final(dc3)
final = final.permute(0,1,3,4,2) # BxCxHxWxT
shape_num = final.shape[0:4]
final = final.reshape(-1,final.shape[4])
final = self.dropout(final)
final = self.fn(final)
final = final.reshape(shape_num)
final = self.logsoftmax(final)
return final
## initilise the model
model = UNet3D(in_channel=15, n_classes=1)
# initliase loss and optimiser
loss_module = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
criterion = loss_module
# training
def train(model, optimizer, criterion, train_loader, device=None):
model.train()
print('Training')
train_running_loss = 0.0
train_running_correct = 0
train_running_RMSE = 0.0
counter = 0
for i, batch in enumerate(train_loader):
counter += 1
batch = recursive_todevice(batch, device)
input, label = batch
print(input.shape)
print(label.shape)
optimizer.zero_grad()
# forward pass
outputs = model(input)
# calculate the loss
loss = criterion(outputs, label)
train_running_loss += loss.item()
rmse = torch.sqrt(loss)
train_running_RMSE += rmse
## Log losses to Neptune vis
run["training/epoch/loss"].log(loss)
run["training/epoch/rmse"].log(rmse)
# calculate the accuracy
#TODO
# _, preds = torch.max(outputs.data, 1)
# train_running_correct += (preds == labels).sum().item()
##
# backpropagation
loss.backward()
# update the optimizer parameters
optimizer.step()
# loss and accuracy for the complete epoch
epoch_loss = train_running_loss / counter
#epoch_acc = 100. * (train_running_correct / len(train_loader.dataset)) ## TODO ACCURACY
return epoch_loss #, epoch_acc
# validation
def validate(model, criterion, val_loader, device=None):
model.eval()
print('Validation')
valid_running_loss = 0.0
valid_running_correct = 0
val_running_RMSE = 0.0
counter = 0
with torch.no_grad():
for i, batch in enumerate(val_loader):
counter += 1
batch = recursive_todevice(batch, device)
input, label = batch
# forward pass
outputs = model(input)
# calculate the loss
loss = criterion(outputs, label)
valid_running_loss += loss.item()
rmse = torch.sqrt(loss)
val_running_RMSE += rmse
## Log losses to Neptune vis
run["val/epoch/loss"].log(loss)
run["val/epoch/rmse"].log(rmse)
# calculate the accuracy
# _, preds = torch.max(outputs.data, 1)
# valid_running_correct += (preds == labels).sum().item()
# loss and accuracy for the complete epoch
epoch_loss = valid_running_loss / counter
#epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
return epoch_loss#, epoch_acc
# define how many epochs to train for
epochs = 5
# lists to keep track of losses and accuracies
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
# start the training
for epoch in range(epochs):
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
train_epoch_loss = train(model=model, optimizer=optimizer, criterion=criterion,
train_loader=train_loader, device = device)
valid_epoch_loss = validate(model=model, criterion=criterion,
val_loader=val_loader, device = device)
train_loss.append(train_epoch_loss)
valid_loss.append(valid_epoch_loss)
print(f"Training loss: {train_epoch_loss:.3f}")
print(f"Validation loss: {valid_epoch_loss:.3f}")
# save the best model till now if we have the least loss in the current epoch
save_best_model(
valid_epoch_loss, epoch, model, optimizer, criterion
)
print('-'*50)
# save the trained model weights for a final time
save_model(epochs, model, optimizer, criterion)
# save the loss and accuracy plots
print('TRAINING COMPLETE')