How can I save the proper embeddings and weights after training?

Hi, always thank you for your kind replies.
I have a question about the saving embedding and weight vectors after the model training.


Preliminaries

I perform the link prediction task and use the PyG library.
To do this, I use LinkNeighborLoader and it calls the edge information with paired nodes.
Each dataset has x(node) features, edge_index, edge_label, and edge_class to indicate the dataset’s source.
Note that edge_class is imbalanced since the number of edges in each dataset is different.


Given the 4 datasets (index from 0 to 3) and a LinkNeighborLoader for combining them into each batch.

# Each value represents the number of the dataset
print(data.edge_class)
# tensor([1, 3, 2, 3, 0, 2, 2, 0], device='cuda:2') #batch 0
# tensor([0, 2, 0, 2, 2, 1, 1, 3], device='cuda:2') #batch 1
# tensor([0, 2, 0, 0, 0, 0, 2, 1], device='cuda:2') #batch 2
...

class MyModel(torch.nn.Module):
    def __init__(self, args, in_channels, hidden_channels, out_channels, num_relations):
        super(MyModel, self).__init__()
...
    def forward(self, d1, d2, node_id, neighbor_cl_ids, edge_index, edge_type): 

        x_dict, attn1_dict, attn2_dict = {}, {}, {} # = {} #

        # trial2: Concat
        for e_i_class in torch.unique(neighbor_cl_ids):
            x = torch.cat((d1[node_id], d2[int(e_i_class.item())][node_id]), dim=1)
            x, (edge_index_1, a_1) = self.conv1(x, edge_index, edge_type, return_attention_weights=True) #
            x = self.act(x) 
            x = self.dropout(x)
            x, (edge_index_2, a_2) = self.conv2(x, edge_index, edge_type, return_attention_weights=True) #

            x_dict[int(e_i_class.item())] = x
            attn1_dict[int(e_i_class.item())] = (edge_index_1, a_1)
            attn2_dict[int(e_i_class.item())] = (edge_index_2, a_2)
            
        return x_dict, attn1_dict, attn2_dict

In the model, I set the dictionaries to save the embedding vectors and 1st/2nd attention scores across the datasets.
During epochs, I tried to train the model like this:

for epoch in range(3):
...
    model.train()
    for data in tqdm(train_loader):
        data = data.to(device)
        print(data.edge_class)

        optimizer.zero_grad()
        z, a1, a2 = model(data.x[0], data.x[1], data.n_id, data.edge_class, data.edge_index, data.edge_type) 
# z: dictionary of node embedding vectors, a1&a2: dictionary of attention scores from GAT-based model
# The key of dictionary will be the index of dataset and value contains the corresponding vectors per dataset

However, after running the epochs, I only could get the values of the partial dataset, not the full dataset. (i.e., there were no data for the 2nd dataset (idx=1))
I guess that the dictionary would save the last result.

print(z)
{0: tensor([[ 0.3912, -0.7021,  0.1847,  ..., -1.2331,  0.6514,  0.3439],
         ...,
         [ 0.3215, -1.5592, -0.0601,  ..., -2.0088,  0.7314,  0.1563]],
        device='cuda:2', grad_fn=<AddBackward0>),
 2: tensor([[ 0.4996, -0.5098,  0.2439,  ..., -1.1096,  0.5557,  0.3468],
         ...,
         [ 0.3767, -1.5131, -0.0300,  ..., -1.9102,  0.6989,  0.1949]],
        device='cuda:2', grad_fn=<AddBackward0>),
 3: tensor([[ 2.9430e-01, -5.2454e-01,  4.5100e-04,  ..., -9.9463e-01,
           5.5265e-01,  3.9856e-01],
         ...,
         [ 3.2548e-01, -1.6286e+00, -1.1648e-01,  ..., -2.0169e+00,
           6.8402e-01,  2.6424e-01]], device='cuda:2', grad_fn=<AddBackward0>)}

After tracking the code, I found that all values were preserved after removing optimizer.zero_grad().
But I think that removing optimizer.zer_grad() is the wrong context since my model is a mini-batched, not an RNN-based model.
To save the all updated embedding vectors and attention score, how should I fix my code?

Thank you for reading the massy question.

Update

