please let me know, How can I get the grad of v_a
I try to exec a code like this, but failed to calculate the grad of v_a
Note: t_b = t_a * 2 just a example, the real operation is complex, and variable can’t achieve.
from torch import Tensor
from torch.autograd import Variable
from torch.nn import functional as F
target = Variable(torch.LongTensor([2, 1]))
v_a = Variable(Tensor([[1,2,3], [3,4,5]]),requires_grad=True)
v_b = v_a * 2
loss = F.cross_entropy(v_b, target)
A few points:
Crit doesnt exist, as far as I can tell. replaced by cross_entropy
torch.nn contains functors, but you can find functional equivalents in torch.nn.functional, which I’ve used here
you need to provide everything in mini-batches, so:
– the input to cross_entropy should be two dimensional, (N, C), where N is size of minibatch, and C is numberof classes
– in practice, this meant I had to make your v_a two dimensional too, so it was a minibatch
target, for cross_entropy is a onedimensional LongTesnor, of class labels, one per minibatch example
everything should be Variables. Forget that Tensors exist
I feel like that statement can be misleading for people learning. Variables are wrappers for tensors and you should think of Variables as tensors just wrapped with pytorch’s Variable wrapper so you can auto compute gradients
Yes… the flaw with that is, let’s say you feed in a Tensor to a net. You dont need the gradient from that tensor, no backprop to the input required, so logically, doesnt need to be a Variable. Except, not.
Thanks for your reply, but you ignore this line “t_b = 2 * (v_a.data)” in my code
Now I’m sure this way that I can’t get v_a.grad
so I want to know how to achieve my operation in the condition that can’t jump out of variable. It seems like scatter_nd in tensorflow.
indices = tf.constant([0,3])
updates = tf.constant([0.2,0.6])
scatter = tf.scatter_nd(indices, updates, shape=4)
[0.2 , 0 , 0 , 0.6]
as you can see, the index in indices fill the corresponding value in updates.
You can’t access it with .grad.
You can use var.register_hook(fn) (doc here). Where fn is a function that will be given as input the gradient of var. You can then use this function to monitor this gradient or store it in a global variable to have it available somewhere else in your program.