Hi,
I implemented a triton kernel that is called inside a torch.autograd.Function
. When using it in my training, I got an OOM. This led me to believe that activation checkpointing doesn’t work with torch.autograd.Function
. I guess when I explicitly save tensors to the ctx
they actually get saved and torch.utils.checkpoint
doesn’t do anything about that. Is that correct? If so, how should I adapt my code to also support checkpointing?
Also, I played around with the checkpoint
function using this script on one V100 GPU (32GB of VRAM).
import torch
from torch.nn import Linear
from torch.utils.checkpoint import checkpoint
from torch.autograd import Function
class LinearFun(Function):
@staticmethod
def forward(ctx, inp, weight):
ctx.save_for_backward(inp, weight)
return inp @ weight.T
@staticmethod
def backward(ctx, grad_output):
inp, weight = ctx.saved_tensors
grad_w = grad_output.T @ inp
grad_inp = grad_output @ weight
return grad_inp, grad_w
class NewLinear(Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, inp):
return LinearFun.apply(inp, self.weight)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = Linear(1024, 1024)
self.l2 = Linear(1024, 1024)
self.l3 = Linear(1024, 1)
# self.l1 = NewLinear(1024, 1024)
# self.l2 = NewLinear(1024, 1024)
# self.l3 = NewLinear(1024, 1)
def forward(self, inp):
inp = self.l1(inp)
inp = self.l2(inp)
inp = self.l3(inp)
return inp
model = Model().cuda()
inp = torch.randn(int(1e6), 1024, device="cuda")
# <---- USE CHECKPOINTING
out = checkpoint(model, inp, use_reentrant=False)
# <---- NOT USING CHECKPOINTING
# out = model(inp)
loss = out.sum()
print('========= before backward =========')
loss.backward()
print('========= after backward =========')
I tweaked the number of inputs to int(1e6)
since that brings both scenarios to an OOM. But I don’t see that the checkpointing allows me to got to higher input dimensions.
Am I missing something?