Hi, I am looking for computing the gradient w.r.t a part of samples in PyTorch.
For example, I first compute the loss on 4 samples (batch_size=4). Then I want to compute the gradients of 2 samples among them depend on a
mask (see the code). Note that the backward time of 2 samples should be around half of that of the whole 4 samples.
Is there a way to perform backward on a part of samples with less time cost in PyTorch? I have found that PyTorch realeased ‘functorch’ library, but it seems can not meet my needs.
import torch from torchvision import models samples = torch.rand(4, 3, 224, 224) # batch_size = 4 model = models.resnet18() outputs = model(samples) loss = outputs.mean() mask = torch.tensor([1,1,0,0]) # the mask means the gradients of the first two samples are required # how to perform backward according to the mask with less time cost