Hi,
I am trying to Opacus for Vision Transformer model. Here is what ViT look like:
class ViT(nn.Module):
def __init__(self, ...):
.......
self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
........
def forward(self, x):
......
x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1)
......
return x
Using Opacus for this model will raise an error because of the nn.Parameter here self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
. Unless we compute grad_sample for the whole ViT model:
@register_grad_sampler(ViT)
def compute_vit_grad_sample():
......
But if I re-write ViT like this:
class cls_token(nn.Module):
def __init__(self, dim):
self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1)
return x
class ViT(nn.Module):
def __init__(self, ...):
.......
self.class_token = cls_token(dim)
def forward(self, x):
......
x = cls_token(x)
......
return x
In this way, it is true that now I only need to compute grad_sample for the cls_token
instead of the whole ViT
module?
@register_grad_sampler(cls_token)
def compute_cls_grad_sample():
......
Thanks!