I want to understand autograd graph, so I write a code snippet about +=
.
import torch
from collections import deque
a = torch.tensor([1, 2, 3], dtype=torch.float32, requires_grad=True)
b = torch.tensor([4, 5, 6], dtype=torch.float32)
b[0] += a[0]
print(b)
# look autograd graph
graph = [[(b.grad_fn,)]]
queue = deque([b.grad_fn])
while len(queue) != 0:
item = queue.popleft()
if item is None:
continue
if type(item).__name__ == 'AccumulateGrad':
continue
temp = []
for fn, _ in item.next_functions:
queue.append(fn)
if type(fn).__name__ == 'AccumulateGrad':
temp.append((fn, fn.variable))
else:
temp.append((fn,))
graph.append(temp)
for x in graph:
print(x)
print
result is below.
In order to see clearly, I visualize it.
What puzzles me is why there is
AsStridedBackward0
and so many CopySlices
.I hope someone can help me. Thanks a lot.