for i, data in enumerate(dataset):
img = torch.squeeze(data['image'].to(device))
save_path = data['save_path']
h, w = img.shape
num = list(range(0,int(w / 2 - 10))) + list(range(int(w / 2 + 10), w))
# motion_events = np.random.randint(30, high=150)
motion_events = 1
k_space_lines = random.sample(num, motion_events)
class AffineTransform(torch.nn.Module):
def __init__(self, h, w, n) -> None:
super().__init__()
self.n = n
self.I = torch.nn.Parameter(torch.randn((1, 1, h, w), dtype=torch.float))
self.X = torch.nn.Parameter(torch.randn(n, dtype=torch.float))
self.Y = torch.nn.Parameter(torch.randn(n, dtype=torch.float))
self.THETA = torch.nn.Parameter(torch.randn(n, dtype=torch.float))
self.tm = torch.stack([torch.cos(self.THETA), torch.sin(-self.THETA), self.X, torch.sin(self.THETA), torch.cos(self.THETA), self.Y], dim=1).reshape(n, 1, 2, 3)
def forward(self, i):
self.tm = self.tm.to(device)
grid = F.affine_grid(self.tm[i], torch.Size(self.I.size()), align_corners=False)
return F.grid_sample(self.I, grid, align_corners=False).to(device)
# return self.I
def update(self):
self.tm = torch.stack([torch.cos(self.THETA), torch.sin(-self.THETA), self.X, torch.sin(self.THETA), torch.cos(self.THETA), self.Y], dim=1).reshape(self.n, 1, 2, 3)
stn = AffineTransform(h, w, motion_events).to(device)
optimz = torch.optim.Adam(stn.parameters(), lr=0.001)
print("processing: {}".format(save_path))
for epoch in range(10000):
corrupted_img = stn(0)
lsf = torch.nn.L1Loss().to(device)
optimz.zero_grad()
diff = lsf(img, corrupted_img[0, 0])
diff.backward()
optimz.step()
stn.update()
I debug find that all tensor in cuda, but it throw me a RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu!
I don’t know why.