Hi,
I am facing this issue, how to fix this?
This is my forward function:
def forward(self, input, time_window=100):
c1_mem = c1_spike = torch.zeros(batch_size, 16, 28, 28, device=device)
print(c1_mem.shape)
c2_mem = c2_spike = torch.zeros(batch_size, 32, 14, 14, device=device)
print(c2_mem.shape)
c3_mem = c3_spike = torch.zeros(batch_size, 64, 7, 7, device=device)
print(c3_mem.shape)
h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, 512, device=device)
print(h1_mem.shape)
h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, 10, device=device)
print(h2_mem.shape)
for step in range(time_window): # simulation time steps
# x = input.expand(input.data.shape[0], 3, 28, 28)
# print(x.shape)
x = input > torch.rand(input.size(), device=device) # prob. firing
# print(x.shape)
c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)
# print(c1_mem.shape)
x = F.avg_pool2d(c1_spike, 2)
c2_mem, c2_spike = mem_update(self.conv2, x, c2_mem, c2_spike)
x = F.avg_pool2d(c2_spike, 2)
c3_mem, c3_spike = mem_update(self.conv3, x, c3_mem, c3_spike)
# print(c3_mem.shape)
x = F.avg_pool2d(c3_spike, 4)
x = x.view(-1,512)
# print(x.shape)
h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
h1_sumspike += h1_spike
h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem, h2_spike)
h2_sumspike += h2_spike
outputs = h2_sumspike / time_window
return outputs