Making HVP More Scalable for Larger Training Set Sizes

Hello Everyone,

I am currently working on a research project in which I need to calculate the Hessian Vector Product (HVP) w.r.t. my model’s parameters. I am able to successfully calculate the HVP, however the calculation slows significantly and chews up a significant amount of memory as the size of the training set I am using increases.

Let me make things more explicit:

	def GetHVPFunc_alt(self, datax, datay, testx, testy):
		params = [ p for p in self.model.parameters() if p.requires_grad ]
		predtest = self.model(testx)
		losstest = calc_loss(predtest, testy)
		grads_test = self.To_List(autograd.grad(losstest, params))

		def HVP(v):
			s = time.time()
			predtrain = self.model(datax)
			losstrain = calc_loss(predtrain, datay)

			v_t = self.Reverse_To_List(v, params)
			out = hvp_alt(losstrain, params, v_t)
			print("Time HVP: ", time.time() - s)
			return out
		return HVP, grads_test

I want to calculate H*v where H is the hessian of the loss function w.r.t. the model parameters. As I increase the number of points in datax and datay (the training set images/labels respectively) the computation time of calculating the model output and the model loss increases as expected, however the calculation time along with the memory usage of hvp_alt grows significantly faster. This is likely because the function calls autograd.grad twice, which I believe creates large graphs that take a lot of memory and that scale with the number of inputs to the model. See the code below:

def hvp(y, w, v):
    if len(w) != len(v):
        raise(ValueError("w and v must have the same length."))
    # First backprop
    first_grads = grad(y, w, retain_graph=True, create_graph=True)
    # Elementwise products
    elemwise_products = 0
    for grad_elem, v_elem in zip(first_grads, v):
        elemwise_products += torch.sum(grad_elem * v_elem)
    # Second backprop
    return_grads = grad(elemwise_products, w, create_graph=False)
    return return_grads

I should note that the code above is not my own and is taken from the following github which comes from the paper I am basing my research on: GitHub - nimarb/pytorch_influence_functions: This is a PyTorch reimplementation of Influence Functions from the ICML2017 best paper: Understanding Black-box Predictions via Influence Functions by Pang Wei Koh and Percy Liang. . I have been using 100 data points from CIFAR-10 to approximate the HVP causing each iteration of my overall code (I am approximating the inverse HVP i.e. H^-1 * v) to finish in ~6 seconds, however when I increase the training set size to 1000 the run time increases to ~60 seconds.

In short, is there a way for me to get an accurate, efficient approximation of the HVP that doesn’t depend so heavily on training set size? Is my current solution of using a subset of the data sufficient or will this result in an inaccurate HVP calculation?

Feel free to ask for clarification if anything was confusing.

Thank you so much,


I could be wrong here, but it really depends on how big the graph is. How big is your model ? Also you can look into torch.jit. This can get around the complexity overhead as it tries to avoid python dependencies.



We are actually working on forward mode right now and that should help quite a bit in your case.
Do you have a small self-contained script that shows the model you’re using?

1 Like

Hello Alban,

I’m using a smaller model I got from a tutorial right now for testing. Here is a script that describes the model:

# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=False)

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16) = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2)
        self.layer3 = self.make_layer(block, 64, layers[2], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out =
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

model = ResNet(ResidualBlock, [2, 2, 2])

Is there anything else you want to see? I do agree that the model size will definitely alter the execution time as well. Also, out of curiosity how would this “forward mode” help me?

1 Like