Indexing for group lasso

I’m trying to implement a Group Lasso penalty on a batched input of shape (batch_size, *), given groups made of lists of coordinates.

As a simple example, if I have a batch of 3 grey-scale images of shape (3, 4, 2), I would define my groups as

    [(2, 0), (2, 1), (3, 0), (3, 1)],
    [(0, 2), (0, 3), (1, 2), (1, 3)],
    [(2, 2), (2, 3), (3, 2), (3, 3)]]

Then, I’d like to get the batch-wise group lasso penalty of shape (batch_size,). Each coordinate of this corresponds to one datapoint. For datapoint 0, this would be
torch.norm( [x[0, 0, 0], x[0, 0, 1], x[0, 1, 0], x[0, 1, 1] + x[0, group2] ....

I can’t figure out how to index x to get this.

As a second part, I’d like the groups to be defined starting from the last dimensions.
Supposing here that x has shape (batch_size, n_channels, M, N), I want to be able to define my groups only using the last 2 dimensions, and sum the values over the channels.

1 Like