Training got stuck with custom autograd module

Hello community, I tried to replace linear-relu-linear structure with a fused customized autograd function. The codes are shown below. However, when I tried to replace the feed-forward network and the last classifier in simple transformer architecture (almost strictly according to the attention is all your need), the training process (very normal code logic, no tricks) would get stuck. I tried to debug the code, and I found out that if I only replace the last classifier layer with the custom function, everything works fine, but when it comes to replacing the intermediate layers (feed-forward layer), sadly everything got stuck. I have no idea why this was happening and I appreciate your help. Thanks in advance.

autograd function:

class MLPScratch(torch.autograd.Function):
    def forward(ctx, X, W1, b1, W2, b2):

        linear = F.linear(X, W1, b1)
        activated = F.relu(linear)
        output = F.linear(activated, W2, b2)

        ctx.save_for_backward(X, W1, b1, W2, b2, linear, activated)
        return output

    def backward(ctx, grad_out):
        X, W1, b1, W2, b2, linear, activated = ctx.saved_tensors
        grad_X = grad_W1 = grad_b1 = grad_W2 = grad_b2 = None

        if b2 is not None:
            grad_b2 = torch.mean(grad_out, dim=0, keepdim=True)
        if grad_out.ndim == 2:
            grad_out_transpose = grad_out.T
            grad_out_transpose = grad_out.permute(*(i for i in range(grad_out.ndim - 2)),
                                                  grad_out.ndim - 1, grad_out.ndim - 2).contiguous()
        grad_W2 = grad_out_transpose @ activated
        grad_activated = grad_out @ W2
        grad_activated_shape = grad_activated.size()
        grad_activated_flattened = grad_activated.view(-1)
        linear_flattened = linear.view(-1)
        grad_before_activated = torch.empty_like(linear.view(-1))
        for i in range(grad_before_activated.size(0)):
            grad_before_activated[i] = grad_activated_flattened[i] if linear_flattened[i] > 0 else 0
        grad_before_activated = grad_before_activated.reshape(grad_activated_shape)
        if b1 is not None:
            grad_b1 = torch.mean(grad_before_activated, dim=0, keepdim=True)

        if grad_out.ndim == 2:
            grad_before_activated_transpose = grad_before_activated.T
            grad_before_activated_transpose = grad_before_activated.permute(
                *(i for i in range(grad_before_activated.ndim - 2)),
                grad_before_activated.ndim - 1, grad_before_activated.ndim - 2).contiguous()
        grad_W1 = grad_before_activated_transpose @ X

        grad_X = grad_before_activated @ W1
        return grad_X, grad_W1, grad_b1, grad_W2, grad_b2

main call:

class FusedMLP(nn.Module):
    def __init__(self, input_channel, hidden_channel, output_channel, bias=True, device=None, dtype=None):
        super(FusedMLP, self).__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        hidden_shape_weight = (hidden_channel, input_channel)
        hidden_shape_bias = (1, hidden_channel)
        output_shape_weight = (output_channel, hidden_channel)
        output_shape_bias = (1, output_channel)
        self.W1 = nn.Parameter(torch.empty(*hidden_shape_weight, **factory_kwargs))
        if bias:
            self.b1 = nn.Parameter(torch.empty(*hidden_shape_bias, **factory_kwargs))
            self.b1 = self.register_parameter("b1", None)
        self.W2 = nn.Parameter(torch.empty(*output_shape_weight, **factory_kwargs))
        if bias:
            self.b2 = nn.Parameter(torch.empty(*output_shape_bias, **factory_kwargs))
            self.b2 = self.register_parameter("b2", None)


    def forward(self, X):
        return MLPScratch.apply(X, self.W1, self.b1, self.W2, self.b2)

    def reset_parameters(self):
        if self.b1 is not None:
            torch.nn.init.constant_(self.b1, 0.)
        if self.b2 is not None:
            torch.nn.init.constant_(self.b2, 0.)

BTW, I have already validate my implementation with gradcheck.

Maybe you can try using torch.autograd.gradcheck to make sure your custom function is computing the correct gradients.

Hi @soulitzer I validated my implementation with gradcheck. The code should be fine.