Computing per sample gradient w.r.t. last layer's parameters

Hi everyone,

I’m trying to implement importance sampling based on https://github.com/idiap/importance-sampling into my PyTorch project.

My code:

size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    pred = model(X)

    # Compute per sample loss
    per_sample_loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    per_sample_loss = per_sample_loss_fn(pred, y)

    # Compute per sample gradient w.r.t. the last layer
    last_layer_params = model.lin3.parameters()
    per_sample_last_layer_grads = torch.autograd.grad(per_sample_loss, last_layer_params)

    # Compute prediction error
    loss = loss_fn(pred, y)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

I am getting the following error when computing the per sample gradients w.r.t. the last layer:

  File "/home/stdmichal/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/tmp/pycharm_project_451/test.py", line 128, in <module>
    idx, scores = train(train_dataloader, model, loss_fn, optimizer)
  File "/tmp/pycharm_project_451/test.py", line 46, in train
    per_sample_last_layer_grads = torch.autograd.grad(per_sample_loss, last_layer_params)
  File "/home/stdmichal/miniconda3/envs/pytorch_env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 218, in grad
    grad_outputs_ = _make_grads(outputs, grad_outputs_)
  File "/home/stdmichal/miniconda3/envs/pytorch_env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 50, in _make_grads
    raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

I found some similar problems (1, 2) but I am not trying to compute the whole backprop as in 1 nor can I reduce the loss as I need per sample gradients to determine sample’s importance.

What would be the correct way to do this please?

@mijalapenos

Instead of manually trying to calculate grads, we can add hooks to tap into the grad while doing the backward propagation

loss = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
    

global grad_val
def get_grad(x):
    global grad_val
    grad_val = copy.copy(x)
    return x
    
model.layer3.weight.register_hook(lambda x: get_grad(x))


for cur_epoch in range(10):
    model.zero_grad()
    output = model(X)
    final_loss = loss(output, y)
    final_loss.backward()
    print("Cur grad is {0}".format(grad_val))
    optimizer.step()
    print("Epoch {0} Loss is {1}".format(cur_epoch, final_loss.item()))

Thank you very much for your response. Sadly, I am not sure this is applicable in my case. In the final implementation, I need to compute the gradient for each sample separately (which could be solved by running with batch size of 1). Furthermore I want to avoid computing the whole back propagation to optimize performance and use just the gradient of the last layer to approximate the gradient norm per sample (as suggested in the paper, the code I posted is just a snippet where I’m trying to test the process).

Is there a way to make PyTorch compute the gradient just for some layers (in my case the only the last layer)?

Expanding upon what @anantguptadbl has stated you can get per-sample gradients via the use of registering forward_pre hooks and full_backward hooks. You can read more here

I managed to implement everything by modifying autograd-hacks package and setting all but the last trainable layer to not require gradient to speed it up. Implementation is working now, thank you everyone!