Sampler Error in pytorch

Hi I have been writing a code in Pytorch for my custom sampler. But when I run it, it generates an error.


import torch
from import Sampler

def CustomSampler(Sampler):
    def __init__(self, data_source, batch_size, replacement=False):
        self.data_source = data_source
        self.ids = torch.FloatTensor(data_source.keys())
        self.numSmall = torch.FloatTensor(data_source.values())
        self.normal_numSmall = self.numSmall / self.numSmall.sum()

        self.num_samples = numSmall.numel()
        self.batch_size = batch_size

    def __iter__(self):
        return self.ids[torch.multinomial(self.normal_small, self.batch_size, replacement=False).numpy().tolist()]

    def __len__(self):
        return self.batch_size

id = [11,12,156,256,36,26,21,25]
small = [13,7,6,8,12,15,3,20]

data_source = dict(zip(id, small))
sample = CustomSampler(data_source, batch_size = 3)


TypeError                                 Traceback (most recent call last)
<ipython-input-9-d53f4a02f4a7> in <module>()
----> 1 sample = CustomSampler(data_source, batch_size = 3)

TypeError: CustomSampler() got an unexpected keyword argument 'batch_size'

Could you please tell me what is the source of error? How can I settle this issue?


batch_size is not a keyword argument in your __init__ function, it is just the second one :slight_smile:
Either call it as (data_source, 3) or in the init function use (self, data_source, batch_size=1, replacement=False).

Sorry @albanD. I did not understand. What do you mean?

Yes my answer is not really clear sorry :confused:

When you do this call sample = CustomSampler(data_source, batch_size = 3)
It will give data_source as first argument to the constructor and 3 to the named argument batch_size.

If you look at how you declared your constructor def __init__(self, data_source, batch_size, replacement=False): it has self as argument as always, then a first argument that will be stored in the variable named data_source, a second argument that will be stored in a variable named batch_size and a named argument replacement that will be stored in a variable with the same name and that has default value False.

You try to supply the second argument as a named argument while it is a positional argument in your function definition.

Thanks for your response. I have changed the initialization code to:

sample = CustomSampler(data_source, 3)

but error was the same :frowning:
TypeError Traceback (most recent call last)
in ()
----> 1 sample = CustomSampler(data_source, 3)

TypeError: CustomSampler() takes 1 positional argument but 2 were given

Ok, looking at your code in more details, CustomSampler should be a class :slight_smile:
def CustomSampler(Sampler): -> class CustomSampler(Sampler):

Sorry for my disaster action. :expressionless: :frowning: