Backpropagation with one-hot encoded feature

I’m trying to one-hot encode a feature in my network, but I’m not sure how to properly update the weights or if this is a good approach at all. Any feedback is appreciated.

My current approach:

A helper class for the encoded feature:

class Encoder(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.fc = torch.nn.Linear(in_features=num_classes, out_features=1)

    def forward(self, x):
        return self.fc(x.float())

My actual network’s forward method assumes the first feature is to be encoded (pclass is the name of the feature & self.pclass_encoder is an instance of the above helper class):

def forward(self, x):
    real_features = x[:, 1:]
    encoded_features = torch.nn.functional.one_hot(x[:, 0].long() - 1)
    sig = torch.nn.Sigmoid()
    pclass_out = self.pclass_encoder.forward(encoded_features)
    x = torch.cat((pclass_out, real_features), 1)
    x = self.fc1(x)
    x = sig(self.out(x))
    return x

In my training loop, I hoped that the gradient would propagate by just adding an extra optimizer for the helper class and stepping it like I do for the main network.

optimzer = torch.optim.Adam(model.parameters(), lr=0.01)
encoder_optimizer = torch.optim.Adam(pclass_encoder.parameters(), lr=.01)

EPOCHS = 4000

for i in range(EPOCHS):
    predictions = model.forward(features)
    loss = loss_fn(predictions, labels)

    optimzer.zero_grad()
    encoder_optimizer.zero_grad()
    loss.backward()
    optimzer.step()
    encoder_optimizer.step()

The weight attribute of the encoded layer does appear to be updating.

To me, the entire thing seems to be end-to-end trainable. Have you tried using only the model optimizer and checking if the weights of Encoder fully connected layer are updated or not?

1 Like

Good call! Originally I only used the one optimizer, but I got sidetracked by my troubleshooting and forgot to revisit that.

The weights of Encoder do in fact get updated with just the model optimizer. This is what I was hoping for and had assumed that the model graph would include anything through which my input tensors were passed.