Chainer framework with custom made loss function

Hello!

I am trying to run an old example written in chainer. I am trying to run it either with a Pytorch loss function or preferably with my own custom made loss. The training loss though won’t decrease.

I am running the following code by replacing the chainer’s in-built loss function softmax_cross_entropy() with a custom made. Although I have included “loss.requires_grad=True” the network doesn’t seem to learn. Can anyone help me modify the loss function to Chainer?

PS: i disclose that i have posted this question to stackoverflow, too.

!pip install chainer

import chainer
import chainer.functions as F
import chainer.links as L
from __future__ import print_function
import matplotlib.pyplot as plt
from chainer.datasets import mnist
import torch
import torch.nn as nn
from chainer import backend
from chainer import backends
from chainer.backends import cuda

from chainer import Function, FunctionNode, gradient_check, report, training, utils, Variable

from chainer import datasets, initializers, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

# Download the MNIST data if you haven't downloaded it yet
train, test = mnist.get_mnist(withlabel=True, ndim=1)
#_______________________________________________________________________________

from chainer import iterators

# Choose the minibatch size.
batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize)
test_iter = iterators.SerialIterator(test, batchsize,
                                     repeat=False, shuffle=False)
#_______________________________________________________________________________


class CrossEntropyLossManual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:
                n_batch -= 1
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss

def lossfun(x, t):
    loss_fn = CrossEntropyLossManual()
    return loss_fn(torch.Tensor(x.data), torch.Tensor(t.data).long())

class MyNetwork(Chain):

    def __init__(self, n_mid_units=100, n_out=10):
        super(MyNetwork, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_mid_units)
            self.l2 = L.Linear(n_mid_units, n_mid_units)
            self.l3 = L.Linear(n_mid_units, n_out)

    def __call__(self, x):
        h = F.relu(self.l1(x))
        h = F.relu(self.l2(h))
        return self.l3(h)

model = MyNetwork()
#_______________________________________________________________________________


from chainer import optimizers

# Choose an optimizer algorithm
optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9)

# Give the optimizer a reference to the model so that it
# can locate the model's parameters.
optimizer.setup(model)
#_______________________________________________________________________________


import numpy as np
from chainer.dataset import concat_examples
from chainer.cuda import to_cpu

max_epoch = 10
gpu_id = 0
while train_iter.epoch < max_epoch:

    # ---------- One iteration of the training loop ----------
    train_batch = train_iter.next()
    image_train, target_train = concat_examples(train_batch)

    # Calculate the prediction of the network
    prediction_train = model(image_train)

    # Calculate the loss with softmax_cross_entropy
    loss = lossfun(prediction_train, target_train)
    loss.requires_grad=True
    # Calculate the gradients in the network
    model.cleargrads()
    loss.backward()

    # Update all the trainable paremters
    optimizer.update()
    # --------------------- until here ---------------------

    # Check the validation accuracy of prediction after every epoch
    if train_iter.is_new_epoch:  # If this iteration is the final iteration of the current epoch

        # Display the training loss
        print('epoch:{:02d} train_loss:{:.04f} '.format(
            train_iter.epoch, loss.data), end='')    

This is not a valid approach. If loss.requires_grad returned False after lossfunction was called, it would mean this tensor is detached from the computation graph and calling loss.requires_grad=True won’t re-attach it somehow.

You are absolutely right!! It is false! Would you think of any way to somehow integrate the pytorch loss to chainer?

Kostas

I don’t see any detaching operations besides creating a new tensor via:

torch.Tensor(x.data)

If x is a tensor directly pass it to the loss function.

It is actually a Variable (what Chainer uses). I am converting it to tensor by torch.Tensor(x.data). Is this where it is being detached?
I have used the following modification:

def lossfun(x, t):
loss_fn = CrossEntropyLossManual()
return loss_fn(torch.tensor(x.data, requires_grad=True), torch.Tensor(t.data).long())

Now the print returns True for gradients but it still won’t learn!

Kostas

I think it depends on your point of view. Recreating a new tensor from another PyTorch tensor will detach it:

x = torch.randn(1, 10)
lin = nn.Linear(10, 10)

out = lin(x)
# out is attached to the computation graph ans shows a valid .grad_fn
print(out.grad_fn)
# <AddmmBackward0 object at 0x7f3527ce58d0>

# recreating a tensor detaches it
y = torch.tensor(out)
# UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

print(y.grad_fn)
# None

# setting .requires_grad=True does not fix it
y.requires_grad_()
print(y.requires_grad)
# True
y.mean().backward()

print(lin.weight.grad)
# None

However, in your code it seems you are passing a Chainer variable to PyTorch, so obviously you won’t detach anything in PyTorch as it’s the fist PyTorch operation.
In this case you would need to check how Chainer interacts with different frameworks.

1 Like