Is it possible to deal with 'non-scalar' losses?

Context

I am trying to use Pytorch’s optimizers to perform non-linear curve fitting . I have an overall code that is working. Here it comes:

import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
import matplotlib.pyplot as plt

class Net(nn.Module):
    def __init__(self, p, IF, y, device):
        super().__init__()
        self.device = device
        self.pars   = nn.Parameter(p.to(self.device))
        self.IF     = IF.to(self.device)
        self.y      = y.to(self.device)

    def forward(self):
        self.pars.detach()
        return self.loss(self.pars)
    
    def loss(self,p):
        out = kmodel(p,self.IF,self.device) - self.y
        out = torch.sum((out)**2, dim=1)
        return torch.mean(out)

################
def kmodel(p,Cp,device=0):
    torch.autograd.set_detect_anomaly(True)
    Nv = p.shape[0]
    Nt = Cp.shape[1]
    
    ton = p[:,0].clone()
    vp  = p[:,1].clone().view(Nv,1).to(device)
        
    Cp = Cp.unsqueeze(-1).unsqueeze(1)
    affine = torch.zeros(Nv, 2, 3)
    # The affine transformation we want is identity + translation
    affine[:, 0, 0] = 1
    affine[:, 1, 1] = 1
    affine[:, 1, 2] = -ton #-2 * ton / Nt
    
    grid = F.affine_grid(affine, (Nv, 1, Nt, 1))#, align_corners=True)
    Cp = F.grid_sample(Cp, grid) #, align_corners=True)
    Cp = Cp.squeeze(1).squeeze(-1)

    out = vp*Cp
    return out

# ================== SIMULATE SOME DATA ==================
n = 5
p0 = torch.linspace(.1, 0.5, n) # ton
p1 = torch.linspace(0.5, 0.75, n) # vp
p_true = []
for i in range(n):
    for j in range(n):
                p_true.append([p0[i], p1[j]])
        
p_true = torch.tensor(p_true).to(device)     
Nv = p_true.shape[0]
print(Nv)

IF = torch.tensor([[0.0309, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0220, 0.0000, 0.0000,
         0.0074, 0.0000, 0.3594, 2.0850, 2.6615, 2.7160, 1.8999, 1.3524, 1.5733,
         1.8137, 1.7839, 1.4743, 1.4680, 1.4860, 1.3490, 1.3471, 1.3055, 1.3520,
         1.2053, 1.3447, 1.2073, 1.2478, 1.3872, 1.2890, 1.1951, 1.1463, 1.2195,
         1.1701, 1.1634, 1.1765, 1.0766, 1.2907, 1.1409, 1.0373, 1.0479, 1.0766,
         1.0842, 1.0371, 1.0241, 1.0439, 0.9803, 0.9966, 1.1955, 0.9665, 0.9679,
         1.0223, 0.9714, 0.9415, 1.0756, 1.0336, 0.9744, 0.9825, 1.0016, 0.9592,
         0.8908, 0.9362, 0.9347, 0.9198, 0.9353, 0.9279, 1.0776]]).to(device)   
Cp = torch.repeat_interleave(IF, repeats=Nv, dim=0).to(device)

%time
y_true = kmodel(p_true,Cp.clone(),device=device).to(device)
y = y_true + 0.00*torch.randn(y_true.size()).to(device)

plt.plot(IF.cpu().numpy().T,'-*', color='C9')
plt.plot(y.cpu().numpy()[:16,:].T);

# ================== OPTIMIZE ==================

from tqdm.notebook import tqdm

l = []
training_steps = 2000
lr = 0.01

Cp = torch.repeat_interleave(IF.clone(), repeats=Nv, dim=0)
par = torch.tensor([[0.1,0.1]]*Nv).to(device) + 1e-6
model = Net(par, Cp, y, device)
optim = torch.optim.Adamax(model.parameters(), lr = lr)
                
with tqdm(total=training_steps) as pbar:
    for epoch in range(training_steps):
        optim.zero_grad()
        #ypred = model()
        ll = model() #loss(ypred, _)
        ll.backward()
        optim.step()
        pbar.set_postfix(loss=ll.item(),p_est=list(model.parameters())[0].data[-1], 
                         #grad=list(model.parameters())[0].grad[-1], 
                         lr = lr, 
                         refresh=False)
        pbar.update(1)
        l.append(ll.item())
        if ll.item() <= 1e-3:
            break

