Hi,
I am using multiple embedding layers for categorical values and I want to concatenate them with dense values. On CPUs my code works fine, but I keep getting the following error when running on GPU (single gpu idx:0):
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 ‘index’
Model is here:
class CategoricalEmbedFNN(nn.Module):
def __init__(self, config, categ_embs_dims):
super().__init__()
self.config = config
self.hidden_size = self.config["hidden_size"]
self.hidden_layers = self.config["hidden_layers"]
self.activation_fn = nn.ReLU(inplace=True)
self.input_dim = len(self.config["dense_features"])
self.embedding_layers = {}
for k, v in categ_embs_dims.items():
self.embedding_layers[f"{k}_emb"] = nn.Embedding(v, self.config["cat_embs"][k])
self.input_dim += self.config["cat_embs"][k]
self.input_linear = nn.Linear(self.input_dim, self.hidden_size)
self.middle_linear = nn.Linear(self.hidden_size, self.hidden_size)
self.output_linear = nn.Linear(self.hidden_size, len(self.config["output_features"]))
def forward(self, x, x_cat):
x_cat = [emb_layer(x_cat[:, idx]) for idx, (k, emb_layer) in enumerate(self.embedding_layers.items())]
x_cat = torch.cat(x_cat, 1)
x = torch.cat((x, x_cat), 1)
x = self.input_linear(x)
x = self.activation_fn(x)
for i in range(self.hidden_layers):
x = self.middle_linear(x)
x = self.activation_fn(x)
out = self.output_linear(x)
return out
I initialize CUDA like this:
self.device = torch.device("cuda")
torch.cuda.set_device(0)
self.model = self.model.to(self.device)
self.loss = self.loss.to(self.device)
Additionally, in batching (with tqdm), I run with the following code:
for X_batch, X_cat, y_batch in tqdm_batch:
# Put data on device
X_batch, X_cat, y_batch = X_batch.to(self.device), X_cat.to(self.device), y_batch.to(self.device)
# Make predictions
self.optimizer.zero_grad()
y_pred = self.model(X_batch, X_cat)
Error comes from the line
x_cat = [emb_layer(x_cat[:, idx]) for idx, (k, emb_layer) in enumerate(self.embedding_layers.items())]
What is the issue and how can I solve it ?