I have some codes
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.distributions
#one_hot code
def one_hot(index, classes):
index=index.cpu()
size = index.size() + (classes,)
view = index.size() + (1,)
mask = torch.Tensor(*size).fill_(0)
index = index.view(*view)
ones = 1.
if isinstance(index, Variable):
ones = Variable(torch.Tensor(index.size()).fill_(1))
mask = Variable(mask, volatile=index.volatile)
return mask.scatter_(1, index, ones)
def switch1(flag):
if flag == 0:
indices = torch.LongTensor([0])
elif flag == 1:
indices = torch.LongTensor([0,1])
elif flag == 2:
indices = torch.LongTensor([0,1,2])
elif flag == 3:
indices = torch.LongTensor([0,1,2,3])
elif flag == 4:
indices = torch.LongTensor([0,1,2,3,4])
elif flag == 5:
indices = torch.LongTensor([0,1,2,3,4,5])
elif flag == 6:
indices = torch.LongTensor([0,1,2,3,4,5,6])
elif flag == 7:
indices = torch.LongTensor([0,1,2,3,4,5,6,7])
return indices
def switch2(flag):
if flag == 0:
indices = torch.LongTensor([1,2,3,4,5,6,7])
elif flag == 1:
indices = torch.LongTensor([2,3,4,5,6,7])
elif flag == 2:
indices = torch.LongTensor([3,4,5,6,7])
elif flag == 3:
indices = torch.LongTensor([4,5,6,7])
elif flag == 4:
indices = torch.LongTensor([5,6,7])
elif flag == 5:
indices = torch.LongTensor([6,7])
elif flag == 6:
indices = torch.LongTensor([7])
return indices
class EMD2Loss(nn.Module):
def __init__(self, eps=1e-7):
super(EMD2Loss, self).__init__()
self.eps = eps
def forward(self, input, target):
target1=target.detach()
yy = one_hot(target1, input.size(-1))
Psoft1 = torch.nn.functional.softmax(input).cpu()
y=torch.tensor(yy)
Psoft=torch.tensor(Psoft1)
Loss=0.0
for i in range(0,target1.size(0)-1):
flag=int(target1[i].item())
for j in range(0,flag+1):
indice=switch1(j)
P1=torch.index_select(Psoft[i,], -1, indice)
y1=torch.index_select(y[i,], -1, indice)
Loss =Loss+(torch.sum(P1-y1))
if flag!=7:
for k in range(0,flag+1):
indice=switch2(k)
P2=torch.index_select(Psoft[i,], -1, indice)
y2=torch.index_select(y[i,], -1, indice)
Loss=Loss+(torch.sum(P2-y2))
Loss=Loss/target1.size(0)
return Loss
when i use this into training ,the first epoch,the first batch has losses ,but in next batch, loss has been zero
like this :
|Epoch: [0][0/226]||Loss 0.1984|Prec@1 34.375
|Epoch: [0][1/226]|Time 0.439 (1.098)|Data 0.001 (0.078)|Loss 0.0000 (0.0358)|Prec@1 4.688
the loss is like this:
the P is Psoft above the code,T is y.C=8
how to write this?Thanks!