Hi everyone,
I am looking for some help in fixing an error in my model. I utilized a custom dataloader that looks like this:
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, dat, labels):
self.labels = labels
self.dat = dat
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
dat = self.dat[idx]
sample = {"Sample": dat, "Class": label}
return sample
Then create my train and validation data like this:
train_loader = CustomDataset(X_train, y_train)
valid_loader = CustomDataset(X_test, y_test)
The error arises when I run the model and occurs in my training loop:
def train(model, device, train_loader, valid_loader, epochs, learning_rate):
for idx, batch in enumerate(train_loader):
text = batch["Sample"].to(device)
target = batch['Class'].to(device)
target = torch.autograd.Variable(target).long()**
target = text.to(device), target.to(device)
The error I am getting is as follows:
48 for idx, batch in enumerate(train_loader):
---> 49 text = batch["Sample"].to(device)
50 target = batch['Class'].to(device)
51 print(type(text), text.shape)
TypeError: to() received an invalid combination of arguments - got (DataLoader), but expected one of:
* (torch.device device, torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
* (torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
* (Tensor tensor, bool non_blocking, bool copy, *, torch.memory_format memory_format)
Below is the code I use to initialze the model, train, and test:
set_seed(SEED)
vanilla_rnn_model = VanillaRNN(output_size, input_size, RNN_size, fc_size, DEVICE)
vanilla_rnn_model.to(DEVICE)
vanilla_rnn_start_time = time.time()
vanilla_train_loss, vanilla_train_acc, vanilla_validation_loss, vanilla_validation_acc = train(vanilla_rnn_model,
train_loader,
valid_loader,
DEVICE,
epochs = epochs,
learning_rate = learning_rate
)
Obviously the model does not like the DataLoader, but I am confused on how to remediate the issue. Any help would be appreciated.