Modifying the LoRA implementation to use external inputs

I’ve been trying to modify LoRA to act as an auxiliary network for an llm. Similar to what ControlNet does with the U-net in stable diffusion according to my high level understanding.

I have it setup to replace all nn.Linear layers with my modified LoRAPlusLinearLayers. I have also wrapped the top level module to create a forward function that parses the extra input and stores it. Where I’m stuck is how to actually transfer the extra inputs for use with my LoRA’s forward function from the top level.

I’ve tried using global variables (just to get it working), callbacks to the top level Module where I store it, and defining a Wrapper for all Modules which passes down the extra input as a separate argument.

The global variables and callback calls return None despite assigning them values

The Module wrapper Excepts stating that “IndexError: Dimension specified as -2 but tensor has no dimensions” or “ModuleWrapper couldnt handle argument past_key_values” which I guess is due to my ignorance surrounding how args or kwargs are handled leading to my ModuleWrapper shielding the original Module of the past_key_values cache

# Using the model EleutherAI/gpt-neo-125M 
class LinearWithLoRAP(torch.nn.Module):
    def __init__(self, linear, rank, alpha, model): 
        super().__init__() 
        num_extra_features = 1 # size of extra inputs
        self.linear = linear
        self.lora = LoRALayer(
            rank, num_extra_features, linear.out_features, alpha
        )

    def forward(self, x, extra):
        return self.linear(x) + self.lora(extra)

class RootWrapperLoRAP(torch.nn.Module):

    def __init__(self, model, rank, alpha): 
        super().__init__()
        self.model = model
        self.Extra = None
        for param in self.model.parameters():
            param.requires_grad = False 
        #add_lora(self.model)
        #self.get_extra = lambda x : lambda x : self.Extra
        #replace_layers(self.model, self.get_extra)

    def get_extra(self):
        return self.Extra #if self.Extra else torch.Tensor(1)

    def finish(self): # replace_layers replaces linear layers with LinearWithLoRAP (no issues), replace_modules replaces all modules that aren't a wrapper with ModuleWrapper (no issues)
        replace_layers(self.model, self.get_extra) # placed in here since self.model and self.get_extra werent defined during init, but self.get_extra still returns None so doesnt work either way
        replace_modules(self.model) # Module Wrapper Attempt

    def forward(self, input_tuple):
        global globalExtra
        input_ids, extra = input_tuple
        self.Extra = extra # Callback Attempt
        globalExtra = extra # Global Attempt
        print("FINAL_IN", input_ids, "EXTRA", self.Extra) # self.Extra not None here
        x = self.model(input_ids, extra) # extra passed in part of Module Wrapper attempt
        #self.Extra = None
        return x 

class ModuleWrapper(torch.nn.Module):
    def __init__(self, module):
        super(ModuleWrapper, self).__init__()
        self.module = module

    def forward(self, *args, extra=None, **kwargs):
        if isinstance(self.module, LinearWithLoRAP):
            output = self.module(*args, extra=extra, **kwargs)
        else:
            output = self.module(*args, **kwargs)
        
        return output, extra_args

I realize my problem is analogous to skip connections, but the implementations (ResNet) I saw share state by having all the Modules built to handle passing shared/past state down. I like the idea of “injecting” my custom LoRA and having a modular implementation and not having to redesign each architecture to handle my extra inputs. It’s for the same reason I haven’t considered concatenating the extra input since I’d have to handle that in each forward function.

If relevant, my testing data is {‘isEnglish’: …, ‘text’: …} where I plan to use isEnglish as a 1x1 tensor that’s -1.0 or 1.0 to act as input to the LoRA’s. All LoRA’s are receiving the same 1x1 input Tensor. The model runs if I hardcode a extra input value so this doesn’t seem like an issue.

My next steps are to look into how ControlNet injects itself and also pytorch hooks but since Ive already spent a decent amount of time twiddling my thumbs trying to fix and understand bugs I thought I might ask y’all.

Basically, how can I pass an input from the top of my model to the various LinearPlusLoRALayers I’ve injected into a GPT llm? I’m not that worried about the efficacy or this being a decent way to augment a model. Just a cool idea