Learning directly from F.gumbel_softmax samples - weird gradients?

Hi there,

I am debugging a piece of a much larger project which aims to use the Gumbel-softmax function to draw samples from a categorical distribution of angles between [-pi, pi] which are used downstream to build 3D coordinates for an eventual MSE loss on those coordinates. Obviously using a cross-entropy loss on the logits directly learns the task but I set the below examples up as a proxy for my downstream task which I ideally don’t want to use cross-entropy loss for.

I noticed that when training the larger model, the logits going into the gumbel_softmax function were converging to the same values regardless of their inputs and produced the same samples regardless of the temperature so I set up these examples to test why this might be and am confused by why these two implementations result in one model able to learn from the Gumbel-softmax samples while one model is not.

Example 1: This model is unable to learn with a loss that does not decrease and samples which converge to a single value or values:

torch.random.manual_seed(0)
num_data = 1000
input_dim = 100
num_classes = 72

mlp = nn.Sequential(
    nn.Linear(input_dim, 128), 
    nn.GELU(),
    nn.Linear(128, 128),
    nn.GELU(),
    nn.Linear(128, num_classes)
)
mlp.train()

inputs = torch.randn((num_data, input_dim))
outputs = torch.randint(0, num_classes, (num_data,))

optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
outputs_to_value = torch.linspace(-torch.pi, torch.pi, num_classes).unsqueeze(-1)

for epoch in range(1000):
    optimizer.zero_grad()

    logits = mlp(inputs)
    sample = F.gumbel_softmax(logits, tau=1, hard=True, dim=-1)
    sample.retain_grad()
    sample_values = sample @ outputs_to_value

    cos_loss = F.mse_loss(torch.cos(sample_values), torch.cos(outputs_to_value[outputs]))
    sin_loss = F.mse_loss(torch.sin(sample_values), torch.sin(outputs_to_value[outputs]))
    loss = cos_loss + sin_loss

    loss.backward(retain_graph=True)
    if epoch == 0 or (epoch + 1) % 100 == 0:
        print(f'Epoch {epoch} Loss: {loss.item()}')
        print(f'\t Grad Sum: {sample.grad.abs().sum()}')

    optimizer.step()

print("Prediction Counts:")
print(scatter(torch.ones(logits.shape[0]), logits.argmax(dim=-1), dim=0, reduce='sum', dim_size=num_classes))
print("Labels Counts:")
print(scatter(torch.ones(outputs.shape[0]), outputs, dim=0, reduce='sum', dim_size=num_classes))

Example 1 Outputs:

Epoch 0 Loss: 1.934694766998291
	 Grad Sum: 145.82659912109375
Epoch 99 Loss: 1.9157741069793701
	 Grad Sum: 144.36563110351562
Epoch 199 Loss: 2.0428879261016846
	 Grad Sum: 144.54237365722656
Epoch 299 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Epoch 399 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Epoch 499 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Epoch 599 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Epoch 699 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Epoch 799 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Epoch 899 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Epoch 999 Loss: 2.023682117462158
	 Grad Sum: 145.2460479736328
Prediction Counts:
tensor([474.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 526.])
Labels Counts:
tensor([16.,  8., 11., 11., 10.,  9., 15., 17., 12., 21., 13., 11., 10., 13.,
        15., 15., 15., 20.,  9., 14., 13., 16., 14., 14., 15., 14., 18., 13.,
        16., 19., 13., 13., 18.,  9., 11., 16., 16., 13., 10., 20., 18., 11.,
        20., 13., 12., 13., 17., 15., 10., 17., 13.,  4., 16., 16.,  8., 12.,
        17., 20., 14., 16., 10., 13.,  8., 12., 16., 21., 22., 15., 10., 10.,
        14., 11.])

Example 2: While this model is able to learn with a decreasing loss and improved performance with more epochs

torch.random.manual_seed(0)
num_data = 1000
input_dim = 100
num_classes = 72

