I want to attach a forward hook to a module with an Input of type Generator and do some operation using the number of elements in the generator and the shape of the elements, how could it be done?
The generator is finite.
import torch import torch.nn as nn class Sum(nn.Module): def __init__(self): super(Sum, self).__init__() def forward(self, x): return sum(x) def sum_hook(mod, input, output): print(type(input)) print(len(list(input))) # 0 s = Sum() h = s.register_forward_hook(sum_hook) i = torch.randn((1,2,3,4)) x = s((i for _ in range(3))) # generator x = s([i for _ in range(3)]) # list h.remove()
My explanation is that the generator is consumed and that’s why the input length is zero. Is there a workaround? (I know that one can use a list, and it works, but using the generator is faster)
Thank you in advance.