Object has no attribute 'train'

Hi just wonder do you know why this happen I defined the train in begiining and when i want to try to use it error pop up


class LFW1(object):#orginal
    def __init__(self, root, specific_folder, img_extension, preprocessing_method=None, crop_size=(96, 112),train = True):
        """
        Dataloader of the LFW dataset.

        root: path to the dataset to be used.
        specific_folder: specific folder inside the same dataset.
        img_extension: extension of the dataset images.
        preprocessing_method: string with the name of the preprocessing method.
        crop_size: retrieval network specific crop size.
        """

        self.preprocessing_method = preprocessing_method
        self.crop_size = crop_size
        self.imgl_list = []
        self.classes = []
        self.people = []
        self.model_align = None

        # read the file with the names and the number of images of each people in the dataset
        with open(os.path.join(root, 'people.txt')) as f:
            people = f.read().splitlines()[1:]

        # get only the people that have more than 20 images
        for p in people:
            p = p.split('\t')
            if len(p) > 1:
                if int(p[1]) >= 20:
                    for num_img in range(1, int(p[1]) + 1):
                        self.imgl_list.append(os.path.join(root, specific_folder, p[0], p[0] + '_' +
                                                           '{:04}'.format(num_img) + '.' + img_extension))
                        self.classes.append(p[0])
                        self.people.append(p[0])

        le = preprocessing.LabelEncoder()
        self.classes = le.fit_transform(self.classes)
        #print(len(self.imgl_list), len(self.classes), len(self.people))
        self.i_train,self.i_test,self.c_train,self.c_test=train_test_split(self.imgl_list,self.classes,test_size=0.2)
        print(len(self.i_train), len(self.i_test), len(self.c_train), len(self.c_test))
       
        

    def __getitem__(self, index):
      if self.train:
        imgl = imageio.imread(self.i_train[index])
        cl = self.c_train[index]

        # if image is grayscale, transform into rgb by repeating the image 3 times
        if len(imgl.shape) == 2:
            imgl = np.stack([imgl] * 3, 2)

        imgl, bb = preprocess(imgl, self.preprocessing_method, crop_size=self.crop_size,
                              is_processing_dataset=True, return_only_largest_bb=True, execute_default=True)

        # append image with its reverse
        imglist = [imgl, imgl[:, ::-1, :]]

        # normalization
        for i in range(len(imglist)):
            imglist[i] = (imglist[i] - 127.5) / 128.0
            imglist[i] = imglist[i].transpose(2, 0, 1)
        imgs = [torch.from_numpy(i).float() for i in imglist]
        
       

        return imgs, cl
        

      else:
        imgl1 = imageio.imread(self.i_test[index])
        cl1 = self.c_test[index]

        # if image is grayscale, transform into rgb by repeating the image 3 times
        if len(imgl1.shape) == 2:
            imgl1 = np.stack([imgl1] * 3, 2)

        imgl1, bb1 = preprocess(imgl1, self.preprocessing_method, crop_size=self.crop_size,
                              is_processing_dataset=True, return_only_largest_bb=True, execute_default=True)

        # append image with its reverse
        imglist1 = [imgl1, imgl1[:, ::-1, :]]

        # normalization
        for i in range(len(imglist1)):
            imglist1[i] = (imglist1[i] - 127.5) / 128.0
            imglist1[i] = imglist1[i].transpose(2, 0, 1)
        imgs1 = [torch.from_numpy(i).float() for i in imglist]
        
        return imgs1, cl1
        

    def __len__(self):
        return len((self.imgl_list),)

I defined that train = true for train set and train=False for test data when i try to load the data error up up

dataset1 = LFW1('/content/drive/My Drive/recfaces13/recfaces/datasets/LFW/try','lfw', 'jpg', 'sphereface', (96, 112),train=True)
dataloader1 = torch.utils.data.DataLoader(dataset1, batch_size=1, shuffle=False, num_workers=2, drop_last=False)
for i, data in enumerate(dataloader1):
  inps, labs = data
  print(len(inps))

it said unexpected keyword argument ‘train’’ im very confuse i defind the train since beginning did i do something wrong here ?

Hy @Archy_dragon, why are you setting train = true or false, just call

#for training
model.train()
#for eval
model.eval()

@Usama_Hasan I dont think it has anything to do with nn.Module() but how the dataset is initialized

@Archy_dragon It would be helpful if you could provide the error log. Why is the LFW1 dataset not inheriting the Dataset class?

this is full log error

AttributeError Traceback (most recent call last)
in ()
1 x=[]
----> 2 for i, data in enumerate(dataloader1):
3 inps, labs = data
4 x.append
5

3 frames
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
393 # (https://bugs.python.org/issue2651), so we work around it.
394 msg = KeyErrorMessage(msg)
–> 395 raise self.exc_type(msg)

AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py”, line 185, in _worker_loop
data = fetcher.fetch(index)
File “/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py”, line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py”, line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File “”, line 45, in getitem
if self.train:
AttributeError: ‘LFW1’ object has no attribute ‘train’

so when it true it will be train and fauls will be test data, So i want to make a function that if i put data or path way in data will load

so when it true will be train and false will be test can you tell me how the model,train() and model,eval() work?

Hy @Archy_dragon, Sorry for replying late. when we call model.eval() some layers like batch normalization etc behave differently. Also it’s a good practice to use it during evaluation.

Can you add your code inside ‘’’ ‘’’ . It will be easy to debug it then

Hy @bsridatta. You’re right Man, I made the comment considering the heading of the issue.

  1. Assume the dataset creation is wrong. Though I dont see why it says the argument is not valid
  2. If its wrong then, the error should be in

and not here

So I would it has something to do with the dataloader too. That is also the reason for me asking why does the dataset class not inherit torch Dataset class? Could you please change and she if the error still exists?

where to change? sorry pretty confuse

Here, just check the custom dataset tutorial, you can have it as an object but try,

from torch.utils.data import Dataset    
class LFW1(Dataset)

thank you will give it a try come back to you