Thank you for your replying.
Based on your suggestions, I moved the threshold to the nn.Module
, like:
class SpikingBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, image_size, batch_size, stride=1, option='A',init_threshold=1.0):
super(SpikingBasicBlock, self).__init__()
self.threshold = nn.Parameter(torch.tensor(init_threshold, dtype=torch.float))
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = BatchNorm(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = BatchNorm(planes)
self.drop2 = nn.Dropout(0.2)
self.planes = planes
self.stride = stride
def forward(self, x, c1_mem, c1_spike, c2_mem, c2_spike):
out = self.bn1(self.conv1(x))
c1_mem, c1_spike = mem_update(out, c1_mem, c1_spike,self.threshold)
out = self.bn2(self.conv2(c1_spike))
out += self.shortcut(x)
c2_mem, c2_spike = mem_update(out, c2_mem, c2_spike,self.threshold)
c2_spike = self.drop2(c2_spike)
return c2_spike, c1_mem, c1_spike, c2_mem, c2_spike
Up to now, the threshold can update by forward
pass, but its changes are minimal. Is it because the gradient of self. threshold
is 0 everywhere?
Thank you again for your support.
Best regards.