Hi all,
I was wondering if there is a way of grouping multiple Parameter class instances in a collection (i.e. ParameterList) s.t. they are registered well as Module parameters, and can be indexed with boolean mask, or an array/list of indexes (like a numpy array)? Nesting parameters in ParameterList registers them correctly, but only list-like indexing is supported.
If you need more specific use case info, I would be happy to provide it.
Thank you!
I would be interested in hearing more about your use case.
If I understand it correctly, you would like to be able to use something like:
params = nn.ParamArray([[param1, param2], [param3, param4]])
p = params[0, 1]
I am sorry for this delayed reply.
My use case might be best described as multitask learning, where I have a network hard shared between tasks and a per-task 1-d tensor which linearly combines the shared network outputs. Let’s say I have n tasks in total, and in each batch, I have samples from k (k<<n) tasks so I want:
-
to efficiently index the ‘collection’ of all per-task weights s.t. for an input of shape (b,d) where b is the batch size, I end up with a tensor of shape (b,t) where t is the size of per-shape tensor.
This is where the implementation using Parameter/ModuleList, wrapping the individual Parameter(per_shape_tensor) fails, as I can’t index the ParamList with boolean mask or list of indices (as in torch or numpy arrays).
-
that the loss.backward() time doesn’t scale with the total number of tasks (backward computation time in a batch should only depend on the number of different tasks in the current batch i.e. gradients should not be calculated for the per-task tensors which are not used in the current batch forward computation)
This is what prevents me from having a large Parameter wrapping a matrix with per-task weights as rows. In this way, I can efficiently index, however backward computation scales with the number of tasks.
(this sems similar to the embedding layers - I tried implementing both with torch.nn.Embedding, and torch.nn.functional.embedding but both implementations seem to scale with number of tasks)
I would be very thankful if you would provide me with any helpful insights for implementing this while achieving both 1) and 2).