Hi there,
I am debugging a piece of a much larger project which aims to use the Gumbel-softmax function to draw samples from a categorical distribution of angles between [-pi, pi] which are used downstream to build 3D coordinates for an eventual MSE loss on those coordinates. Obviously using a cross-entropy loss on the logits directly learns the task but I set the below examples up as a proxy for my downstream task which I ideally don’t want to use cross-entropy loss for.
I noticed that when training the larger model, the logits going into the gumbel_softmax function were converging to the same values regardless of their inputs and produced the same samples regardless of the temperature so I set up these examples to test why this might be and am confused by why these two implementations result in one model able to learn from the Gumbel-softmax samples while one model is not.
Example 1: This model is unable to learn with a loss that does not decrease and samples which converge to a single value or values:
torch.random.manual_seed(0)
num_data = 1000
input_dim = 100
num_classes = 72
mlp = nn.Sequential(
nn.Linear(input_dim, 128),
nn.GELU(),
nn.Linear(128, 128),
nn.GELU(),
nn.Linear(128, num_classes)
)
mlp.train()
inputs = torch.randn((num_data, input_dim))
outputs = torch.randint(0, num_classes, (num_data,))
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
outputs_to_value = torch.linspace(-torch.pi, torch.pi, num_classes).unsqueeze(-1)
for epoch in range(1000):
optimizer.zero_grad()
logits = mlp(inputs)
sample = F.gumbel_softmax(logits, tau=1, hard=True, dim=-1)
sample.retain_grad()
sample_values = sample @ outputs_to_value
cos_loss = F.mse_loss(torch.cos(sample_values), torch.cos(outputs_to_value[outputs]))
sin_loss = F.mse_loss(torch.sin(sample_values), torch.sin(outputs_to_value[outputs]))
loss = cos_loss + sin_loss
loss.backward(retain_graph=True)
if epoch == 0 or (epoch + 1) % 100 == 0:
print(f'Epoch {epoch} Loss: {loss.item()}')
print(f'\t Grad Sum: {sample.grad.abs().sum()}')
optimizer.step()
print("Prediction Counts:")
print(scatter(torch.ones(logits.shape[0]), logits.argmax(dim=-1), dim=0, reduce='sum', dim_size=num_classes))
print("Labels Counts:")
print(scatter(torch.ones(outputs.shape[0]), outputs, dim=0, reduce='sum', dim_size=num_classes))
Example 1 Outputs:
Epoch 0 Loss: 1.934694766998291
Grad Sum: 145.82659912109375
Epoch 99 Loss: 1.9157741069793701
Grad Sum: 144.36563110351562
Epoch 199 Loss: 2.0428879261016846
Grad Sum: 144.54237365722656
Epoch 299 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Epoch 399 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Epoch 499 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Epoch 599 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Epoch 699 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Epoch 799 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Epoch 899 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Epoch 999 Loss: 2.023682117462158
Grad Sum: 145.2460479736328
Prediction Counts:
tensor([474., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 526.])
Labels Counts:
tensor([16., 8., 11., 11., 10., 9., 15., 17., 12., 21., 13., 11., 10., 13.,
15., 15., 15., 20., 9., 14., 13., 16., 14., 14., 15., 14., 18., 13.,
16., 19., 13., 13., 18., 9., 11., 16., 16., 13., 10., 20., 18., 11.,
20., 13., 12., 13., 17., 15., 10., 17., 13., 4., 16., 16., 8., 12.,
17., 20., 14., 16., 10., 13., 8., 12., 16., 21., 22., 15., 10., 10.,
14., 11.])
Example 2: While this model is able to learn with a decreasing loss and improved performance with more epochs
torch.random.manual_seed(0)
num_data = 1000
input_dim = 100
num_classes = 72
mlp = nn.Sequential(
nn.Linear(input_dim, 128),
nn.GELU(),
nn.Linear(128, 128),
nn.GELU(),
nn.Linear(128, num_classes)
)
mlp.train()
inputs = torch.randn((num_data, input_dim))
outputs = torch.randint(0, num_classes, (num_data,))
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
outputs_to_value = torch.linspace(-torch.pi, torch.pi, num_classes)
outputs_to_value = torch.stack([torch.cos(outputs_to_value), torch.sin(outputs_to_value)], dim=-1)
for epoch in range(1000):
optimizer.zero_grad()
logits = mlp(inputs)
sample = F.gumbel_softmax(logits, tau=1, hard=True, dim=-1)
sample.retain_grad()
sample_values = sample @ outputs_to_value
loss = F.mse_loss(sample_values, outputs_to_value[outputs])
loss.backward(retain_graph=True)
if epoch == 0 or (epoch + 1) % 100 == 0:
print(f'Epoch {epoch} Loss: {loss.item()}')
print(f'\t Grad Sum: {sample.grad.abs().sum()}')
optimizer.step()
print("Prediction Counts:")
print(scatter(torch.ones(logits.shape[0]), logits.argmax(dim=-1), dim=0, reduce='sum', dim_size=num_classes))
print("Labels Counts:")
print(scatter(torch.ones(outputs.shape[0]), outputs, dim=0, reduce='sum', dim_size=num_classes))
Example 2 Outputs:
Epoch 0 Loss: 0.9673473834991455
Grad Sum: 56.99043655395508
Epoch 99 Loss: 0.5496492981910706
Grad Sum: 39.74175262451172
Epoch 199 Loss: 0.11724147200584412
Grad Sum: 17.508560180664062
Epoch 299 Loss: 0.07937148213386536
Grad Sum: 14.189104080200195
Epoch 399 Loss: 0.06150461733341217
Grad Sum: 12.445345878601074
Epoch 499 Loss: 0.04786432906985283
Grad Sum: 11.127017974853516
Epoch 599 Loss: 0.04411398246884346
Grad Sum: 10.569339752197266
Epoch 699 Loss: 0.0373353511095047
Grad Sum: 9.741028785705566
Epoch 799 Loss: 0.03433249145746231
Grad Sum: 9.20258617401123
Epoch 899 Loss: 0.02726987563073635
Grad Sum: 8.206332206726074
Epoch 999 Loss: 0.0253863837569952
Grad Sum: 8.08150577545166
Prediction Counts:
tensor([ 0., 0., 24., 0., 35., 2., 4., 0., 15., 16., 42., 0., 0., 2.,
60., 0., 6., 17., 1., 19., 24., 17., 0., 15., 0., 32., 33., 16.,
0., 0., 0., 66., 0., 0., 38., 0., 8., 1., 4., 30., 45., 2.,
0., 31., 6., 16., 0., 50., 10., 1., 14., 0., 0., 8., 16., 40.,
17., 1., 18., 35., 0., 7., 0., 9., 25., 0., 57., 2., 1., 0.,
62., 0.])
Labels Counts:
tensor([16., 8., 11., 11., 10., 9., 15., 17., 12., 21., 13., 11., 10., 13.,
15., 15., 15., 20., 9., 14., 13., 16., 14., 14., 15., 14., 18., 13.,
16., 19., 13., 13., 18., 9., 11., 16., 16., 13., 10., 20., 18., 11.,
20., 13., 12., 13., 17., 15., 10., 17., 13., 4., 16., 16., 8., 12.,
17., 20., 14., 16., 10., 13., 8., 12., 16., 21., 22., 15., 10., 10.,
14., 11.])
I know the issue likely lies in the magnitudes of the gradients from F.gumbel_softmax as they differ between the two but I’m having trouble seeing why this would make such a large difference. Am I missing something obvious? Any insight would be appreciated!