I have a loss function that requires multiple internal passes:
def my_loss_func(logits, sigma, labels, num_passes):
total_loss = 0
img_batch_size = logits.shape[0]
logits_shape = list(logits.shape)
vol_std = np.zeros((img_batch_size, num_passes))
for fpass in range(num_passes):
noise_array = torch.normal(mean=0.0, std=1.0, size=logits_shape, device=torch.device('cuda:0'))
stochastic_output = logits + sigma * noise_array
del noise_array
temp_vol = torch.softmax(stochastic_output, dim=1)
temp_vol = temp_vol[:, 0, ...]
vol_std[:, fpass] = temp_vol.view(4, -1).sum(1).detach().cpu().numpy()
del temp_vol
exponent_B = torch.log(torch.sum(torch.exp(stochastic_output), dim=-1, keepdim=True))
inner_logits = exponent_B - stochastic_output
soft_inner_logits = labels * inner_logits
total_loss += torch.exp(soft_inner_logits)
del exponent_B, inner_logits, soft_inner_logits
mean_loss = total_loss / num_passes
actual_loss = torch.mean(torch.log(mean_loss))
batch_std = np.std(vol_std, axis=1)
return actual_loss, batch_std
Both logits and sigma are networks outputs and therefore have associated gradients. I run into memory issues when num_passes exceeds 50, are there any other ways in which I could fully optimise memory allocation to allow for a greater number of passes? I’m not at all concerned with readability/ ugly solutions, anything will do.