Use a pretrained model as a loss

The goal is to train a multi feature generative model using the guidance of another model, an score estimator in this example.


Categorical dataset consisting of N features. Each feature has its own set attributes.
A numerical feature acts as the score of the combination of feature attributes.

Pretrained Estimator Model

I trained a regression model on this data. Such as it takes as input the whole row and outputs the estimated score of it.
The whole row is represented as the concatenation of the embeddings of each feature attribute.

estimator = EstimatorNet()

row = torch.concat([embeddings["feature_1"], embeddings["feature_2"],...])

estimated_score = estimator(row)

The goal here is to have the estimator capture the complex feature interactions as well as predict the associated score.

To train Generator Model

The goal is to generate features given a context, such that the score is optimized. Kind of a recommendation problem where you can’t afford to just run the estimator on every possible feature combination and then do a top k ranking. So we have to extract ‘‘scoring’’ knowledge from the estimator to guide the generator during its training and push it to generate high scoring combinations.

I already trained such generator and it works okay. The last layers of my generator are as follows

latent = first_layers(context)
feature1_logits = fully_connected_1(latent)
feature2_logits = fully_connected_2(latent)

generated_feature1 = feature1_logits.argmax()
generated_feature2 = feature2_logits.argmax()

I used logits and argmax for demonstration purposes even though in reality I used softmax and then argmax.

My idea during training is that when the generator encounters a high scoring row, it tries to reproduce it. And when it encounters a bad scoring row it tries to avoid it AND come up with one that is estimated to be high scoring.


I used cross entropy loss with softmax between row_feature and generated_feature then summed over all the features. Weighted by the inverse estimated score to give high loss for low scoring generations.

loss1 = CrossEntropyLoss(feature1_logits.softmax(), row_feature1)
loss2 = CrossEntropyLoss(feature2_logits.softmax(), row_feature2)

reproductionloss = (loss1 + loss2)
loss = reproductionloss/estimator(row_features)

The problem

So far so good. Even though the estimator’s first layer is an embedding layer ( breaks the gradients flow ) as well as having argmax in the last layer of the generator. The loss used to trained the generator is differentiable.

However, the feature interactions are so complex that a single feature swap converts the whole row from good scoring to bad scoring. Therefore, naive summation of cross entropy losses fails miserably in such cases as it gives almost the same reproductionloss. The only thing that changes is the estimated score, which is not differentiable.

My idea is to learn a new metric such that it is capable of detecting if the feature swap is meaningful enough.

I trained row_features aligner such that it projects the embeddings from the estimator in a new representation in which we can use the cosine similarity to define if a given row is similar to another in terms of behaviour ( not only in terms of absolute feature attribute matching )

class Aligner(nn.Module):
    def __init__(self, estimator_model):
        self.estimator_model = estimator_model
        self.projection_head = nn.Linear(64, 64)
    def forward(row1, row2):
        latent1 = self.estimator_model.to_latent(row1)
        latent2 = self.estimator_model.to_latent(row2)
        projection1 = self.projection_head(latent1)
        projection2 = self.projection_head(latent2)
        CosineSimilarity(projection1, projection2)
aligner = Aligner(estimator)
cosine = aligner(row1, row2)

This Aligner model actually captures differences better than the sum of cross entropies.

However, I have no idea how to plug it to the generator model during training without breaking the gradients flow. Since this Aligner also starts with an embedding layer ( from the Pretrained Estimator ), and if I completely remove the cross entropy loss from the generator loss then nothing will be differentiable.

My current Idea is to feed softmaxes directly from the generator to the aligner, and use the softmax as a weighted sum of each feature embeddings in the aligner.

I tried a straight through estimator approach to bypass the argmax+embedding block, the gradients do flow but I suspect it is meaningless ( since the gradients at the embeddings level indicate how each dimension in the embeddings should move and not which other embedding to switch to )

I also tried to code my own version of embeddings with one hot indexing ( linear layer instead of lookup table of embeddings ) coupled with a differentiable version of argmax / gumbell softmax.

Do you people have any other suggestions on how you would approach such?