Ben
June 22, 2017, 10:40am
1
I met a problem similar to https://github.com/fchollet/keras/issues/2115
It is actually using weight_matrix in loss function and can be implemented in Keras. So how to implement it in Pytorch?
Here is the Keras code copied from the upper link:
def w_categorical_crossentropy(y_true, y_pred, weights):
nb_cl = len(weights)
final_mask = K.zeros_like(y_pred[:, 0])
y_pred_max = K.max(y_pred, axis=1)
y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
for c_p, c_t in product(range(nb_cl), range(nb_cl)):
final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
return K.categorical_crossentropy(y_pred, y_true) * final_mask
Any suggestions is welcome!
Thanks in advance!
Ben
smth
June 22, 2017, 2:58pm
2
you can implement it exactly like you implemented your keras code.
Ben
June 22, 2017, 3:04pm
3
Because I’m new in Pytorch, I can’t find the conterparts in Pytorch, especially in the for loop. I’d appreciate if you show me sorm code!
Ben
June 22, 2017, 3:25pm
4
I think the major steps are:
calculate the cross entropy for each sample in a batch
calculate the weight for each sample, which is like a lookup table in a for loop
loss = sum(cross_entropy_tensor * weight_tensor) / batch_size
Now I can get softmax tensor with shape batch_size * num_class by using nn.LogSoftmax. Then I’m a little confused about how to implement 1 and 2.
nn.NLLLoss seems combine 1 and 3 with no per sample weight.
I guess, you can split the crossentropy loss to [softmax, log, NLLLoss].
So you can mul a weight matrix after the log operation and pass the log(p(x)) to NLLLoss.
Ben
June 27, 2017, 10:43am
7
Yup. I’m working on my own cross entropy using the softmax, pretty confusing!
Ben
August 25, 2017, 6:56am
8
I implemented the loss function. Here is the gist:
gistfile1.txt
def one_hot(size, index):
""" Creates a matrix of one hot vectors.
```
import torch
import torch_extras
setattr(torch, 'one_hot', torch_extras.one_hot)
size = (3, 3)
index = torch.LongTensor([2, 0, 1]).view(-1, 1)
torch.one_hot(size, index)
# [[0, 0, 1], [1, 0, 0], [0, 1, 0]]
This file has been truncated. show original
Know issue: Part of the program must run on CPU, it may be slow. In my own test, the speed is ok. If you have better solution, please let me know!