here is my custom dict:
class FeedData(UserDict):
def __init__(self, **kwargs):
super().__init__(kwargs)
def __getitem__(self, key):
if key not in self.data:
return None
return self.data[key]
def to(self, device):
for key, value in self.data.items():
self.data[key] = value.to(device)
def __getattr__(self, item):
data = super().__getattribute__("data")
if item not in data:
raise AttributeError
return data[item]
and my collate_fn:
def _collate_fn(self, batch):
feed_data = FeedData()
if self.sparse_feat:
sparse = torch.stack([default_collate([d[key] for d in batch]) + self._offset[key]
for key in self.sparse_feat], dim=1)
feed_data["sparse"] = sparse
if self.dense_feat:
dense = torch.stack([default_collate([d[key] for d in batch])
for key in self.dense_feat], dim=1)
feed_data["dense"] = dense
if self.variable_feat:
varlen_feat = FeedData(**{
"var_tensor": FeedData(),
"var_length": FeedData()
})
for key in self.variable_feat:
seq_tensor, seq_lengths = self._pack_variable([d[key] for d in batch])
varlen_feat.var_tensor[key] = seq_tensor
varlen_feat.var_length[key] = seq_lengths
feed_data["varlen_feat"] = varlen_feat
label = default_collate([d["label"] for d in batch])
feed_data["label"] = label
print(type(feed_data))
return feed_data
and my test code:
dataset = FMDataset(file_list)
dataloader = FMDataloader(dataset, config)
for data in dataloader:
print(type(data))
exit(-1)
and the output:
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'dict'>
the
<class ‘bin.dataset.fm_dataset.FeedData’>
is print in “collate_fn” and
<class ‘dict’>
is the type of data from dataloader
How can I get UserDict from dataloader? Wish your help!
Edit 1:
I have do a mini test:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from collections import UserDict
class MyDataset(Dataset):
def __init__(self):
super().__init__()
self.x = torch.rand(10, 4)
self.y = torch.randint(0, 10, (10,))
def __len__(self):
return len(self.x)
def __getitem__(self, item):
data = self.x[item]
target = self.y[item]
return {"data": data, "target": target}
class MyDict(UserDict):
def __init__(self, **kwargs):
super().__init__(kwargs)
def collate_fn(batch):
ret = MyDict(**default_collate(batch))
return ret
dataset = MyDataset()
dataloader = DataLoader(dataset, collate_fn=collate_fn)
for data in dataloader:
print(type(data))
print(data)
This seem works correctly.
So I’m confuse why my project code only get dict from dataloader.
Edit 2
if I use another class:
class MyData:
pass
and modify my collate_fn:
def _collate_fn(self, batch):
# feed_data = FeedData()
feed_data = MyData()
# if self.sparse_feat:
# sparse = torch.stack([default_collate([d[key] for d in batch]) + self._offset[key]
# for key in self.sparse_feat], dim=1)
# feed_data["sparse"] = sparse
# if self.dense_feat:
# dense = torch.stack([default_collate([d[key] for d in batch])
# for key in self.dense_feat], dim=1)
# feed_data["dense"] = dense
# if self.variable_feat:
# varlen_feat = FeedData(**{
# "var_tensor": FeedData(),
# "var_length": FeedData()
# })
# for key in self.variable_feat:
# seq_tensor, seq_lengths = self._pack_variable([d[key] for d in batch])
# varlen_feat.var_tensor[key] = seq_tensor
# varlen_feat.var_length[key] = seq_lengths
# feed_data["varlen_feat"] = varlen_feat
# label = default_collate([d["label"] for d in batch])
# feed_data["label"] = label
print(type(feed_data))
return feed_data
I can get “MyData” from dataloader.
It’s so weird!