I am pretty new in Pytorch and is trying to build a network with embedding for float type value.
I am mixing some numerical features with the the category features so they are not all integers.
When I run the embedding, I get the following error: Expected tensor for argument #1 ‘indices’ to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor
And if I change the type for those category features as int, the numerical features part will give me issues
Is there any chance that I can solve the issue?
Here is my code:
class EmbeddingAmodule(nn.Module):
def __init__(self):
super().__init__()
self.Emb = torch.nn.Embedding(344, 16)
def forward(self, x):
x = self.Emb(x)
return x
class EmbeddingBmodule(nn.Module):
def __init__(self):
super().__init__()
self.Emb = torch.nn.Embedding(344, 16)
def forward(self, x):
x = self.Emb(x)
return x
class Env(nn.Module):
def __init__(self):
super().__init__()
self.env = nn.Linear(62,62)
def forward(self, x):
x = F.relu(self.env(x))
return x
class Combinednetwork(nn.Module):
def __init__(self, modelA1, modelB1, modelenv):
super().__init__()
self.modelA1 = modelA1
self.modelB1 = modelB1
self.modelenv = modelenv
self.fc2 = nn.Linear(262,256, bias=False)
...
self.Logit = nn.Linear(64,1)
def forward(self, xa1, xa2, xa3, xa4, xa5, xa6, xb1, xb2, xb3, xb4, xb5, xb6, xenv):
xa1 = self.modelA1(xa1)
xa2 = self.modelA1(xa2)
...
xb5 = self.modelB1(xb5)
xb6 = self.modelB1(xb6)
xenv = self.modelenv(xenv)
x = torch.cat((xa1, xa2, xa3, xa4, xa5, xa6, xb1, xb2, xb3, xb4, xb5, xb6, xenv), dim=0)
...
x = self.fc5(x)
x = self.bn5(x) # layer 5
value = self.Logit(x)
return torch.tanh(value)
params = {'batch_size': 64,
'num_workers': 6}
max_epochs = 10
dataset = CustomIterableDatasetv1('cleaned.csv')
dataloader = DataLoader(dataset, **params, pin_memory=True)
modela = EmbeddingAmodule().to(device)
modelb = EmbeddingBmodule().to(device)
modelenv = Env().to(device)
model = Combinednetwork(modela,modelb, modelenv)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
criterion2 = torch.nn.MSELoss()