The provided code is extracted from here. I have two questions:
- How can one control
torch.compile
to operate in training mode, where it generates the backward pass, or in inference mode, where it doesn’t generate the backward pass? - What is the behavior when the
torch.compile
functions (swiglu
andswiglu_back
as shown below) are used in theforward
andbackward
static methods withintorch.autograd.Function
? Do they operate in the inference mode (which doesn’t generate the backward pass) as default?
@torch.compile
def swiglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return F.silu(y_1) * y_2
@torch.compile
def swiglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
return torch.cat(
(g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1
)
class SwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return swiglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = swiglu_back(grad_output, input[0])
return tmp