I ran into a strange problem. When I run the following code and commented out
* 1, the code works well, consuming about 2GB memory, but when
* 1 is retained, the code becomes OOM. Even though I run it on a machine with 12GB GPU memory.
import torch import torch.nn.functional as F def func(x1, x2): _, _, H, W = x1.size() x2 = F.pad(x2,  * 4) cv =  for i in range(9): for j in range(9): cost = x1 * x2[:, :, i:(i + H), j:(j + W)] # * 1 cv.append(torch.mean(cost, 1, keepdim=True)) return torch.cat(cv, 1) if __name__ == '__main__': x1 = torch.randn(4, 256, 128, 256, requires_grad=True).to(torch.device("cuda")) x2 = torch.randn(4, 256, 128, 256, requires_grad=True).to(torch.device("cuda")) y = func(x1, x2) y.sum().backward()