I’m trying to implement a Group Lasso penalty on a batched input of shape `(batch_size, *)`

, given groups made of lists of coordinates.

As a simple example, if I have a batch of 3 grey-scale images of shape (3, 4, 2), I would define my groups as

```
[(2, 0), (2, 1), (3, 0), (3, 1)],
[(0, 2), (0, 3), (1, 2), (1, 3)],
[(2, 2), (2, 3), (3, 2), (3, 3)]]
```

Then, I’d like to get the batch-wise group lasso penalty of shape `(batch_size,)`

. Each coordinate of this corresponds to one datapoint. For datapoint 0, this would be

`torch.norm( [x[0, 0, 0], x[0, 0, 1], x[0, 1, 0], x[0, 1, 1] + x[0, group2] ...`

.

I can’t figure out how to index x to get this.

As a second part, I’d like the groups to be defined starting from the last dimensions.

Supposing here that `x`

has shape `(batch_size, n_channels, M, N)`

, I want to be able to define my groups only using the last 2 dimensions, and sum the values over the channels.