I am having some difficulties in getting the backpropagation to work while training a custom hyper network. The training epoch looks like this:
for epoch in trange(1, num_epochs + 1, desc="Training"):
for i, batch in enumerate(tqdm(dataloader, desc='Epoch', leave=False)):
data, shifteddata, labels = batch
weights, embedding, preds = model(data)
# reconstruction_loss
transformer_model.init_weights_hyper(weights)
_,transformer_loss = transformer_model(data, shifteddata)
transformer_loss = transformer_loss.mean()
wandb.log({"Reconstruction loss:": transformer_loss.item()}, epoch+1)
classification_loss = F.cross_entropy(preds, labels)
wandb.log({"Classification loss:": classification_loss.item()}, epoch+1)
model.zero_grad()
loss = lambda_v*classification_loss + (1-lambda_v)*transformer_loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
The hyper network has two different objective functions; one that calculates the classification loss in a bottleneck layer, and another main loss that is given by how well/bad the learning network (in this case how a transformer performs on the same data). The problem arises when looking at the gradients while propagating the loss backwards. When I look at the graphs on wandb.ai, it seems that all the nodes after the classification head have been disconnected from the computation graph, and no gradients exist at all for them. This does not make sense because the error is being propagated from the output that comes after the classification head, so I see no reason why they should have been disconnected.
The hyper network looks like this:
class hyperNetwork(nn.Module):
def count_params(self, model_params):
model = GPT(model_params['vocab_size'],model_params['n_embd'],model_params['n_layer'],model_params['block_size'])
gpt_params = 0
for i, j in enumerate(model.named_parameters()):
if j[0]=='pos_emb':
continue
else:
gpt_params = gpt_params + np.prod(j[1].shape)
return gpt_params
def __init__(self, vocab_size, GPT_params, encoder_params, decoder_params):
super().__init__()
self.vocab_size = vocab_size
self.encoder_params = encoder_params
## initialize the multiscale transformer
self.encoder = MultiScaleTransformer(self.vocab_size, encoder_params['embed_dim'], encoder_params['hidden_dim'], encoder_params['num_heads'], encoder_params['num_layers'], encoder_params['num_classes'], encoder_params['projection_dim'], encoder_params['classification_head'])
# this GPT is never used, it is just used to find the number of parameters to initialize
self.required_params = self.count_params(GPT_params)
self.decoder = Decoder(self.required_params)
def forward(self, x, target=None):
if self.encoder_params['classification_head'] == True:
labels, embedding = self.encoder(x)
weights = self.decoder(labels)
return weights, embedding, labels
else:
embedding = self.encoder(x)
weights = self.decoder(embedding)
return weights, embedding
As a sanity check, I removed the transformer model, and used a dummy loss function that depends directly on the output of the final layer of the hyper network; and there the gradients are definitely being preserved. I imagine somehow the graph is getting disconnected when I’m initializing the weights of the transformer, but I’m not sure what the issue is. The code for initializing the weights looks like this:
def init_weights_hyper(self, weights):
W = weights.clone()
idx = 0
for name,param in self.named_parameters():
if name=='pos_emb':
continue
else:
data_size = param.data.shape
values_reqd = np.prod(data_size)
w_idx = weights[0,idx:idx+values_reqd]
idx = idx+values_reqd
param.data = nn.parameter.Parameter(w_idx.reshape(param.data.shape))
Does anyone have any idea about what I’m doing wrong in the training?