I’m trying to implement a Group Lasso penalty on a batched input of shape (batch_size, *)
, given groups made of lists of coordinates.
As a simple example, if I have a batch of 3 grey-scale images of shape (3, 4, 2), I would define my groups as
[(2, 0), (2, 1), (3, 0), (3, 1)],
[(0, 2), (0, 3), (1, 2), (1, 3)],
[(2, 2), (2, 3), (3, 2), (3, 3)]]
Then, I’d like to get the batch-wise group lasso penalty of shape (batch_size,)
. Each coordinate of this corresponds to one datapoint. For datapoint 0, this would be
torch.norm( [x[0, 0, 0], x[0, 0, 1], x[0, 1, 0], x[0, 1, 1] + x[0, group2] ...
.
I can’t figure out how to index x to get this.
As a second part, I’d like the groups to be defined starting from the last dimensions.
Supposing here that x
has shape (batch_size, n_channels, M, N)
, I want to be able to define my groups only using the last 2 dimensions, and sum the values over the channels.