I am training a Deep complex network. The Complex convolution, complex batch normalization etc are used from this repo.(https://github.com/wavefrontshaping/complexPyTorch). I am observing a unique problem which is that when I am training the model on 2 gpus(Titan RTX), The training accuracy is increasing where as the validation accuracy is stuck at around 17%. The same model is showing good validation accuracy (accuracy increasing with training ) when trained at single gpu.
I tried an experiment where I used a model from torchvision and tested on multiple gpu. It worked fine.
For the multiple GPU I am changing only one line
model=nn.DataParallel(complex_model)
I suspect that there might be a issue with parallel computation and complex network.
The tentative code is given below
model=nn.DataParallel(complex_model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for j in range(epoch):
model.train()
total=0
correct=0
for i,(real,imag, label) in enumerate(training_dataset_loader):
real = real.cuda().float()
imag = imag.cuda().float()
label = label.cuda()
label = label.view(-1).long()
optimizer.zero_grad()
outputs = model(real,imag)
loss = criterion(outputs, label)
print('iteration == %d epoch == %d loss == %f '%(i+1,j+1,loss))
loss.backward()
optimizer.step()
softmax = nn.Softmax(dim=1)
outputs=softmax(outputs)
_, predicted = torch.max(outputs.data, 1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accuracy of training set after %d epoch is %f ' % ((j + 1, correct / total * 100)))
model.eval()
num_total = 0
total=0
correct=0
for i, (real,imag, label) in enumerate(validation_dataset_loader):
real = real.cuda().float()
imag = imag.cuda().float()
label = label.cuda()
label = label.view(-1).long()
optimizer.zero_grad()
outputs = model(real,imag)
softmax = nn.Softmax(dim=1)
outputs = softmax(outputs)
_, predicted = torch.max(outputs.data, 1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accuracy of validation set after %d epoch is %f '% ((j+1,correct/total*100)))
for the input i am giving an spectrogram after separating real and imaginary part.
Am I missing something while using multiple gpus?