I have three weight maps that’s output from a CONV layer. They are of the same size [B, C, H, W]
My goal is to employ a sparsemax(variant of softmax) normalization on every pixel of the three weight maps so that after normalization their values would sum to one and have a sparse distribution. But my implementation above doesn’t seem to work. So I wonder if torch.cat() method have gradient? Or is it something else that leads to the mistake?