Hi,
I got this error, couldn’t reproduce this. Even Though, I checked all my model dimensions.
Can anyone tell me where the issue exactly?
Below is my code:
class SCNN(nn.Module):
def init(self, class_num=NUM_CHARS):
super(SCNN, self).init()
self.classes = class_num
# in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[0]
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
# in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[1]
self.avg_poo11 = nn.AvgPool2d(kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)
self.avg_poo12 = nn.AvgPool2d(kernel_size=1, stride=1, padding=0)
self.fc1 = nn.Linear(6272, cfg_fc[0])
self.fc2 = nn.Linear(cfg_fc[0], cfg_fc[1])
def forward(self, input, time_window=20):
c1_mem = c1_spike = torch.zeros(args.batch_size, 32, 28, 28, device=DEVICE)
c2_mem = c2_spike = torch.zeros(args.batch_size, 32, 14, 14, device=DEVICE)
h1_mem = h1_spike = h1_sumspike = torch.zeros(args.batch_size, cfg_fc[0], device=DEVICE)
h2_mem = h2_spike = h2_sumspike = torch.zeros(args.batch_size, NUM_CHARS, device=DEVICE)
for step in range(time_window): # simulation time steps
x = input > torch.rand(input.size(), device=DEVICE) # prob. firing
# print(x.shape)
c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)
# print(c1_mem.shape)
# x = nn.AvgPool2d(kernel_size=1, stride=1, padding=0)
x = self.avg_poo11(c1_spike)
# print(x.shape)
c2_mem, c2_spike = mem_update(self.conv2, x, c2_mem, c2_spike)
# print(c2_mem.shape)
x = self.avg_poo12(c2_spike)
# print(x.shape)
x = x.view(args.batch_size, -1)
# print(x.shape)
h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
# print(h1_mem.shape)
h1_sumspike += h1_spike
h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem, h2_spike)
# print(h2_mem.shape)
h2_sumspike += h2_spike
outputs = h2_sumspike / time_window
return outputs
names = ‘spiking_model’
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
acc_record = list([])
loss_train_record = list([])
loss_test_record = list([])
snn = SCNN(class_num=NUM_CHARS)
snn.to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=args.learning_rate)
num_epochs = 50
for epoch in range(num_epochs):
running_loss = 0
start_time = time.time()
for i, (images, labels) in enumerate(dataloaders[“train”]):
snn.zero_grad()
optimizer.zero_grad()
images = images.float().to(DEVICE)
outputs = snn(images)
labels_ = torch.zeros(args.batch_size, NUM_CHARS).scatter_(1, labels.view(-1, 1), 1)
loss = criterion(outputs.cpu(), labels_)
running_loss += loss.item()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f'
% (epoch + 1, num_epochs, i + 1, len(datasets['train']) // args.batch_size, running_loss))
running_loss = 0
print('Time elasped:', time.time() - start_time)
correct = 0
total = 0
optimizer = lr_scheduler(optimizer, epoch, args.learning_rate, 40)
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(dataloaders['val']):
inputs = inputs.to(DEVICE)
optimizer.zero_grad()
outputs = snn(inputs)
labels_ = torch.zeros(args.batch_size, NUM_CHARS).scatter_(1, targets.view(-1, 1), 1)
loss = criterion(outputs.cpu(), labels_)
_, predicted = outputs.cpu().max(1)
total += float(targets.size(0))
correct += float(predicted.eq(targets).sum().item())
if batch_idx % 100 == 0:
acc = 100. * float(correct) / float(total)
print(batch_idx, len(dataloaders['val']), ' Acc: %.5f' % acc)
print('Iters:', epoch, '\n\n\n')
print('Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total))
acc = 100. * float(correct) / float(total)
acc_record.append(acc)
if epoch % 5 == 0:
print(acc)
print('Saving..')
state = {
'net': snn.state_dict(),
'acc': acc,
'epoch': epoch,
'acc_record': acc_record,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt' + names + '.t7')
best_acc = acc