Way of modeling multi-dimensional output?

Assume I think a reasonable shape for my output is n m x m matrices, maybe I believe my loss would be best modeled in terms of these matrices (e.g. I’d prefer to predict n/2 of them correctly, rather than predict every 2nd element correctly but have 0 overall correct matrices).

I would just have a linear layer of size: n * m * m as the output from my network, and then construct a custom loss function I apply to each m*m chunk of my output, at least that was my first though.

However, I assume there might be a better way to approach such a problem. So I’m curios if anyone has any ideas of experience doing this ?

If anyone knows how to do this for a more general case (e.g. something that would apply to a list of arrays of arbitrary dimensions, rather than to a list of 2-dimensional arrays as per my example) I’d be even more curios to know that.

Are there any built-in loss functions that could be used for these type of problem or would I be better off constructing one from scratch ?