I’m trying to compute a Fast Walsh-Hadamard function. In theory this operation should allocate 0 extra memory, but the implementation below is not.
def fast_walsh_hadamard_torched(x, axis: int = 0, normalize: bool = False):
orig_shape = x.size()
assert axis >= 0 and axis < len(orig_shape), (
"For a vector of shape %s, axis must be in [0, %d] but it is %d"
% (orig_shape, len(orig_shape) - 1, axis)
)
h_dim = orig_shape[axis]
h_dim_exp = int(round(np.log(h_dim) / np.log(2)))
assert h_dim == 2 ** h_dim_exp, (
"hadamard can only be computed over axis with size that is a power of two, but"
" chosen axis %d has size %d" % (axis, h_dim)
)
working_shape_pre = [int(torch.prod(torch.tensor(orig_shape[:axis])))]
working_shape_post = [
int(torch.prod(torch.tensor(orig_shape[axis + 1:])))
]
working_shape_mid = [2] * h_dim_exp
working_shape = working_shape_pre + working_shape_mid + working_shape_post
ret = x.view(working_shape)
print(ret.size())
for dim in range(1, h_dim_exp+1):
left, right = torch.chunk(ret, 2, dim=dim)
ret = torch.cat([left + right, left - right], dim=dim)
del left, right
if normalize:
ret = ret / torch.sqrt(float(h_dim))
ret = ret.view(orig_shape)
return ret
torch.Size([67108864])
----------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem Number of Calls
----------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
empty 0.09% 3.494ms 0.09% 3.494ms 20.315us 13.50 Gb 13.50 Gb 0 b 0 b 172
resize_ 0.04% 1.461ms 0.04% 1.461ms 27.565us 13.00 Gb 13.00 Gb 0 b 0 b 53
empty_strided 0.00% 76.222us 0.00% 76.222us 38.111us 512.00 Mb 512.00 Mb 0 b 0 b 2
zeros 0.00% 24.376us 0.00% 66.824us 66.824us 400 b 0 b 0 b 0 b 1
zero_ 0.00% 9.836us 0.00% 30.566us 30.566us 0 b 0 b 0 b 0 b 1
fill_ 0.38% 15.447ms 0.38% 15.450ms 2.575ms 0 b 0 b 0 b 0 b 6
to 0.00% 48.789us 1.14% 46.184ms 5.132ms 512.00 Mb 0 b 0 b 0 b 9
size 0.07% 2.735ms 0.07% 2.735ms 0.453us 0 b 0 b 0 b 0 b 6036
constant_pad_nd 0.00% 120.489us 0.38% 15.590ms 15.590ms 256.00 Mb 0 b 0 b 0 b 1
narrow 0.02% 664.211us 0.06% 2.371ms 15.099us 0 b 0 b 0 b 0 b 157
----------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
As we can see the implementation is allocating a lot more memory than is needed, I narrowed this down to the cat
operation in the for loop of the implementation. Essentially cat
is allocating new tensors every time it’s being called (and is being recorded as empty op in the profiler).
I also need gradient history to be recorded so I’m not able to just use a buffer and use the out
functionality of cat
.
Any ideas?