I found that there were the embedding vectors and attention score for the whole dataset after adding the further steps of codes.

    model.train()
    for data in tqdm(train_loader):
        tr_loss = []
        data = data.to(device)
        data.edge_class = data.edge_class[data.input_id]

        optimizer.zero_grad()
        z, a1, a2 = model(data.x[0], data.x[1], data.n_id, data.edge_class, data.edge_index, data.edge_type) #전체 node feature를 모델에 넣음
        
        for i in torch.unique(data.edge_class):
            tr_loss_mean = 0
            i = int(i.item())
            
            tr_out = model.decode(z[i], data.edge_label_index[:, data.edge_class==i])
            tr_loss.append(criterion(tr_out, data.edge_label[data.edge_class==i].float()))
            
            #try:
                #tr_out = model.decode(z[i], data.edge_label_index[:, data.edge_class==i])
                #tr_loss.append(criterion(tr_out, data.edge_label[data.edge_class==i].float()))
            #except KeyError:
                #pass
                
        tr_loss_mean = sum(tr_loss)/len(tr_loss)
        tr_loss_mean.backward() #retain_graph=True
        optimizer.step()
print(z)
#{0: tensor([[ 0.3331, -1.0215,  0.9440,  ...,  0.3577, -0.2995, -0.3058],
#         ...,
#         [ 0.8367, -0.2150,  0.0894,  ..., -0.4216,  0.4663,  0.1768]],
#        device='cuda:2', grad_fn=<AddBackward0>),
# 1: tensor([[ 0.3921, -0.2746,  0.9942,  ..., -0.5685,  0.0403, -0.6099],
#         ...,
#         [ 0.9800, -0.3124,  0.1543,  ..., -0.4559,  0.3139,  0.2093]],
#        device='cuda:2', grad_fn=<AddBackward0>),
# 2: tensor([[ 0.7992, -0.4216,  1.1965,  ..., -0.2722, -0.0562, -0.5231],
#         ...,
#         [ 0.9281, -0.2738,  0.1264,  ..., -0.4438,  0.3720,  0.2010]],
#        device='cuda:2', grad_fn=<AddBackward0>),
# 3: tensor([[ 0.8567, -0.6530,  1.1218,  ..., -1.0058,  0.0113, -0.7525],
#         ...,
#         [ 0.8974, -0.2511,  0.1250,  ..., -0.4186,  0.3938,  0.1992]],
#        device='cuda:2', grad_fn=<AddBackward0>)}

By the way, when I add the validation process after training, it returned only the last edge class like this:

...
model.eval()
    with torch.no_grad():
        y_val_pred, y_val_pred_prob, y_val_true = [], [], []
        for data in tqdm(val_loader):
            val_loss = []
            data = data.to(device)
            data.edge_class = data.edge_class[data.input_id]
                
            y_val_true.append(data.edge_label)
            z_val, a1_val, a2_val = model(data.x[0], data.x[1], data.n_id, data.edge_class, data.edge_index, data.edge_type)

            for i in torch.unique(data.edge_class):
                val_loss_mean = 0
                i = int(i.item())
                    
                try:
                    val_out = model.decode(z[i], data.edge_label_index[:, data.edge_class==i])
                    val_out_sig = val_out.sigmoid()
                    val_loss.append(criterion(val_out, data.edge_label[data.edge_class==i].float()))
                except KeyError:
                    pass
... #The following process just contains the saving loss values
print(z_val)
#{3: tensor([[ 1.6437,  0.7458,  2.3925,  ..., -0.4451,  0.4667, -0.5906],
#         ...,
#         [ 0.4809, -0.9566,  0.2333,  ..., -0.1838, -0.2463, -0.3823]],
#        device='cuda:2')}

Although I am guessing that model.eval and with torc.no_grad() are the causes, I couldn’t track why the validation step only returns the last one’s vectors and don’t know how to get the embedding vectors of the remained edge classes…

Update

Throughout the debugging process, I anticipate that, due to the nature of a validation step that does not shuffle the data, the results will be calculated sequentially from edge class 0 to 3, with only the final batch (which contains only edge class 3) being output.
Okay. When I set the batch_size of LinkNeighborLoader for the validation dataset as the whole number of edges (i.e., no mini-batch), z_val returns all the embedding vectors I expected.
However, uploading all datasets to my GPU at once is difficult because of the OOM problem.
In this case, how can I get the entire embedding and attention scores for all edge classes using the mini-batch approach?