This is a tiny model in Python that shows that the embed_table weights are correctly updated during the learning.
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed( 1337 )
vocab_size = 3 # 0, 1 and 2
inputs = torch.tensor( (1, 2, 0, 1) )
targets = torch.tensor( (2, 0, 1, 2) )
class BigramModel (nn.Module):
def __init__(self):
super().__init__()
self.embed_table = nn.Embedding( vocab_size, vocab_size )
def forward (self, idx):
return self.embed_table( idx )
def print_sum_weights (self):
w = self.embed_table.weight
print( f'SUM = {w.sum().item():.5f}' )
model = BigramModel()
model.print_sum_weights()
optimizer = torch.optim.AdamW( model.parameters(), lr=0.001 )
n_steps = 3
for step in range( n_steps ):
print( '---', step+1, '---' )
logits = model( inputs )
loss = F.cross_entropy( logits, targets )
print( f'loss = {loss.item():.5f}' )
optimizer.zero_grad( set_to_none=True )
loss.backward()
optimizer.step()
model.print_sum_weights()
The printed output is
SUM = -6.67094
--- 1 ---
loss = 0.83275
SUM = -6.67388
--- 2 ---
loss = 0.83176
SUM = -6.67681
--- 3 ---
loss = 0.83078
SUM = -6.67974
We can see that the sum of the weights is changing and the loss is going down. Perfect.
Now, I write the same model in C++ with the same seeding :
#include <torch/torch.h>
//************************************************************************
int g_vocab_size = 3; // 0, 1 and 2
torch::Tensor g_inputs = torch::tensor( { 1, 2, 0, 1 } );
torch::Tensor g_targets = torch::tensor( { 2, 0, 1, 2 } );
//************************************************************************
struct _BigramModel : torch::nn::Module
{
torch::nn::Embedding m_embed_table { nullptr };
_BigramModel (void) {
m_embed_table = register_module( "embedding",
torch::nn::Embedding(g_vocab_size, g_vocab_size));
}
torch::Tensor forward (torch::Tensor idx) {
return m_embed_table( idx );
}
void print_sum_weights (void) {
auto w = m_embed_table->weight;
std::cout << "SUM = " << w.sum().item().toFloat() << std::endl;
}
};
TORCH_MODULE_IMPL( BigramModel, _BigramModel );
//************************************************************************
int main (void)
{
torch::manual_seed( 1337 );
// Create the model
auto model = BigramModel();
model->print_sum_weights();
// Create the optimizer
int learning_rate = 0.001;
auto optimizer = torch::optim::AdamW( model->parameters(),
learning_rate );
// Training
int n_steps = 3;
for (int step=0; step<n_steps; step++)
{
std::cout << "--- " << step+1 << " ---" << std::endl;
// Forward operation
torch::Tensor logits = model( g_inputs );
// Evaluate the loss
torch::Tensor loss = torch::nn::functional::
cross_entropy( logits, g_targets );
std::cout << "loss = " << loss.item().toFloat() << std::endl;
// Calculate the new grads
optimizer.zero_grad( /*set_to_none*/true );
loss.backward();
// SHOULD apply the grads to the weights <<<<<
optimizer.step();
// Debugging
model->print_sum_weights();
}
// Done
return 0;
}
//************************************************************************
and here is the printed output of this C++ version
SUM = -6.67094
--- 1 ---
loss = 0.83275
SUM = -6.67094
--- 2 ---
loss = 0.83275
SUM = -6.67094
--- 3 ---
loss = 0.83275
SUM = -6.67094
The first printed sum and loss are identical to the Python version.
But after, in this version the weights and loss are not updated.
It seems the “optimizer.step()” is doing nothing.
I guess I miss something here but I don’t know what.