Consider the following purely illustrative C++ code:
/*Create PyTorch Model and Send to GPU*/
Foo model();
model.to(torch::kCUDA);
/*Make Optimizer*/
torch::optim::Adam adam_opt(model.parameters(), torch::optim::AdamOptions(1e-3));
/*Use Training Mode*/
model.train();
/*Random Training/Label Batch*/
torch::Tensor input_batch = torch::rand({ BATCH_SIZE, 1, HEIGHT, WIDTH}).to(torch::kCUDA);
torch::Tensor label_batch= torch::rand({ BATCH_SIZE}).to(torch::kCUDA);
/*Train for 100 Iterations*/
for (int i = 0; i < 100; i++) {
/*Zero the gradient buffers*/
adam_opt.zero_grad();
/*Forward Through Model*/
torch::Tensor output = model.forward(input_batch );
/*Calculate Loss*/
torch::Tensor loss = torch::binary_cross_entropy(output, label_batch);
/*Backprop then update using optimizer*/
loss.backward();
adam_opt.step();
}
The subsequent code is the Python equivalent to the above C++:
#Create PyTorch Model and Send to GPU
model = Foo()
model = model.to(torch.device("cuda:0"))
#Loss Function
loss_fn = torch.nn.BCELoss().to(torch.device("cuda:0"))
#Make Optimizer
optimizer = torch.optim.Adam(model.parameters())
#Use Training Mode
model.train()
#Random Training/Label Batch
input_batch = torch.rand(BATCH_SIZE, 1, HEIGHT, WIDTH).to(torch.device("cuda:0"))
label_batch = torch.rand(BATCH_SIZE).to(torch.device("cuda:0"))
#Train for 100 Iterations
for i in range(100):
#Zero the gradient buffers
optimizer.zero_grad()
#Forward Through Model
output = model(input_batch)
#Calculate Loss
loss = loss_fn(output, label_batch)
#Backprop then update using optimizer
loss.backward()
optimizer.step()
If I comment out the loss.backward() function (and subsequently the optimizer.step() function) in the Python code, then the amounts of GPU memory used by the Python and C++ codes are identical. However, if I run the code as originally presented with the loss.backward() function (and optimizer.step() function) present in both the C++ and Python codes, then the Python code uses significantly less memory.
An explanation for why backward() reduces memory usage was given here: Calling loss.backward() reduce memory usage?
However, this does not explain why the equivalent C++ code doesn’t also have similar memory reduction. Is there a way to reduce the C++ code’s memory usage to match the Python one? This seems like an important issue for those of us trying to port our code over.