Hi Miturian!
I think you’re trying to say something slightly different, so let me
rephrase it:
x
is a batchSize-2 “input” to the “model”; w
is the (scalar) weight
that defines this very simple “model”.
Yes.
Yes, but I think it’s more helpful to think of this as a “tensor” chain
rule, rather than a chain rule broken down into scalar pieces.
Thus:
dL / dw = sum_i { dL / dy_i * dy_i / dw }
,
where sum_i {... * ...}
is a tensor contraction.
That is:
>>> import torch
>>> torch.__version__
'1.9.0'
>>>
>>> w = torch.tensor ([2.], requires_grad = True) # "requires_grad = True" leaf tensor
>>> x = torch.tensor ([1., 2.]) # "requires_grad = False" leaf tensor
>>> y = torch.mul (x, w) # non-leaf tensor (same as x * w)
>>> L = y.mean() # root (non-leaf) final result
>>>
>>> y.retain_grad() # so we can examine y.grad
>>> L.retain_grad() # so we can examine L.grad
>>>
>>> w
tensor([2.], requires_grad=True)
>>> x
tensor([1., 2.])
>>> y
tensor([2., 4.], grad_fn=<MulBackward0>)
>>> L
tensor(3., grad_fn=<MeanBackward0>)
>>>
>>> L.backward (retain_graph = True)
>>>
>>> L.grad # unit "seed" to start chain rule
tensor(1.)
>>> y.grad # [dL / dy_0, dL / dy_1]
tensor([0.5000, 0.5000])
>>> x.grad # None, because x is a leaf with requires_grad = False
>>> w.grad # [dL / dw] = sum_i {dL / dy_i * dy_i / dw}
tensor([1.5000])
If I understand you correctly that “these factors” are things like dy_i / dw
,
then, no, they’re are not stored in the graph, waiting for .backward()
to
be called.
Instead, when y = x * w
is called (in the “forward pass”), only the fact
that *
(torch.mul()
) was called is stored – by using a MulBackward0
object – together with any torch.mul()
-specific context that mul()
's
companion backward()
function will need to calculate dy_i / dw
at
the time that the backward pass is run. To emphasize, dy_i / dw
is
neither computed nor stored during the forward pass (but information
sufficient to compute it during the backward pass is stored).
I prefer to say that y = x * w
creates a single arm, but that the arm in
question is a “tensor arm.” This is not purely semantic, in that autograd
stores a single “tensor arm” for this operations, rather than two separate
“scalar arms.”
Just to be clear, the graph only contains those computations that
lead back to a requires_grad = True
leaf tensor. It doesn’t
waste time or storage on computations that lead back solely to
requires_grad = False
leaf tensors.
If by this you mean that a lot of scalar computations are “linked to
the same variables” because those scalar computations are part
of the same tensor computation, the graph, in a sense, contains
all of those scalar computations, but in an efficient way, because
those scalar computations are packaged together as a single tensor
operation.
If by “symmetries” you mean that the several (in this example, two)
scalar operations are related to one another because they are part
of the same tensor operation, then, yes, autograd takes advantage
of this structure – and does so, as described above, by storing these
several scalar operations together as a single tensor operation.
I don’t know how to, but I think I’ve seen posts that talk about probing
the details of the graph. (I would have to imaging that the pytorch api
supports this, even if it isn’t well-documented or a part of what we
think of as the “public-facing” api, but, again, I don’t know how to do it.)
Best.
K. Frank