End goal is to estimate those 2 parameters in my kmodel function, for each one of the provided time series (each row of the tensor y)

Problem

I can access the gradient with:

for p in model.parameters():
    print(p.grad.shape)
    print(p.grad)

and if I do this, I can see that the gradient has a shape of [Nv , 2], which is:

  • Nv = # of time series
  • 2 = # of parameters to estimate

This is perfect for what I am trying to do.
However, when I compute the loss as:

def loss(self,p):
        out = kmodel(p,self.IF,self.device) - self.y
        out = torch.sum((out)**2, dim=1)
        return torch.mean(out)

The info about all the Nv time series I am fitting are collapsed together, as if I was dealing with a single model with (Nv * 2) unknown parameters.
What I would really like to do is to define a loss like:

def loss(self,p):
        out = kmodel(p,self.IF,self.device) - self.y
        out = torch.sum((out)**2, dim=1)
        return out

So that I have a different value for each row (they are all independent from each other), and I can do a different backward pass for each one of them.

It’s hard to prove, but I feel that grouping all the time series in a single loss value, I am slowing down (and affecting) my convergence, because I am treating all parameters as if they were all part of one model, while this is not really the case.

And most importantly, having loss for each time series should ideally allow me to scale my problem (almost) linearly with the value of Nv (each backward pass taking care of 2 parameters to optimize), while now every time I had a new row, I increase a lot the complexity of my model …

Question

Do you think I have any chance of succeding in this quest? :sweat_smile:

Hi,

Your problem is a bit different from the classical neural net as you have one weight per sample that you just try to hardcore over-fit to it.

In your current code, even though you have a single loss which is the mean, since each element before the mean is computed completely independently from each (different input, different parameter, etc), then the gradients of the mean will be the same as if you called backward on each element of the sum one by one (dividing by the number of elements because of the mean).

So you do any of these (and mix) because the different part of your model are independent :slight_smile:

Note that in general, it is better to sum the losses and do a single backward (speed wise).

Just to be sure I got your point: gradients are independent for each row in my input dataset, and even if I mix all of them in the loss, autograd is still aware of it and it treats each row/time series in the exact same way (in terms of updating the parameters’ values when I do optim.step()) as if I were to compute the backward pass for each one of them independently, right?

Are you suggesting that I change my current loss, from:

def loss(self,p):
        out = kmodel(p,self.IF,self.device) - self.y
        out = torch.sum((out)**2, dim=1)
        return torch.mean(out)

to

def loss(self,p):
        out = kmodel(p,self.IF,self.device) - self.y
        out = torch.sum((out)**2, dim=1)
        return torch.sum(out)

or

def loss(self,p):
        out = kmodel(p,self.IF,self.device) - self.y
        return torch.sum((out)**2)

Will any of these options make any difference in terms of speed/memory usage?
I am especially interested in this because despite this toy example, in a real case for me Nv should be in the order of millions … :grimacing:

if I were to compute the backward pass for each one of them independently, right?

Autograd just computes the gradients.
If the gradients are independent, then what autograd computes will be independent as well :slight_smile:

Will any of these options make any difference in terms of speed/memory usage?

These three options won’t really make any difference. The only difference between the firrst two is that all the gradients you compute will be scaled by out.nelement().
For the third one, you will just add an extra square backward. So not much of a difference either.

My personal opinion is that 2 or 3 are better as the gradient computed for each independent time serie is independent of the number of time series in total (the first one just has a constant scaling so it’s not the end of the world)
For the squaring, it will depend on your application :slight_smile:

So, what’s the take-home message of this statement?returning a torch.sum is “better” than a torch.mean, or am I missing something else here?

Yes.
Speed-wise they are the same.
But for your application, I would recommend the sum, so that training sample[0] will behavior the same if len(sample) = 10 or len(sample) = 100.
If you do the mean, then you will have to scale your lr up when len(sample) increases.

1 Like