I build a unet in pytorch and keras, however it seems much more slower in pytorch.I used Nvidia 1080Ti and Tesla v100 GPU card. I search for the reason why pytorch is slower and I find that pytorch should faster than keras.I wonder if i made some mistakes in my code? so could someone show me how to accelerate the training of my code? Here is my pytorch code:
‘’’
import argparse
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from skimage import io
import torch
import os
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
from SSIM import SSIM
import torchvision.transforms as transforms
import torchvision
from torch.autograd import Variable
from unet_model import *
import os
from patchify import *
def normlize(im):
for i in range(im.size(0)):
im[i] = (im[i] - im[i].min())/(im[i].max() - im[i].min())
return im
def standard(im,mean,var):
return (im - mean)/var
def dataaugment(inputs,target):
rotatetimes = np.random.randint(4)
fliplr = np.random.randint(2)
flipud = np.random.randint(2)
inputs = torch.rot90(inputs,rotatetimes+1,(2,3))
target = torch.rot90(target,rotatetimes+1,(2,3))
batch, width, height = inputs.size(0),inputs.size(2),inputs.size(3)
if fliplr:
for i in range(batch):
img_input = inputs[i][0]
img_target = target[i][0]
inputs[i][0] = torch.fliplr(img_input)
target[i][0] = torch.fliplr(img_target)
if flipud:
for i in range(batch):
img_input = inputs[i][0]
img_target = target[i][0]
inputs[i][0] = torch.flipud(img_input)
target[i][0] = torch.flipud(img_target)
return inputs, target
parser = argparse.ArgumentParser(description='Debackground')
parser.add_argument('--batch_size',type=int,default=16)
parser.add_argument('--epochs',type=int,default=100)
args = parser.parse_args()
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
input_ = io.imread('input_actin.tif')
gt = io.imread('gt_actin.tif')
input_ = torch.tensor(input_,dtype=torch.float32).unsqueeze_(dim=1)
gt = torch.tensor(gt,dtype=torch.float32).unsqueeze_(dim=1)
input_ = normlize(input_)
gt = normlize(gt)
x_train, x_test, y_train, y_test = train_test_split(input_,gt,test_size=0.001)
print(x_train.device)
print(x_test.device)
train_ds = TensorDataset(x_train,y_train)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4)
def weight_init(module):
if isinstance(module,nn.Conv2d):
nn.init.xavier_normal_(module.weight)
elif isinstance(module,nn.Linear):
nn.init.xavier_normal_(module.weight)
elif isinstance(module,nn.BatchNorm2d):
nn.init.constant_(module.weight,1)
nn.init.constant_(module.bias,1)
model = UNet(1,1)
criterion = nn.MSELoss()
learning_rate = 1e-3
if cuda:
model = model.cuda()
criterion.cuda()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)
milestone = [25,50,75]
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestone,gamma=0.5)
writer = SummaryWriter('runs/lightsheet_experiment')
step = 0
for epoch in range(args.epochs):
for j,(data,label) in enumerate(train_dl,0):
model.train()
model.zero_grad()
optimizer.zero_grad()
if cuda:
data = data.cuda()
label = label.cuda()
pred = model(data)
loss = 1000*criterion(pred,label)
loss.backward()
optimizer.step()
scheduler.step()
print("[epoch %d step %d loss %.4f]"%(epoch,j,loss.item()))
if step%10==0:
writer.add_scalar('train_loss', loss.item(),step)
step +=1
with torch.no_grad():
for jj,(x_test,y_test) in enumerate(test_dl,0):
noise_x = Variable(x_test, volatile=True)
target_y = Variable(y_test, volatile=True)
if torch.cuda.is_available():
noise_x = noise_x.cuda()
target_y = target_y.cuda()
y_val = model(noise_x)
val_loss = SSIM()(y_val,target_y)
ssim += val_loss.item()
recurrent += 1
ssim = ssim/recurrent
writer.add_scalar('ssim', ssim,epoch)
if (epoch+1)%50==0:
clean_grid = torchvision.utils.make_grid(normlize(y_test),nrow=4)
writer.add_image('clean image'+str(epoch+1),clean_grid,dataformats='CHW')
dirty_grid = torchvision.utils.make_grid(normlize(x_test),nrow=4)
writer.add_image('dirty image'+str(epoch+1),dirty_grid,dataformats='CHW')
debackground_grid = torchvision.utils.make_grid(normlize(y_val),nrow=4)
writer.add_image('debackground image'+str(epoch+1),debackground_grid,dataformats='CHW')
print("[epoch %d val_loss %.4f]"%(epoch,ssim))
del val_loss
del y_val
torch.save(model.state_dict(), os.path.join(os.getcwd(), 'net_latest.pth'))
if (epoch+1)%10==0:
path = os.path.join(os.getcwd(),'model','deback_epoch%d.pth'%(epoch+1))
torch.save(model.state_dict(),path)
‘’’