def predict_batchwise(m, test_dataloader,device):
with torch.no_grad():
X= torch.ones([0, 64]).float()
X = X.to(device)
Y=[]
for i, data in enumerate(test_dataloader,0):
img0,label0 = data
img0= img0.to(device)
label0 = label0.to(device)
# pre trained resnet feature extractor
out0= m(img0)
for j in range(64):
actual_lab0 = label0[j]
Y.append([actual_lab0.item()])
X = torch.cat((out0.data, X), 0)
return torch.stack(X), torch.stack(Y)
error is
35 # calculate embeddings with model, also get labels (non-batch-wise)
—> 36 X, T = predict_batchwise(model, dataloader,device)
37
38 # calculate NMI with kmeans clustering
~/MoonVision/vid-api/vid-vision-ankit/src/graphs/pytorch/adversarial_network/proxy_nca/utils.py in predict_batchwise(m, test_dataloader, device)
9 X = X.to(device)
10 Y=[]
—> 11 for i, data in enumerate(test_dataloader,0):
12
13 img0,label0 = data
~/miniconda3/envs/fat-ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py in next(self)
613 if self.num_workers == 0: # same-process loading
614 indices = next(self.sample_iter) # may raise StopIteration
–> 615 batch = self.collate_fn([self.dataset[i] for i in indices])
616 if self.pin_memory:
617 batch = pin_memory_batch(batch)
~/miniconda3/envs/fat-ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py in default_collate(batch)
230 elif isinstance(batch[0], container_abcs.Sequence):
231 transposed = zip(*batch)
–> 232 return [default_collate(samples) for samples in transposed]
233
234 raise TypeError((error_msg.format(type(batch[0]))))
~/miniconda3/envs/fat-ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py in (.0)
230 elif isinstance(batch[0], container_abcs.Sequence):
231 transposed = zip(*batch)
–> 232 return [default_collate(samples) for samples in transposed]
233
234 raise TypeError((error_msg.format(type(batch[0]))))
~/miniconda3/envs/fat-ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py in default_collate(batch)
207 storage = batch[0].storage().new_shared(numel)
208 out = batch[0].new(storage)
–> 209 return torch.stack(batch, 0, out=out)
210 elif elem_type.module == ‘numpy’ and elem_type.name != 'str’
211 and elem_type.name != ‘string_’:
Please help me to figure out how to solve this