here is my code:
i want to know why the embedding layer has grad but cannot update
import numpy as np
import random
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import Linear
from SSL_data import SSLDataset
from sklearn.metrics import roc_auc_score
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
class SSL(nn.Module):
def __init__(self, embedding_size, layer_size, cat_num, one_num):
super(SSL, self).__init__()
self.embedding_size = embedding_size
self.cat_num = cat_num
self.one_num = one_num
self.cat_table = []
self.one_table = []
self.softmax = nn.Softmax(dim=2)
for i in range(cat_num):
self.cat_table.append(nn.Embedding(10, embedding_size))
for i in range(one_num):
self.one_table.append(nn.Embedding(2, embedding_size, padding_idx=0))
self.layer1 = Linear(2*embedding_size, layer_size)
self.layer2 = Linear(layer_size, 1)
def forward(self, cat_feature, one_feature):
cat_embedding = []
for i in range(self.cat_num):
cat_embedding.append(self.cat_table[i](cat_feature[:, i].long()))
cat_embedding = torch.stack(cat_embedding, dim=2)
one_embedding = []
for i in range(self.one_num):
one_embedding.append(self.one_table[i](one_feature[:, i].long()))
one_embedding = torch.stack(one_embedding, dim=2)
one_embedding = F.avg_pool2d(one_embedding, (1, self.one_num))
one_embedding_repeat = one_embedding.expand(one_embedding.size(0), self.embedding_size, self.cat_num)
attention = torch.bmm(one_embedding_repeat, cat_embedding.transpose(1, 2))
attention = self.softmax(attention)
basic_embedding = torch.bmm(attention, cat_embedding)
basic_embedding = F.avg_pool2d(basic_embedding, (1, self.cat_num)).squeeze(2)
one_embedding = one_embedding.squeeze(2)
user_embedding = torch.cat((basic_embedding, one_embedding), 1)
if user_embedding.size(0) > 1000:
print(self.one_table[0].weight)
# print(user_embedding[0, :])
# print(self.cat_table[0].weight.grad) # has grad
# print(self.cat_table[0].weight) # cannot update
x = torch.sigmoid(self.layer1(user_embedding))
x = torch.sigmoid(self.layer2(x))
return x, user_embedding
def train_model():
model = SSL(
embedding_size=20,
layer_size=10,
cat_num=8,
one_num=20
)
train_loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False)
# weight_decay=
optimizer = Adam(model.parameters(), lr=lr)
for epoch in range(10):
print('epoch: {}'.format(epoch))
for index in range(binary_classify_number):
for batch_idx, (cat, bin) in enumerate(train_loader):
label = bin[:, index].float().view(-1, 1)
predict, _ = model(cat, bin)
predict = predict.view(-1, 1)
bce = nn.BCELoss()
loss = bce(predict, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
category = dataset.x[:, :8].copy()
binary = dataset.x[:, 8:].copy()
labels = binary[:, index].tolist()
cat = torch.from_numpy(category)
bin = torch.from_numpy(binary)
predict, _ = model(cat, bin)
predict = predict.view(-1, 1).detach().numpy().tolist()
print(index, roc_auc_score(labels, predict))
_, latent = model(
torch.from_numpy(dataset.x[:, :8]),
torch.from_numpy(dataset.x[:, 8:]))
np.savetxt(embedding_path, latent.detach().numpy(), fmt='%0.8f')
if __name__ == "__main__":
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
embedding_path = 'ssl/benci_latent.txt'
dataset = SSLDataset()
batch_size = 256
lr = 0.001
binary_classify_number = 1
fnum = dataset.get_feature_num()
train_model()