mlp = nn.Sequential(
    nn.Linear(input_dim, 128), 
    nn.GELU(),
    nn.Linear(128, 128),
    nn.GELU(),
    nn.Linear(128, num_classes)
)
mlp.train()

inputs = torch.randn((num_data, input_dim))
outputs = torch.randint(0, num_classes, (num_data,))

optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
outputs_to_value = torch.linspace(-torch.pi, torch.pi, num_classes)
outputs_to_value = torch.stack([torch.cos(outputs_to_value), torch.sin(outputs_to_value)], dim=-1)

for epoch in range(1000):
    optimizer.zero_grad()

    logits = mlp(inputs)
    sample = F.gumbel_softmax(logits, tau=1, hard=True, dim=-1)
    sample.retain_grad()
    sample_values = sample @ outputs_to_value

    loss = F.mse_loss(sample_values, outputs_to_value[outputs])
    loss.backward(retain_graph=True)

    if epoch == 0 or (epoch + 1) % 100 == 0:
        print(f'Epoch {epoch} Loss: {loss.item()}')
        print(f'\t Grad Sum: {sample.grad.abs().sum()}')

    optimizer.step()

print("Prediction Counts:")
print(scatter(torch.ones(logits.shape[0]), logits.argmax(dim=-1), dim=0, reduce='sum', dim_size=num_classes))
print("Labels Counts:")
print(scatter(torch.ones(outputs.shape[0]), outputs, dim=0, reduce='sum', dim_size=num_classes))

Example 2 Outputs:

Epoch 0 Loss: 0.9673473834991455
	 Grad Sum: 56.99043655395508
Epoch 99 Loss: 0.5496492981910706
	 Grad Sum: 39.74175262451172
Epoch 199 Loss: 0.11724147200584412
	 Grad Sum: 17.508560180664062
Epoch 299 Loss: 0.07937148213386536
	 Grad Sum: 14.189104080200195
Epoch 399 Loss: 0.06150461733341217
	 Grad Sum: 12.445345878601074
Epoch 499 Loss: 0.04786432906985283
	 Grad Sum: 11.127017974853516
Epoch 599 Loss: 0.04411398246884346
	 Grad Sum: 10.569339752197266
Epoch 699 Loss: 0.0373353511095047
	 Grad Sum: 9.741028785705566
Epoch 799 Loss: 0.03433249145746231
	 Grad Sum: 9.20258617401123
Epoch 899 Loss: 0.02726987563073635
	 Grad Sum: 8.206332206726074
Epoch 999 Loss: 0.0253863837569952
	 Grad Sum: 8.08150577545166
Prediction Counts:
tensor([ 0.,  0., 24.,  0., 35.,  2.,  4.,  0., 15., 16., 42.,  0.,  0.,  2.,
        60.,  0.,  6., 17.,  1., 19., 24., 17.,  0., 15.,  0., 32., 33., 16.,
         0.,  0.,  0., 66.,  0.,  0., 38.,  0.,  8.,  1.,  4., 30., 45.,  2.,
         0., 31.,  6., 16.,  0., 50., 10.,  1., 14.,  0.,  0.,  8., 16., 40.,
        17.,  1., 18., 35.,  0.,  7.,  0.,  9., 25.,  0., 57.,  2.,  1.,  0.,
        62.,  0.])
Labels Counts:
tensor([16.,  8., 11., 11., 10.,  9., 15., 17., 12., 21., 13., 11., 10., 13.,
        15., 15., 15., 20.,  9., 14., 13., 16., 14., 14., 15., 14., 18., 13.,
        16., 19., 13., 13., 18.,  9., 11., 16., 16., 13., 10., 20., 18., 11.,
        20., 13., 12., 13., 17., 15., 10., 17., 13.,  4., 16., 16.,  8., 12.,
        17., 20., 14., 16., 10., 13.,  8., 12., 16., 21., 22., 15., 10., 10.,
        14., 11.])

I know the issue likely lies in the magnitudes of the gradients from F.gumbel_softmax as they differ between the two but I’m having trouble seeing why this would make such a large difference. Am I missing something obvious? Any insight would be appreciated!