Difficulties in training a hyper-network

I am having some difficulties in getting the backpropagation to work while training a custom hyper network. The training epoch looks like this:

for epoch in trange(1, num_epochs + 1, desc="Training"):

        for i, batch in enumerate(tqdm(dataloader, desc='Epoch', leave=False)):
            
            data, shifteddata, labels = batch
            
            weights, embedding, preds = model(data)
          
            # reconstruction_loss
            transformer_model.init_weights_hyper(weights)

            _,transformer_loss = transformer_model(data, shifteddata)

            transformer_loss = transformer_loss.mean()

            wandb.log({"Reconstruction loss:": transformer_loss.item()}, epoch+1)
            classification_loss = F.cross_entropy(preds, labels)
            
            wandb.log({"Classification loss:": classification_loss.item()}, epoch+1)

            model.zero_grad()
            loss = lambda_v*classification_loss + (1-lambda_v)*transformer_loss
            
             
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)


            optimizer.step()
            lr_scheduler.step()

The hyper network has two different objective functions; one that calculates the classification loss in a bottleneck layer, and another main loss that is given by how well/bad the learning network (in this case how a transformer performs on the same data). The problem arises when looking at the gradients while propagating the loss backwards. When I look at the graphs on wandb.ai, it seems that all the nodes after the classification head have been disconnected from the computation graph, and no gradients exist at all for them. This does not make sense because the error is being propagated from the output that comes after the classification head, so I see no reason why they should have been disconnected.

The hyper network looks like this:

class hyperNetwork(nn.Module):

    def count_params(self, model_params):

        model = GPT(model_params['vocab_size'],model_params['n_embd'],model_params['n_layer'],model_params['block_size'])

        gpt_params = 0
        for i, j in enumerate(model.named_parameters()):
            if j[0]=='pos_emb':
                continue
            else:
                gpt_params = gpt_params + np.prod(j[1].shape)
        
        return gpt_params

    def __init__(self, vocab_size, GPT_params, encoder_params, decoder_params):

        super().__init__()

        self.vocab_size = vocab_size
        self.encoder_params = encoder_params
        ## initialize the multiscale transformer

        self.encoder = MultiScaleTransformer(self.vocab_size, encoder_params['embed_dim'], encoder_params['hidden_dim'], encoder_params['num_heads'], encoder_params['num_layers'], encoder_params['num_classes'],  encoder_params['projection_dim'], encoder_params['classification_head'])
        # this GPT is never used, it is just used to find the number of parameters to initialize
        self.required_params = self.count_params(GPT_params)
        self.decoder = Decoder(self.required_params)
        


    def forward(self, x, target=None):

        if self.encoder_params['classification_head'] == True:
            labels, embedding = self.encoder(x)
            weights = self.decoder(labels)
            return weights, embedding, labels
        else:
            embedding = self.encoder(x)
            weights = self.decoder(embedding)
            return weights, embedding

As a sanity check, I removed the transformer model, and used a dummy loss function that depends directly on the output of the final layer of the hyper network; and there the gradients are definitely being preserved. I imagine somehow the graph is getting disconnected when I’m initializing the weights of the transformer, but I’m not sure what the issue is. The code for initializing the weights looks like this:

def init_weights_hyper(self, weights):

    W = weights.clone()    
    idx = 0
    for name,param in self.named_parameters():
        if name=='pos_emb':
            continue
        else:
            data_size = param.data.shape
            values_reqd = np.prod(data_size)
            w_idx = weights[0,idx:idx+values_reqd]
            idx = idx+values_reqd

            param.data = nn.parameter.Parameter(w_idx.reshape(param.data.shape))

Does anyone have any idea about what I’m doing wrong in the training?