[Solved] How to implement custom loss function?

(YoungMin Park) #1

I need to implement custom loss function and I saw following tutorial.

class LinearFunction(Function):

    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

But I don’t know how to use above guide for my situation.

I have ground truth data, but it’s not a label data.

I perform math formulas with GT data through all image pixels and I obtain one float number.

And I want to update weights in CNN as above float number loss value minimizes.

How to implement forward() and backward() in this case?

(Masaki Kozuki) #2

IMO, you don’t need to implement a class as you pasted.

In that case, implementing your loss function using torch and/or torch.nn.functional functions.
Then, pytorch will do backprop automatically.

(YoungMin Park) #3

I can obtain loss value like this

# tensor([3.0082], device='cuda:0')

But I can’t see grad value from

print(last_loss_v.grad) # None

And when I try this


This erros shows up

element 0 of tensors does not require grad and does not have a grad_fn

What I want is to update CNN weights until this loss tensor([3.0082], device=‘cuda:0’) minimizes enough

Would you inform me some guides about points what I incorrectly did and how to solve above errors?

(Masaki Kozuki) #4

Oh, really?

One thing I define loss function using torch functions is this and it could backprop. https://github.com/crcrpar/pytorch.sngan_projection/blob/master/losses.py

(YoungMin Park) #5

Thanks. I’ll try based on your loss function.

(YoungMin Park) #6

I’ve read your loss function
but I can’t see what is the difference except you use torch and F’s operation (like torch.mean(), F.softplus()) at the end of each function

This is code I’m writing


Would you tell me what is the difference?

Generating network


Obtaining prediction


Calculating loss


Updating network


(Masaki Kozuki) #7

I think these 2 lines are wrong .

detach() does

Returns a new Tensor, detached from the current graph.

The result will never require gradient.


(YoungMin Park) #8

Yeah. I’ll find how to resize torch tensor image by not using detach() and numpy().

But without using detach() and numpy(),
and with the circumstance where I change this sentence




All epochs can be done.

But loss value still shows constant, which means CNN is not updated like this

Total number of parameters: 224481
en_i_loss tensor(8.7888, device='cuda:0', requires_grad=True)
en_i_loss.grad tensor(1., device='cuda:0')
Saved model at end of iteration
en_i_loss tensor(8.7888, device='cuda:0', requires_grad=True)
en_i_loss.grad tensor(1., device='cuda:0')
Saved model at end of iteration
en_i_loss tensor(8.7888, device='cuda:0', requires_grad=True)
en_i_loss.grad tensor(1., device='cuda:0')
Saved model at end of iteration
Saved model at end of epoch
Train finished

And above CNN not being updated issue was my first issue.

Implementing custom loss function is attempt to resolve that first issue.

(YoungMin Park) #9

I think I’ve solved this issue.
I shouldn’t have broken backpropagation graph by newly and unnecessarily wrapping torch Variable tensor in loss_functions.py.

Newly wrapping torch tensor by Variable() deletes grad values.

In conclusion, I checked and made all torch Variable tensors print grad_fn
by using this

print(torch_Variable_tensor.grad_fn) # <MeanBackward1 object at 0x7f1ebc889da0>

If all torch tensors print grad_fn information, that means backpropagation graph is well connected
and means you can perform backpropagation by using loss.backward()