Im having a hard time understanding the torch.utils.data.Dataset class, it doesn’t behave like a regular python class. I put below a simple example (output results are commented). In short:
- Dataset cannot access custom methods or variables regularly (see TestDataset)
- But a class that inherits from Dataset can ! (see TestDataset2)
Is there something im missing (static methods,…). Thanks in advance.
import torch
from torch.utils.data import Dataset
# A basic python class
class TestObject():
def __init__(self, a):
self.a = a
def get_a(self):
return self.a
# print("testobject")
# testobject=TestObject(99)
# print(list(testobject.__dict__.keys()))# ['a']
# print(testobject.a)# 99
# print(testobject.get_a())# 99
# A torch "Dataset" class: cant access attributes/methods normally
class TestDataset(Dataset):
def __init__(self, a):
self.a = a
def get_a(self):
return self.a
def __len__(self):
return 1
def __getitem__(self, idx):
return torch.randn(1,1,1)
# print("testdataset")
# testdataset=TestDataset(99)
# print(list(testdataset.__dict__.keys()))# ['a']
# print(testdataset.a)# AttributeError: type object 'TestDataset' has no attribute 'a'
# print(testdataset.get_a())# TypeError: TestDataset.get_a() missing 1 required positional argument: 'self'
## A basic python inherited class
class TestObject2(TestObject):
def __init__(self, a,b):
super().__init__(a)
self.b = b
def get_b(self):
return self.b
# print("testobject2")
# testobject2=TestObject2(99,42)
# print(testobject2.a)# 99
# print(testobject2.b)# 42
# A torch "Dataset" inherited class: now everything works normally !?
class TestDataset2(TestDataset):
def __init__(self, a,b):
super().__init__(a)
self.b = b
def get_b(self):
return self.b
# print("testdataset2")
# testdataset2=TestDataset2(99,42)
# print(list(testdataset2.__dict__.keys()))# ['a','b']
# print(testdataset2.a)# 99, it works now !?!
# print(testdataset2.get_a())# 99
# print(testdataset2.b)# 42
# print(testdataset2.get_b())# 42