Hi, I’m using the gdl loss function from https://github.com/ginobilinie/medSynthesisV1/blob/master/nnBuildUnits.py, when I call and run the code on my own, it shows the problem:
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
I’m new in pytroch, and have no idea what’s wrong with anywhere and what should I do for the next step. Any help will be highly appreciated!!
I call it in this way:
from model.unet2 import UNet
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
net = UNet(in_channels=1, out_channels=1, init_features=32)
net.to(device)
ctimgs= read_img(“k1/train/ct_data/patient_002_ct.gipl”)
mrimgs = read_img(k1/train/mri_data/patient_002_mr_T1.gipl")
ctimgs1 = ctimgs[0,:,:]
mrimgs1 = mrimgs[0,:,:]
ctimgs1 = ctimgs1[np.newaxis,np.newaxis,:,:]
mrimgs1 = mrimgs1[np.newaxis,np.newaxis,:,:]
ctimgs1 = torch.from_numpy(ctimgs1)
mrimgs1 = torch.from_numpy(mrimgs1)
out = net(mrimgs1)
gdlloss = gdl_loss()
loss = gdlloss(out,ctimgs1)
this is the gdl_loss
class gdl_loss(nn.Module):
def __init__(self, pNorm=2):
super(gdl_loss, self).__init__()
self.convX = nn.Conv2d(1, 1, kernel_size=(1, 2), stride=1, padding=(0, 1), bias=False)
self.convY = nn.Conv2d(1, 1, kernel_size=(2, 1), stride=1, padding=(1, 0), bias=False)
filterX = torch.FloatTensor([[[[-1, 1]]]]) # 1x2
filterY = torch.FloatTensor([[[[1], [-1]]]]) # 2x1
self.convX.weight = torch.nn.Parameter(filterX,requires_grad=False)
self.convY.weight = torch.nn.Parameter(filterY,requires_grad=False)
self.pNorm = pNorm
def forward(self, pct, ct):
assert not ct.requires_grad
assert pct.dim() == 4
assert ct.dim() == 4
assert pct.size() == ct.size(), "{0} vs {1} ".format(pct.size(), ct.size())
pred_dx = torch.abs(self.convX(pct))
pred_dy = torch.abs(self.convY(pct))
gt_dx = torch.abs(self.convX(ct))
gt_dy = torch.abs(self.convY(ct))
grad_diff_x = torch.abs(gt_dx - pred_dx)
grad_diff_y = torch.abs(gt_dy - pred_dy)
mat_loss_x = grad_diff_x ** self.pNorm
mat_loss_y = grad_diff_y ** self.pNorm # Batch x Channel x width x height
shape = ct.shape
mean_loss = (torch.sum(mat_loss_x) + torch.sum(mat_loss_y)) / (shape[0] * shape[1] * shape[2] * shape[3])
return mean_loss