Hello Dongsup!
Yes, it is (but the requires_grad=True
for s
is unnecessary, and
might cause problems).
Also, note that you can get rid of the loop by using gather()
to
line up the desired elements of y_pred
, as specified by classes
,
with y
.
This is illustrated by the following pytorch version 0.3.0 script and
its output:
import torch
torch.__version__
torch.manual_seed (2020)
# your loss function with loop, tweaked for version 0.3.0
def categorical_mse_2d_loss(y_pred, classes, y):
b, h, w = y.size()
# s = torch.zeros((b, h, w), requires_grad=True)
s = torch.autograd.Variable (torch.zeros((b, h, w))) # you don't need / want requires_grad = True here
for i in range(b):
for j in range(h):
for k in range(w):
c = classes[i,j,k]
# s[i,j,k] = torch.square((y_pred[i,c,j,k] - y[i,j,k])) # is this differentiable?
s[i,j,k] = (y_pred[i,c,j,k] - y[i,j,k])**2 # is this differentiable?
return torch.mean(s)
nBatch = 2
nClass = 3
height = 4
# no width for simplicity
# we will unsqueeze() to add width of 1 for categorical_mse_2d_loss()
# test data
y_pred = torch.autograd.Variable (torch.randn (nBatch, nClass, height), requires_grad = True)
y = torch.autograd.Variable (torch.randn (nBatch, height))
classes = torch.autograd.Variable (torch.LongTensor ([[2, 2, 0, 0], [0, 1, 1, 2]]))
y_pred
y
classes
# we don't need the loop -- we can use gather() to get the relevant elements of y_pred
y_pred_by_class = y_pred.gather (1, classes.unsqueeze (1)).squeeze()
y_pred_by_class
lossA = ((y_pred_by_class - y)**2).mean()
# could use mse_loss()
# lossA = torch.nn.functional.mse_loss (y_pred_by_class, y)
lossA
lossA.backward()
y_pred.grad
y_pred.grad.zero_()
lossB = categorical_mse_2d_loss (y_pred.unsqueeze (3), classes.unsqueeze (2).data, y.unsqueeze (2))
lossB
# yes, loop version is differentiable
lossB.backward()
y_pred.grad
>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x000002AD03216630>
>>>
>>> # your loss function with loop, tweaked for version 0.3.0
... def categorical_mse_2d_loss(y_pred, classes, y):
... b, h, w = y.size()
... # s = torch.zeros((b, h, w), requires_grad=True)
... s = torch.autograd.Variable (torch.zeros((b, h, w))) # you don't need / want requires_grad = True here
... for i in range(b):
... for j in range(h):
... for k in range(w):
... c = classes[i,j,k]
... # s[i,j,k] = torch.square((y_pred[i,c,j,k] - y[i,j,k])) # is this differentiable?
... s[i,j,k] = (y_pred[i,c,j,k] - y[i,j,k])**2 # is this differentiable?
...
... return torch.mean(s)
...
>>> nBatch = 2
>>> nClass = 3
>>> height = 4
>>>
>>> # no width for simplicity
... # we will unsqueeze() to add width of 1 for categorical_mse_2d_loss()
...
>>> # test data
... y_pred = torch.autograd.Variable (torch.randn (nBatch, nClass, height), requires_grad = True)
>>> y = torch.autograd.Variable (torch.randn (nBatch, height))
>>> classes = torch.autograd.Variable (torch.LongTensor ([[2, 2, 0, 0], [0, 1, 1, 2]]))
>>>
>>> y_pred
Variable containing:
(0 ,.,.) =
1.2372 -0.9604 1.5415 -0.4079
0.8806 0.0529 0.0751 0.4777
-0.6759 -2.1489 -1.1463 -0.2720
(1 ,.,.) =
1.0066 -0.0416 -1.2853 -0.4948
-1.2964 -1.2502 -0.7693 1.6856
0.3546 0.7790 0.3257 0.4995
[torch.FloatTensor of size 2x3x4]
>>> y
Variable containing:
0.7705 -0.5920 0.5270 0.0807
0.9863 2.2251 -1.1789 -0.3879
[torch.FloatTensor of size 2x4]
>>> classes
Variable containing:
2 2 0 0
0 1 1 2
[torch.LongTensor of size 2x4]
>>>
>>> # we don't need the loop -- we can use gather() to get the relevant elements of y_pred
... y_pred_by_class = y_pred.gather (1, classes.unsqueeze (1)).squeeze()
>>> y_pred_by_class
Variable containing:
-0.6759 -2.1489 1.5415 -0.4079
1.0066 -1.2502 -0.7693 0.4995
[torch.FloatTensor of size 2x4]
>>>
>>> lossA = ((y_pred_by_class - y)**2).mean()
>>> # could use mse_loss()
... # lossA = torch.nn.functional.mse_loss (y_pred_by_class, y)
... lossA
Variable containing:
2.3522
[torch.FloatTensor of size 1]
>>> lossA.backward()
>>> y_pred.grad
Variable containing:
(0 ,.,.) =
0.0000 0.0000 0.2536 -0.1222
0.0000 0.0000 0.0000 0.0000
-0.3616 -0.3892 0.0000 0.0000
(1 ,.,.) =
0.0051 0.0000 0.0000 0.0000
0.0000 -0.8688 0.1024 0.0000
0.0000 0.0000 0.0000 0.2219
[torch.FloatTensor of size 2x3x4]
>>>
>>> y_pred.grad.zero_()
Variable containing:
(0 ,.,.) =
0 0 0 0
0 0 0 0
0 0 0 0
(1 ,.,.) =
0 0 0 0
0 0 0 0
0 0 0 0
[torch.FloatTensor of size 2x3x4]
>>>
>>> lossB = categorical_mse_2d_loss (y_pred.unsqueeze (3), classes.unsqueeze (2).data, y.unsqueeze (2))
>>> lossB
Variable containing:
2.3522
[torch.FloatTensor of size 1]
>>> # yes, loop version is differentiable
... lossB.backward()
>>> y_pred.grad
Variable containing:
(0 ,.,.) =
0.0000 0.0000 0.2536 -0.1222
0.0000 0.0000 0.0000 0.0000
-0.3616 -0.3892 0.0000 0.0000
(1 ,.,.) =
0.0051 0.0000 0.0000 0.0000
0.0000 -0.8688 0.1024 0.0000
0.0000 0.0000 0.0000 0.2219
[torch.FloatTensor of size 2x3x4]
Best.
K. Frank