Torch "Dataset" class quirks?

Im having a hard time understanding the torch.utils.data.Dataset class, it doesn’t behave like a regular python class. I put below a simple example (output results are commented). In short:

  • Dataset cannot access custom methods or variables regularly (see TestDataset)
  • But a class that inherits from Dataset can ! (see TestDataset2)
    Is there something im missing (static methods,…). Thanks in advance.
import torch
from torch.utils.data import Dataset


# A basic python class
class TestObject():
    def __init__(self, a):
        self.a = a
    def get_a(self):
        return self.a
# print("testobject")
# testobject=TestObject(99)
# print(list(testobject.__dict__.keys()))# ['a']
# print(testobject.a)# 99
# print(testobject.get_a())# 99 

# A torch "Dataset" class: cant access attributes/methods normally
class TestDataset(Dataset):
    def __init__(self, a):
        self.a = a
    def get_a(self):
        return self.a
    def __len__(self):
        return 1
    def __getitem__(self, idx):
        return torch.randn(1,1,1)
# print("testdataset")
# testdataset=TestDataset(99)
# print(list(testdataset.__dict__.keys()))# ['a']
# print(testdataset.a)# AttributeError: type object 'TestDataset' has no attribute 'a'
# print(testdataset.get_a())# TypeError: TestDataset.get_a() missing 1 required positional argument: 'self'

## A basic python inherited class
class TestObject2(TestObject):
    def __init__(self, a,b):
        super().__init__(a)
        self.b = b
    def get_b(self):
        return self.b
# print("testobject2")
# testobject2=TestObject2(99,42)
# print(testobject2.a)# 99
# print(testobject2.b)# 42

# A torch "Dataset" inherited class: now everything works normally !?
class TestDataset2(TestDataset):
    def __init__(self, a,b):
        super().__init__(a)
        self.b = b
    def get_b(self):
        return self.b
# print("testdataset2")
# testdataset2=TestDataset2(99,42)
# print(list(testdataset2.__dict__.keys()))# ['a','b']
# print(testdataset2.a)# 99, it works now !?!
# print(testdataset2.get_a())# 99
# print(testdataset2.b)# 42
# print(testdataset2.get_b())# 42

I don’t get any AttributeError or TypeError running your code snippet and get:

testobject
['a']
99
99
testdataset
['a']
99
99
testobject2
99
42
testdataset2
['a', 'b']
99
99
42
42

Thanks for the quick reply! Well i feel stupid, it works after cleaning up my notebook and retesting. I probably had a tipo or just named something “Dataset”.