I have tried now with
torch._logging.set_logs(recompiles=True, recompiles_verbose = True, fusion = True)
it repeatedly print fusion attempts, suggesting that it tries to recompile the graph with each call. But recompiles
reports nothing.
Here is a complete code, where I added timing of compilation and benchmarking of compiled function.
import torch
import torch.nn as nn
from torch import Tensor
import logging
# torch._logging.set_logs(guards=True)
torch._logging.set_logs(recompiles=True, recompiles_verbose = True, fusion = True)
class A(nn.Module):
def __init__(self, v:float):
super().__init__()
self.b = v
def forward(self, x):
return (x.sin()+self.b).sin()
def forward(net, x:Tensor):
return net(x).sum()
dev = torch.device('cuda:0')
compile_args = dict(fullgraph=True, dynamic = True, backend = "inductor", options={'group_fusion':True, 'force_same_precision':True, 'disable_cpp_codegen':False, 'trace.graph_diagram':True, "triton.cudagraphs": False})
cforward = torch.compile(forward, **compile_args)
def compute(net):
y = cforward(net, x)
y.backward()
start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)
x = torch.rand((100,100,100), device =dev, requires_grad=True)
nets = [A(float(j)) for j in range(3)]
# nets = [nn.Sequential(*[A(float(j)) for j in range(3)]) for i in range(3)]
# WARMUP/ COMPILE
s = torch.cuda.Stream(device=dev)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.device(dev):
with torch.cuda.stream(s): # warmup on a side stream, according to examples
for i in range(3):
print(f"_______________Net {i}_______________")
x.grad = None
start_e.record()
compute(nets[i])
end_e.record()
torch.cuda.synchronize()
t = start_e.elapsed_time(end_e)
print(f'Time: {t} ms')
print(x.grad.mean())
torch.cuda.current_stream().wait_stream(s)
G = torch.cuda.CUDAGraph()
with torch.cuda.graph(G):
x.grad = None
for i in range(3):
compute(nets[i])
for i in range(3):
print(f"_______________Compiled Graphed Iteration {i}_______________")
tt = 0
for j in range(100):
start_e.record()
G.replay()
end_e.record()
torch.cuda.synchronize()
t = start_e.elapsed_time(end_e)
tt += t
print(f'Time: {tt/100} ms')
print(x.grad.mean())
It prints
_______________Net 0_______________
V0902 13:17:17.227000 140492323473216 torch/_inductor/scheduler.py:1820] [0/0] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0902 13:17:17.227000 140492323473216 torch/_inductor/scheduler.py:627] [0/0] [__fusion] cannot fuse buf0 with buf1: numel/rnumel mismatch (reduce) (123, 1), (((s0**3 + 122)//123), 123)
V0902 13:17:17.227000 140492323473216 torch/_inductor/scheduler.py:2133] [0/0] [__fusion] found 0 possible fusions
V0902 13:17:17.228000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion] completed fusion round (1/10): fused 2 nodes into 2 nodes
V0902 13:17:17.228000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion]
V0902 13:17:17.228000 140492323473216 torch/_inductor/scheduler.py:1834] [0/0] [__fusion] ===== fusion complete (1 iterations) =====
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1820] [0/0] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:2133] [0/0] [__fusion] found 0 possible fusions
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion]
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1834] [0/0] [__fusion] ===== fusion complete (1 iterations) =====
Time: 2500.05810546875 ms
tensor(0.7457, device='cuda:0')
_______________Net 1_______________
V0902 13:17:19.046000 140492323473216 torch/_inductor/scheduler.py:1820] [0/1] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0902 13:17:19.046000 140492323473216 torch/_inductor/scheduler.py:627] [0/1] [__fusion] cannot fuse buf0 with buf1: numel/rnumel mismatch (reduce) (123, 1), (((s0**3 + 122)//123), 123)
V0902 13:17:19.046000 140492323473216 torch/_inductor/scheduler.py:2133] [0/1] [__fusion] found 0 possible fusions
V0902 13:17:19.047000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion] completed fusion round (1/10): fused 2 nodes into 2 nodes
V0902 13:17:19.047000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion]
V0902 13:17:19.047000 140492323473216 torch/_inductor/scheduler.py:1834] [0/1] [__fusion] ===== fusion complete (1 iterations) =====
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1820] [0/1] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:2133] [0/1] [__fusion] found 0 possible fusions
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion]
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1834] [0/1] [__fusion] ===== fusion complete (1 iterations) =====
Time: 517.2654418945312 ms
tensor(0.1222, device='cuda:0')
_______________Net 2_______________
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:1820] [0/2] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:627] [0/2] [__fusion] cannot fuse buf0 with buf1: numel/rnumel mismatch (reduce) (123, 1), (((s0**3 + 122)//123), 123)
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:2133] [0/2] [__fusion] found 0 possible fusions
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion] completed fusion round (1/10): fused 2 nodes into 2 nodes
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion]
V0902 13:17:19.535000 140492323473216 torch/_inductor/scheduler.py:1834] [0/2] [__fusion] ===== fusion complete (1 iterations) =====
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1820] [0/2] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:2133] [0/2] [__fusion] found 0 possible fusions
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion]
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1834] [0/2] [__fusion] ===== fusion complete (1 iterations) =====
Time: 446.4669189453125 ms
While I expect to see compilation of the first net only and then reusing the same compiled code. However the latter attempts are much faster than the first one, so likely it does reuse some compiled code. That’s confusing. Do I misunderstand something or it is really an issue / several issues?