Is there a way to provide arguments to a collate_fn() or in some other way get that function access to information other than ‘batch’?
The use case is this: I have a fairly large, custom data set in need of some sort of normalization. Exactly what normalization will be most effective is not obvious. It would therefore be convenient to use sklearn to construct a bunch of pre-populated scaler objects (say, a maxabs scaler, and a standardization scaler, just for starters) and use a collate_fn to perform the scaling on the fly.
But in order to do this, the collate_fn needs some way to get the file location of those scaler files. But how to do this eludes me.
(The brute force alternative of simply replicating the dataset multiple times, normalizing each one separately, is unpalatable do to the sheer size of the database.)
Is this sufficiently impossible that it warrants a feature request?
I’m not sure I understand everything, but one possibility is to create a class that contains a
__call__ method which will be passed to your
def __init__(self. *params):
self.params = params
def __call__(self, batch):
# do something with batch and self.params
my_collator = MyCollator(param1, param2, ...)
data_loader = torch.utils.data.DataLoader(..., collate_fn=my_collator)
Depending on what you want, I’d probably try one of the following:
- If the normalization is per example, add it to the dataset and keeping track of it in the dataset is the preferred way.
- There are various other ways to achieve something similar to what @fmassa suggested (and even more variants when you search for “currying in python”). The lazy person’s way would be using default arguments:
dl = torch.utils.data.DataLoader(..., collate_fn = lambda b: my_collator_with_param(b, 3))
or, using the params argument in the lambda to fix the argument rather than having the lookup during execution
dl = torch.utils.data.DataLoader(..., collate_fn = lambda b, params=params: my_collator_with_param(b, params))
Between doing things in the dataset and having one of the collator variants from this thread, you should be able to do most things.
Almost same as the above answers though,
functools.partial enabeles you to specify arguments other than
batch like this
from functools import partial
def my_collate_fn(batch, a, b):
DataLoader(..., collate_fn=partial(my_collate_fn, a=foo, b=bar))