Hello community, I tried to replace linear-relu-linear structure with a fused customized autograd function. The codes are shown below. However, when I tried to replace the feed-forward network and the last classifier in simple transformer architecture (almost strictly according to the attention is all your need), the training process (very normal code logic, no tricks) would get stuck. I tried to debug the code, and I found out that if I only replace the last classifier layer with the custom function, everything works fine, but when it comes to replacing the intermediate layers (feed-forward layer), sadly everything got stuck. I have no idea why this was happening and I appreciate your help. Thanks in advance.
autograd function:
class MLPScratch(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, X, W1, b1, W2, b2):
linear = F.linear(X, W1, b1)
activated = F.relu(linear)
output = F.linear(activated, W2, b2)
ctx.save_for_backward(X, W1, b1, W2, b2, linear, activated)
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_out):
X, W1, b1, W2, b2, linear, activated = ctx.saved_tensors
grad_X = grad_W1 = grad_b1 = grad_W2 = grad_b2 = None
if b2 is not None:
grad_b2 = torch.mean(grad_out, dim=0, keepdim=True)
if grad_out.ndim == 2:
grad_out_transpose = grad_out.T
else:
grad_out_transpose = grad_out.permute(*(i for i in range(grad_out.ndim - 2)),
grad_out.ndim - 1, grad_out.ndim - 2).contiguous()
grad_W2 = grad_out_transpose @ activated
grad_activated = grad_out @ W2
grad_activated_shape = grad_activated.size()
grad_activated_flattened = grad_activated.view(-1)
linear_flattened = linear.view(-1)
grad_before_activated = torch.empty_like(linear.view(-1))
for i in range(grad_before_activated.size(0)):
grad_before_activated[i] = grad_activated_flattened[i] if linear_flattened[i] > 0 else 0
grad_before_activated = grad_before_activated.reshape(grad_activated_shape)
if b1 is not None:
grad_b1 = torch.mean(grad_before_activated, dim=0, keepdim=True)
if grad_out.ndim == 2:
grad_before_activated_transpose = grad_before_activated.T
else:
grad_before_activated_transpose = grad_before_activated.permute(
*(i for i in range(grad_before_activated.ndim - 2)),
grad_before_activated.ndim - 1, grad_before_activated.ndim - 2).contiguous()
grad_W1 = grad_before_activated_transpose @ X
grad_X = grad_before_activated @ W1
return grad_X, grad_W1, grad_b1, grad_W2, grad_b2
main call:
class FusedMLP(nn.Module):
def __init__(self, input_channel, hidden_channel, output_channel, bias=True, device=None, dtype=None):
super(FusedMLP, self).__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
hidden_shape_weight = (hidden_channel, input_channel)
hidden_shape_bias = (1, hidden_channel)
output_shape_weight = (output_channel, hidden_channel)
output_shape_bias = (1, output_channel)
self.W1 = nn.Parameter(torch.empty(*hidden_shape_weight, **factory_kwargs))
if bias:
self.b1 = nn.Parameter(torch.empty(*hidden_shape_bias, **factory_kwargs))
else:
self.b1 = self.register_parameter("b1", None)
self.W2 = nn.Parameter(torch.empty(*output_shape_weight, **factory_kwargs))
if bias:
self.b2 = nn.Parameter(torch.empty(*output_shape_bias, **factory_kwargs))
else:
self.b2 = self.register_parameter("b2", None)
self.reset_parameters()
def forward(self, X):
return MLPScratch.apply(X, self.W1, self.b1, self.W2, self.b2)
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.W1)
torch.nn.init.xavier_uniform_(self.W2)
if self.b1 is not None:
torch.nn.init.constant_(self.b1, 0.)
if self.b2 is not None:
torch.nn.init.constant_(self.b2, 0.)
BTW, I have already validate my implementation with gradcheck.