Hi all,
I was wondering if there is a way of grouping multiple Parameter class instances in a collection (i.e. ParameterList) s.t. they are registered well as Module parameters, and can be indexed with boolean mask, or an array/list of indexes (like a numpy array)? Nesting parameters in ParameterList registers them correctly, but only listlike indexing is supported.
If you need more specific use case info, I would be happy to provide it.
Thank you!
I would be interested in hearing more about your use case.
If I understand it correctly, you would like to be able to use something like:
params = nn.ParamArray([[param1, param2], [param3, param4]])
p = params[0, 1]
I am sorry for this delayed reply.
My use case might be best described as multitask learning, where I have a network hard shared between tasks and a pertask 1d tensor which linearly combines the shared network outputs. Let’s say I have n tasks in total, and in each batch, I have samples from k (k<<n) tasks so I want:

to efficiently index the ‘collection’ of all pertask weights s.t. for an input of shape (b,d) where b is the batch size, I end up with a tensor of shape (b,t) where t is the size of pershape tensor.
This is where the implementation using Parameter/ModuleList, wrapping the individual Parameter(per_shape_tensor) fails, as I can’t index the ParamList with boolean mask or list of indices (as in torch or numpy arrays).

that the loss.backward() time doesn’t scale with the total number of tasks (backward computation time in a batch should only depend on the number of different tasks in the current batch i.e. gradients should not be calculated for the pertask tensors which are not used in the current batch forward computation)
This is what prevents me from having a large Parameter wrapping a matrix with pertask weights as rows. In this way, I can efficiently index, however backward computation scales with the number of tasks.
(this sems similar to the embedding layers  I tried implementing both with torch.nn.Embedding, and torch.nn.functional.embedding but both implementations seem to scale with number of tasks)
I would be very thankful if you would provide me with any helpful insights for implementing this while achieving both 1) and 2).