How to write a new backward function for my own loss

I have some codes

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.distributions
#one_hot code
def one_hot(index, classes):
    size = index.size() + (classes,)
    view = index.size() + (1,)
    mask = torch.Tensor(*size).fill_(0)
    index = index.view(*view)
    ones = 1.
    if isinstance(index, Variable):
        ones = Variable(torch.Tensor(index.size()).fill_(1))
        mask = Variable(mask, volatile=index.volatile)
    return mask.scatter_(1, index, ones)
def switch1(flag):
    if flag == 0:
        indices = torch.LongTensor([0])
    elif flag == 1:
        indices = torch.LongTensor([0,1])
    elif flag == 2:
        indices = torch.LongTensor([0,1,2])
    elif flag == 3:                                                                   
        indices = torch.LongTensor([0,1,2,3])
    elif flag == 4:
        indices = torch.LongTensor([0,1,2,3,4])
    elif flag == 5:
        indices = torch.LongTensor([0,1,2,3,4,5])
    elif flag == 6:
        indices = torch.LongTensor([0,1,2,3,4,5,6])
    elif flag == 7:
        indices = torch.LongTensor([0,1,2,3,4,5,6,7])
    return indices
def switch2(flag):
    if flag == 0:
        indices = torch.LongTensor([1,2,3,4,5,6,7])
    elif flag == 1:
        indices = torch.LongTensor([2,3,4,5,6,7])
    elif flag == 2:
        indices = torch.LongTensor([3,4,5,6,7])
    elif flag == 3:                                                                   
        indices = torch.LongTensor([4,5,6,7])
    elif flag == 4:
        indices = torch.LongTensor([5,6,7])
    elif flag == 5:
        indices = torch.LongTensor([6,7])
    elif flag == 6:
        indices = torch.LongTensor([7])
    return indices

class EMD2Loss(nn.Module):
    def __init__(self, eps=1e-7):
        super(EMD2Loss, self).__init__()
        self.eps = eps
    def forward(self, input, target):
        yy = one_hot(target1, input.size(-1))
        Psoft1 = torch.nn.functional.softmax(input).cpu()
        for i in range(0,target1.size(0)-1):
            for j in range(0,flag+1):
                P1=torch.index_select(Psoft[i,], -1, indice)
                y1=torch.index_select(y[i,], -1, indice)
                Loss =Loss+(torch.sum(P1-y1))
            if flag!=7:
                for k in range(0,flag+1):
                    P2=torch.index_select(Psoft[i,], -1, indice)
                    y2=torch.index_select(y[i,], -1, indice)
        return Loss 

when i use this into training ,the first epoch,the first batch has losses ,but in next batch, loss has been zero
like this :
|Epoch: [0][0/226]||Loss 0.1984|Prec@1 34.375
|Epoch: [0][1/226]|Time 0.439 (1.098)|Data 0.001 (0.078)|Loss 0.0000 (0.0358)|Prec@1 4.688

the loss is like this:
the P is Psoft above the code,T is y.C=8
how to write this?Thanks!