Supplying Arguments to Collate_fn

Is there a way to provide arguments to a collate_fn() or in some other way get that function access to information other than ‘batch’?

The use case is this: I have a fairly large, custom data set in need of some sort of normalization. Exactly what normalization will be most effective is not obvious. It would therefore be convenient to use sklearn to construct a bunch of pre-populated scaler objects (say, a maxabs scaler, and a standardization scaler, just for starters) and use a collate_fn to perform the scaling on the fly.

But in order to do this, the collate_fn needs some way to get the file location of those scaler files. But how to do this eludes me.

(The brute force alternative of simply replicating the dataset multiple times, normalizing each one separately, is unpalatable do to the sheer size of the database.)

1 Like

Anyone?

Is this sufficiently impossible that it warrants a feature request?

I’m not sure I understand everything, but one possibility is to create a class that contains a __call__ method which will be passed to your DataLoader

class MyCollator(object):
    def __init__(self. *params):
        self.params = params
    def __call__(self, batch):
        # do something with batch and self.params

my_collator = MyCollator(param1, param2, ...)
data_loader = torch.utils.data.DataLoader(..., collate_fn=my_collator)
25 Likes

Depending on what you want, I’d probably try one of the following:

  • If the normalization is per example, add it to the dataset and keeping track of it in the dataset is the preferred way.
  • There are various other ways to achieve something similar to what @fmassa suggested (and even more variants when you search for “currying in python”). The lazy person’s way would be using default arguments:
dl = torch.utils.data.DataLoader(..., collate_fn = lambda b: my_collator_with_param(b, 3))

or, using the params argument in the lambda to fix the argument rather than having the lookup during execution

dl = torch.utils.data.DataLoader(..., collate_fn = lambda b, params=params: my_collator_with_param(b, params))

Between doing things in the dataset and having one of the collator variants from this thread, you should be able to do most things.

Best regards

Thomas

12 Likes

Almost same as the above answers though, functools.partial enabeles you to specify arguments other than batch like this

from functools import partial


def my_collate_fn(batch, a, b):
    ...

DataLoader(..., collate_fn=partial(my_collate_fn, a=foo, b=bar))
18 Likes