Batch balanced with some classes

Hello everyone,
In PyTorch does exist a function to balance a batch using only N classes? I really interested to balance each batch using only some classes in a cyclic way of course, for instance:
Batch 0 [5,5,5,0,0,0] (“5 instances of class 0,1,2, and 0 instances somewhere else”)
Batch 1 [0,0,0,5,5,5]
Epoch finished…

I would like to use this approach because a need to have many instances per class and in the sometime balanced. The reason which i would to have 0 instances in some classes is due to Cuda out of memory…
Can you help me please?

I think the best approach would be to write a custom sampler by passing the targets to it and apply your sampling logic to return the right indices.
You could use one of the already implemented samplers from here as the base class and add your filtering to it.

1 Like

thanks, @ptrblck, I just done it. I’m testing the sampling logic in this moment.
I would like to publish the code on Pytorch repo. Any suggestions?

Do you mean as an additional sampler into the PyTorch code base?
If so, please create a feature request with the description of the use case and your changes on GitHub, where it can be discussed further. :slight_smile:

1 Like

yes, thanks a lot @ptrblck. I will create a feature request as you said.

Hi @ptrblck,
I tried to follow the guide in order to submit our dataloader approach. However, I have the following error if I start

ImportError: cannot import name 'cpp_extensionp'

If I try to run only a class test I obtain:

ERROR collecting test/ ___________________
ImportError while importing test module '/home/nataraja/Desktop/pytorch/test/'.
Hint: make sure your test modules/packages have valid Python names.
/usr/lib/python3.6/importlib/ in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
test/ in <module>
    from import (_utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset,
E   ImportError: cannot import name 'BufferedShuffleDataset'

I read the guide and i followed all steps,
Can you help me?
Thanks a lot.

This issue should have been fixed in the latest master version as described here. Are you using the stable 1.7.0 code base? If so, could you switch to master?

yes, I was in the master branch and not in the stable 1.7.0 code base. I forked yesterday, so I think i’ve the latest version of the branch master. Any suggestions?

Could you post the so that I could reproduce it with the current master?

This is the my fork (i’ve not modified anything)

So you are seeing these errors while only building from source and running the

if i run the error is:

    from torch.utils import cpp_extensionp
ImportError: cannot import name 'cpp_extensionp'

There seems to be a typo (extensionp instead of extension), so let me grep for it and check, what’s going on.

1 Like

yes, there was a typo error about extensionp
I submitted the request and the pull code is:

Thanks a lot @ptrblck :slight_smile: