I ran into a strange problem. When I run the following code and commented out * 1
, the code works well, consuming about 2GB memory, but when * 1
is retained, the code becomes OOM. Even though I run it on a machine with 12GB GPU memory.
import torch
import torch.nn.functional as F
def func(x1, x2):
_, _, H, W = x1.size()
x2 = F.pad(x2, [4] * 4)
cv = []
for i in range(9):
for j in range(9):
cost = x1 * x2[:, :, i:(i + H), j:(j + W)] # * 1
cv.append(torch.mean(cost, 1, keepdim=True))
return torch.cat(cv, 1)
if __name__ == '__main__':
x1 = torch.randn(4, 256, 128, 256, requires_grad=True).to(torch.device("cuda"))
x2 = torch.randn(4, 256, 128, 256, requires_grad=True).to(torch.device("cuda"))
y = func(x1, x2)
y.sum().backward()