Hi, I have a simple segmentation network.
During training I run my evaluation metrics on a validation dataset which gives good results. But after training when i use the network to predict segmentation mask I get very bad results and the problem is especially peculiar because running prediction on validation set also gives bad results. eval() and training() is properly set during training and validation phase. Is there any idea on whats going on?
Could you check, if your preprocessing pipeline is the same (in your training script and your prediction script)?
I assume you were able to load the state_dict
and didn’t get any errors?
Yes I can load state_dict
just fine.
This is my training code:
def train_network(train_dataset, valid_dataset, dirs, num_epochs=100, batch_size=10,
retrain=True, file=None, unet=False):
network = get_network(unet)
optimizer = optim.Adam(network.parameters(), lr=0.001, betas=(0.9, 0.999), amsgrad=True, eps=1e-06,
weight_decay=0.0)
scheduler = sched.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=50, cooldown=10, factor=0.75)
last_checkpoint_path = dirs['checkpoint'] + r'/checkpoint-last'
checkpoint = None
if os.path.exists(last_checkpoint_path) and not retrain:
checkpoint = torch.load(last_checkpoint_path)
network.load_state_dict(checkpoint['model_state_dict'])
network.train()
if cuda.is_available():
LOGGER.info("Using CUDA: " + torch.cuda.get_device_name(torch.cuda.current_device()))
network.cuda()
else:
LOGGER.warning("No CUDA")
if os.path.exists(last_checkpoint_path) and not retrain:
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
LOGGER.debug(network)
wei = torch.tensor([1.0, 2.0, 2.0, 15.0]).cuda()
criterion = nn.CrossEntropyLoss(weight=wei)
data_loader = data.DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True, drop_last=False)
signal.signal(signal.SIGINT, signal.default_int_handler)
try:
for epoch in range(1, num_epochs + 1):
running_loss = 0.0
for batch_index, sample in enumerate(data_loader):
optimizer.zero_grad()
outputs = network(sample["image"].cuda())
# loss = criterion(outputs, sample["mask"].float().cuda(), sample["weights"].float().cuda())
target = sample["class_mask"].squeeze(dim=1).long().cuda()
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
del target
running_loss = running_loss / len(data_loader)
test_loss, accuracy, conf = perform_test_network(valid_dataset, network, criterion, batch_size=1)
scheduler.step(running_loss)
LOGGER.info(
f'Epoch {epoch: 5d}/{num_epochs: 5d}, '
f'train_loss: {running_loss:8.5f}, '
f'test_loss: {test_loss:8.5f}, accu: {accuracy:2.9f}, '
f'background accuracy: {conf[0, 0]:2.3f}, '
f'frame accuracy: {conf[1, 1]:2.3f}, '
f'feet accuracy: {conf[2, 2]:2.3f}, '
f'defect accuracy: {conf[3, 3]:2.3f}')
if epoch % 100 == 0 or epoch == num_epochs:
out_file = os.path.join(dirs['checkpoint'], f'checkpoint-epoch_{epoch}_time_{nice_time()}')
save_checkpoint(epoch, network, optimizer, scheduler, loss, out_file)
LOGGER.info(f"Epoch model saved: {out_file}")
save_checkpoint(epoch, network, optimizer, scheduler, loss, last_checkpoint_path)
perform_test_and_visualize(network, train_dataset, dirs, epoch, batch_size=2)
except KeyboardInterrupt:
LOGGER.info(f"Training interupted, saving model for future trainig.")
save_checkpoint(epoch, network, optimizer, scheduler, loss, last_checkpoint_path)
return network, optimizer, loss, scheduler
This is my prediction code:
def predict(dataset, checkpoint_path, dirs, cuda=False, unet=False):
net = get_network(unet)
if cuda:
net.cuda()
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
data_loader = data.DataLoader(dataset, batch_size=1)
with torch.no_grad():
i = 0
for _, sample in enumerate(data_loader):
i += 1
timing = time.time()
if cuda:
sample_f = sample['image'].float().cuda()
output = net(sample_f).cpu()
else:
sample_f = sample['image'].float()
output = net(sample_f)
output_class = torch.argmax(output, dim=1)
rgbout = convert_tensor_to_RGB(output_class)
filename = dirs['test'] + '/img_' + "_".join(sample["filename"])
out_filename = filename + '.bmp'
save_image(rgbout, out_filename, padding=0, normalize=False)
LOGGER.info(f"Output file written: {out_filename}")