How to compute grad_sample for large model

Hi,
I am trying to Opacus for Vision Transformer model. Here is what ViT look like:

class ViT(nn.Module):
    def __init__(self, ...):
        .......
       self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
       ........

    def forward(self, x):
        ......
        x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1)  
       ......
        return x    

Using Opacus for this model will raise an error because of the nn.Parameter here self.class_token = nn.Parameter(torch.zeros(1, 1, dim)). Unless we compute grad_sample for the whole ViT model:

@register_grad_sampler(ViT)
def compute_vit_grad_sample():
   ......

But if I re-write ViT like this:

class cls_token(nn.Module):
    def __init__(self, dim):
       self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
           
    def forward(self, x):
        x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) 
        return x    


class ViT(nn.Module):
    def __init__(self, ...):
        .......
       self.class_token = cls_token(dim)
           
    def forward(self, x):
        ......
        x = cls_token(x)
       ......
        return x    

In this way, it is true that now I only need to compute grad_sample for the cls_token instead of the whole ViT module?

@register_grad_sampler(cls_token)
def compute_cls_grad_sample():
   ......

Thanks!

Hi doudeimouyi,

Thanks for your interest! The second approach (wrapping the cls_token in a nn.Module and only implementing the grad_sampler for this module) would be correct.

Indeed, in this approach, you are calling the forward method of the module cls_token, hence Opacus is able to correctly compute the grad samples.

Note however that you need to replicate your approach (encapsulate nn.Parameters in a nn.Module for which you define a custom grad sampler) every time you directly use a parameter in a forward pass instead of relying on a forward method of a child module.

Do not hesitate if you have further questions.