Custom Loss Function From Scratch

Hi,
I want to write a custom loss function so first i tried to mimic built-in MSE loss as below but it gave much lower losses than built-in function (nn.MSELoss()). It is important to learn like below how to extract (i,j)th data and do something with it like calculating gradient on x-direction etc. Can you please explain where am i doing wrong? Also, i have 3 output channels and i wonder how nn.MSELoss() calculates loss (calculates separately and then adds together?) Thanks!

    def custom(self, pred, truth):      
        for j in range(1,255):
            for i in range(1,255):
                loss = torch.mean((truth[:,:, i, j] - pred[:,:, i, j])**2)
        return loss
                                                        
        self.custom.to(self.device)

It’s lower because you’re resetting it to just the ith jth error value; you should be accumulating it over all i and j.

Thanks, i forgot it but this time code does not give any result( waited for long time). I still couldn’t understand how should nn.MSELoss() works for multichannel output and does my function deal with it correctly.

    def custom(self, pred, truth):     
        loss = 0
        for j in range(1,255):
            for i in range(1,255):
                loss += (truth[:,:, i, j] - pred[:,:, i, j])**2
        return loss.mean()

Well your nested loops aren’t speeding it up. It’s slow because you’re doing 65,000+ pixel-wise MSEs, which I’m guessing is making your autograd graph extremely long. Not only that but there’s no parallelism possible for these calculations.
Why not just diff between truth and pred wholesale?
loss = torch.square(truth[:,:0:255,0:255] - pred[:,:0:255,0:255]).mean()

or if your images are already 256 pix wide and high (I’m assuming 255 is an off-by-one error?), then just
loss = torch.square(truth - pred).mean()

You’re right those loops make it very slow but i need to use it somehow because i will try a function which uses the information of neighbours for each element.
By the way i tried your indexing but i noticed that when i change the upper limit of width and height to a very high value, it still gave me a result. I am using 4 images (3x256x256) with a batch size of 4 to at least learn how to implement a custom loss function. Can you explain why it still gives a result. Thanks

I just noticed that if i exceed the length of an array it won’t exit an error instead it will give the result like;

x = [1,2,3,4,5]
In [17]: x[0:5]

Out[17]: [1, 2, 3, 4, 5]

In [18]: x[0:15]

Out[18]: [1, 2, 3, 4, 5]