Return "UserDict" in collate_fn, but get "Dict" in dataloader

here is my custom dict:

class FeedData(UserDict):
    def __init__(self, **kwargs):
        super().__init__(kwargs)

    def __getitem__(self, key):
        if key not in self.data:
            return None
        return self.data[key]

    def to(self, device):
        for key, value in self.data.items():
            self.data[key] = value.to(device)

    def __getattr__(self, item):
        data = super().__getattribute__("data")
        if item not in data:
            raise AttributeError
        return data[item]

and my collate_fn:

    def _collate_fn(self, batch):
        feed_data = FeedData()

        if self.sparse_feat:
            sparse = torch.stack([default_collate([d[key] for d in batch]) + self._offset[key]
                                  for key in self.sparse_feat], dim=1)
            feed_data["sparse"] = sparse

        if self.dense_feat:
            dense = torch.stack([default_collate([d[key] for d in batch])
                                 for key in self.dense_feat], dim=1)
            feed_data["dense"] = dense

        if self.variable_feat:
            varlen_feat = FeedData(**{
                "var_tensor": FeedData(),
                "var_length": FeedData()
            })

            for key in self.variable_feat:
                seq_tensor, seq_lengths = self._pack_variable([d[key] for d in batch])
                varlen_feat.var_tensor[key] = seq_tensor
                varlen_feat.var_length[key] = seq_lengths
            feed_data["varlen_feat"] = varlen_feat

        label = default_collate([d["label"] for d in batch])
        feed_data["label"] = label

        print(type(feed_data))
        return feed_data

and my test code:

dataset = FMDataset(file_list)
dataloader = FMDataloader(dataset, config)

for data in dataloader:
    print(type(data))
    exit(-1)

and the output:

<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'bin.dataset.fm_dataset.FeedData'>
<class 'dict'>

the

<class ‘bin.dataset.fm_dataset.FeedData’>

is print in “collate_fn” and

<class ‘dict’>

is the type of data from dataloader

How can I get UserDict from dataloader? Wish your help!

Edit 1:

I have do a mini test:

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from collections import UserDict


class MyDataset(Dataset):
    def __init__(self):
        super().__init__()

        self.x = torch.rand(10, 4)
        self.y = torch.randint(0, 10, (10,))

    def __len__(self):
        return len(self.x)

    def __getitem__(self, item):
        data = self.x[item]
        target = self.y[item]

        return {"data": data, "target": target}
    

class MyDict(UserDict):
    def __init__(self, **kwargs):
        super().__init__(kwargs)


def collate_fn(batch):
    ret = MyDict(**default_collate(batch))
    return ret 


dataset = MyDataset()
dataloader = DataLoader(dataset, collate_fn=collate_fn)

for data in dataloader:
    print(type(data))
    print(data)

This seem works correctly.

So I’m confuse why my project code only get dict from dataloader.

Edit 2
if I use another class:

class MyData:
    pass

and modify my collate_fn:

def _collate_fn(self, batch):
        # feed_data = FeedData()
        feed_data = MyData()

        # if self.sparse_feat:
        #     sparse = torch.stack([default_collate([d[key] for d in batch]) + self._offset[key]
        #                           for key in self.sparse_feat], dim=1)
        #     feed_data["sparse"] = sparse

        # if self.dense_feat:
        #     dense = torch.stack([default_collate([d[key] for d in batch])
        #                          for key in self.dense_feat], dim=1)
        #     feed_data["dense"] = dense

        # if self.variable_feat:
        #     varlen_feat = FeedData(**{
        #         "var_tensor": FeedData(),
        #         "var_length": FeedData()
        #     })

        #     for key in self.variable_feat:
        #         seq_tensor, seq_lengths = self._pack_variable([d[key] for d in batch])
        #         varlen_feat.var_tensor[key] = seq_tensor
        #         varlen_feat.var_length[key] = seq_lengths
        #     feed_data["varlen_feat"] = varlen_feat

        # label = default_collate([d["label"] for d in batch])
        # feed_data["label"] = label

        print(type(feed_data))

        return feed_data

I can get “MyData” from dataloader.
It’s so weird!

1 Like

Any suggestion will be appreciate :pray: :pray: