A compiled function will hang forever without a print statement. When the print statement in mul()
is uncommented, function will run fine at any size input (n in main). If the print statement is left commented, only n <= 8 will work. Here is a reduced example.
import torch,random
def mul(a, b):
#print("") #uncomment to unhang
n = len(a)
c = [[[0]*n]*n]*n
for i in range(n):
for j in range(n):
for k in range(n):
c[i][j][k] = 0
for l in range(n):
c[i][j][k] += a[i][j][l] * b[i][j][k]
return c
def main():
n = 16 #change to 8 to unhang
t1,t2 = [[[random.random() for _ in range(n)] for _ in range(n)] for _ in range(n)],[[[random.random() for _ in range(n)] for _ in range(n)] for _ in range(n)]
fun = torch.compile(mul)
fun(t1,t2)
if __name__ == "__main__":
main()```