Hi!
I’m currently working on optimal transport distance for the inverse problem. I’m using PyTorch loss.backward()
for computing the gradient of optimal transport distance. However, it doesn’t reflect any updates on my model. I would like to know, does PyTorch support the autograd of optimal transport distance?
Thank you.
Here I attached snapshots of my codes:
for b in batches:
batch_xs = xs[b::BATCHES]
batch_xr = xr[b::BATCHES]
data_sim = pde_solver(batch_xs, batch_xr)
data_true = noisy_data[b::BATCHES, ...]
sim, data_sim_cdf = OT_data_normalization(data_sim, 6.0, 1.0)
true, data_true_cdf = OT_data_normalization(data_true, 6.0, 1.0)
with torch.cuda.amp.autocast():
loss = OTloss(data_sim_cdf, data_true_cdf, sim, T)
opt.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(opt)
nn.utils.clip_grad_value_(pde_solver.model, clip_value=1e3)
scaler.step(opt)
scaler.update()
scheduler.step()
# Data normalization function
def OT_data_normalization(data, b, c):
ns, nr, nt = data.shape
data = data.reshape(ns*nr, nt)
data = torch.log(torch.exp(b*data) + 1)
data_norm = (data + c) / torch.sum(data + c, dim=-1,keepdim=True)
data_cdf = torch.cumsum(data_norm, dim=-1)
return data_norm, data_cdf
# compute OT loss
def OTloss(cdf_sim, cdf_obs, dsim, t):
idx = torch.searchsorted(cdf_obs, cdf_sim)
idx[idx == t.shape[-1]] = -1
tidx = t[idx]
return ((t - tidx)**2 * dsim).sum()