Why I can't update the embedding layer?


(bdy) #1

here is my code:

i want to know why the embedding layer has grad but cannot update

import numpy as np
import random
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import Linear
from SSL_data import SSLDataset
from sklearn.metrics import roc_auc_score

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)


class SSL(nn.Module):
    def __init__(self, embedding_size, layer_size, cat_num, one_num):
        super(SSL, self).__init__()
        self.embedding_size = embedding_size
        self.cat_num = cat_num
        self.one_num = one_num
        self.cat_table = []
        self.one_table = []
        self.softmax = nn.Softmax(dim=2)
        for i in range(cat_num):
            self.cat_table.append(nn.Embedding(10, embedding_size))
        for i in range(one_num):
            self.one_table.append(nn.Embedding(2, embedding_size, padding_idx=0))
        self.layer1 = Linear(2*embedding_size, layer_size)
        self.layer2 = Linear(layer_size, 1)

    def forward(self, cat_feature, one_feature):
        cat_embedding = []
        for i in range(self.cat_num):
            cat_embedding.append(self.cat_table[i](cat_feature[:, i].long()))
        cat_embedding = torch.stack(cat_embedding, dim=2)

        one_embedding = []
        for i in range(self.one_num):
            one_embedding.append(self.one_table[i](one_feature[:, i].long()))
        one_embedding = torch.stack(one_embedding, dim=2)
        one_embedding = F.avg_pool2d(one_embedding, (1, self.one_num))
        one_embedding_repeat = one_embedding.expand(one_embedding.size(0), self.embedding_size, self.cat_num)

        attention = torch.bmm(one_embedding_repeat, cat_embedding.transpose(1, 2))
        attention = self.softmax(attention)

        basic_embedding = torch.bmm(attention, cat_embedding)
        basic_embedding = F.avg_pool2d(basic_embedding, (1, self.cat_num)).squeeze(2)
        one_embedding = one_embedding.squeeze(2)
        user_embedding = torch.cat((basic_embedding, one_embedding), 1)
        if user_embedding.size(0) > 1000:
            print(self.one_table[0].weight)
        #     print(user_embedding[0, :])
        #     print(self.cat_table[0].weight.grad) # has grad
        #     print(self.cat_table[0].weight)         # cannot update

        x = torch.sigmoid(self.layer1(user_embedding))
        x = torch.sigmoid(self.layer2(x))
        return x, user_embedding


def train_model():
    model = SSL(
        embedding_size=20,
        layer_size=10,
        cat_num=8,
        one_num=20
    )

    train_loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False)
    # weight_decay=
    optimizer = Adam(model.parameters(), lr=lr)

    for epoch in range(10):
        print('epoch: {}'.format(epoch))
        for index in range(binary_classify_number):
            for batch_idx, (cat, bin) in enumerate(train_loader):
                label = bin[:, index].float().view(-1, 1)
                predict, _ = model(cat, bin)
                predict = predict.view(-1, 1)

                bce = nn.BCELoss()
                loss = bce(predict, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            category = dataset.x[:, :8].copy()
            binary = dataset.x[:, 8:].copy()
            labels = binary[:, index].tolist()
            cat = torch.from_numpy(category)
            bin = torch.from_numpy(binary)
            predict, _ = model(cat, bin)
            predict = predict.view(-1, 1).detach().numpy().tolist()
            print(index, roc_auc_score(labels, predict))

    _, latent = model(
        torch.from_numpy(dataset.x[:, :8]),
        torch.from_numpy(dataset.x[:, 8:]))
    np.savetxt(embedding_path, latent.detach().numpy(), fmt='%0.8f')


if __name__ == "__main__":
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    embedding_path = 'ssl/benci_latent.txt'
    dataset = SSLDataset()
    batch_size = 256
    lr = 0.001
    binary_classify_number = 1
    fnum = dataset.get_feature_num()

    train_model()

#2

Try

for i in range(cat_num):
    self.cat_table.append(nn.Embedding(10, embedding_size))
for i in range(one_num):
    self.one_table.append(nn.Embedding(2, embedding_size, padding_idx=0))
self.cat_table = torch.nn.ModuleList(self.cat_table)
self.one_table = torch.nn.ModuleList(self.one_table)

(bdy) #3

that’s it, thx:smiley::smiley::smiley: