Is there a way to compute the gradient w.r.t a part of samples?

Hi, I am looking for computing the gradient w.r.t a part of samples in PyTorch.

For example, I first compute the loss on 4 samples (batch_size=4). Then I want to compute the gradients of 2 samples among them depend on a mask (see the code). Note that the backward time of 2 samples should be around half of that of the whole 4 samples.

Is there a way to perform backward on a part of samples with less time cost in PyTorch? I have found that PyTorch realeased ‘functorch’ library, but it seems can not meet my needs.


import torch

from torchvision import models

samples = torch.rand(4, 3, 224, 224) # batch_size = 4

model = models.resnet18()

outputs = model(samples)

loss = outputs.mean()

mask = torch.tensor([1,1,0,0]) 
# the mask means the gradients of the first two samples are required

# how to perform backward according to the mask with less time cost