Insert adapters in a transformer

Hi everyone,

I’m trying to insert adapters in GPT-2, which are supposed to be small trainable modules embedded into a transformer. Here is my code so far:

class Adapter(nn.Module):
    """
    The adapters first project the original
    d-dimensional features into a smaller dimension, m, apply
    a nonlinearity, then project back to d dimensions.
    """
    def __init__(self, size = 6, model_dim = 768):
        super().__init__()
        self.adapter_block = nn.Sequential(
            nn.Linear(model_dim, size),
            nn.ReLU(),
            nn.Linear(size, model_dim)
        )

    def forward(self, x):

        ff_out = self.adapter_block(x)
        # Skip connection
        adapter_out = ff_out + x

        return adapter_out


class Adaptered(nn.Module):
    def __init__(self, orig_layer):
        super().__init__()
        self.orig_layer = orig_layer
        self.adapter = Adapter()

    def forward(self, *x):
        orig_out = self.orig_layer(*x)
        output = (self.adapter.forward(orig_out[0].unsqueeze(0))[0],)

        return output



class GPT2_with_adapter(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=False)
        # Freeze the original model parameters
        for params in self.model.parameters():
            params.requires_grad = False
        # Embed adapter layers into the transformer blocks 
        for i in range(12):
            self.model.transformer.h[i].attn.c_proj = Adaptered(self.model.transformer.h[i].attn.c_proj)
            self.model.transformer.h[i].mlp.c_proj = Adaptered(self.model.transformer.h[i].mlp.c_proj)

    def get_model(self):

        return self.model

The problem is that the adapter weights don’t get updated during the backward pass. Can anyone point out what the issue is?

I cannot reproduce the issue and see gradients in your new custom layers:

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

out = model(**inputs)
loss = out[0].mean()
loss.backward()

print(model.transformer.h[0].attn.c_proj.adapter.adapter_block[0].weight.grad.abs().sum())
# tensor(6387.1465)

I also had to fix a previous error in your code snippet raising:

# TypeError: dropout(): argument 'input' (position 1) must be Tensor, not tuple

by removing the tuple creation in your custom module.

Thanks for replying @ptrblck.

You are right, the new custom layers have non-zero gradients. But somehow, the weights of those layers remain unchanged after the optimizer has taken a step:

class Adaptered(nn.Module):
    def __init__(self, orig_layer):
        super().__init__()
        self.orig_layer = orig_layer
        self.adapter = Adapter()
        # To keep track of previous weights
        self.prev_weights = self.adapter.adapter_block[0].weight

    def forward(self, *x):
        # This is non-zero after loss.backward() has been called
        print(self.adapter.adapter_block[0].weight.grad.abs().sum())
        # But this remains 0 after the optimizer has taken a step
        print((self.adapter.adapter_block[0].weight - self.prev_weights).abs().sum())
        self.prev_weights = self.adapter.adapter_block[0].weight
        orig_out = self.orig_layer(*x)
        output = (self.adapter.forward(orig_out[0].unsqueeze(0))[0],)[0]

        return output

Could you please confirm this once? Thanks.

If it helps, this is how my optimizer is defined:

model = GPT2_with_adapter()
model = model.get_model()

model.train()
optimizer = AdamW(model.parameters(), weight_decay=args.weight_decay)

No, I cannot confirm it as the weights are updated using an optimizer:

model = GPT2_with_adapter().get_model()
tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
optimizer = torch.optim.AdamW(model.parameters(), lr=1.)

print("before update")
print(model.transformer.h[0].attn.c_proj.adapter.adapter_block[0].weight.abs().sum())
# tensor(82.4845, grad_fn=<SumBackward0>)

out = model(**inputs)
loss = out[0].mean()
loss.backward()

print("grad")
print(model.transformer.h[0].attn.c_proj.adapter.adapter_block[0].weight.grad.abs().sum())
# tensor(111252.1250)

optimizer.step()

print("after update")
print(model.transformer.h[0].attn.c_proj.adapter.adapter_block[0].weight.abs().sum())
# tensor(3098.1809, grad_fn=<SumBackward0>)

You are right @ptrblck. I think the problem was with the combination of my scheduler and a low learning rate because of which the weights weren’t really changing in the first few iterations. Thanks again!

Lastly, I’ll be really glad if you could maybe briefly recommend a better way of implementing these adapters. Or do you think that my current approach is good enough?