Accessing weight corresponding to a single element in a batch

Hi there,

I’ve not dig deep into the PyTorch codebase but I assume that in order to perform batch inference/training, weights are duplicated and loaded into the GPU to perform parallel inference depending on the batch size. Let’s say I’d like to access the weights corresponding to every element in a batch, i.e. my model is an adaptive one and I’d like to modify the duplicated weights independently w.r.t. to the current input. What’s the best way to do this?

Thanks!

The parameters are not duplicated and batched operations are used instead, i.e. each input sample will use the same parameters and the output will be a batch again.

I see, thanks! So it’s impossible to have multiple set of duplicated weights for a given batch of inputs on the current version of PyTorch.
I assume it’s technically feasible, since weights will be loaded into separate CUDA cores prior to concurrent processing of different inputs. Am I right on this part?

I don’t fully understand why you would like to duplicate the weights.
It would surely be possible to clone the weights and apply them one-by-one for each input sample, but this is usually not necessary, so could you explain your use case a bit more?

It’s a new idea that I’m trying to implement. It consist of the adaptation of weights w.r.t. the inputs. Since a batch could consist of inputs of different nature, i.e. different class in the case of classification, the adaptation of weights would results in the multiple set of weights for a single batch. It’s a novel method that I’m working on. A workaround is to restrict a batch to contain inputs that are closely distributed.

Thanks for the explanation.
In that case you could probably either clone the parameter or initialize multiple layers for the current input e.g. via nn.ModuleList and select the desired one.

Thanks for the reply, this is one way to do it, but it will result in linear growth in memory according to the batch size, e.g. batch size of 128 will result in 128 different layers/modules stored. In the meantime, I’ll think of a more general approach that doesn’t require the replication of weights.