L-BFGS optimizer doesn't work properly on CUDA

Hi, I have some problems when I use L-BFGS optimizer on pytorch.

My problems are below.

  • L-BFGS optimizer with CUDA doesn’t converge or converge too early (converge on high loss value)
  • L-BFGS with CPU work perfectly.
  • If I set data types of all tensor to float64, It is work.

I think that these problem depends on my environment because my code works good on Google Colab.
But even in the new virtual conda env, same problems happen.

My env are below.

  • Windows 10 64bit
  • Python 3.8.10
  • Pytorch 1.10.0
  • Nvidia RTX 3070
  • CUDA 11.1 & cuDNN v8.2.1
  • Newest Nvidia GPU Driver

This is my example code.

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1234)
np.random.seed(1234)

device = torch.device("cuda")  # Or torch.device("cpu")
dtype = torch.float32  # Or torch.float64

torch.set_default_dtype(dtype)

x_data = np.linspace(0, 1, 100)[:, None]
y_data = np.sin(x_data * 2.0 * np.pi) + np.sin(x_data * 20.0 * np.pi) * 0.1

x_input = Variable(torch.tensor(x_data, dtype=dtype)).to(device)
y_target = Variable(torch.tensor(y_data, dtype=dtype)).to(device)

model = torch.nn.Sequential(
    torch.nn.Linear(1, 20),
    torch.nn.Tanh(),
    torch.nn.Linear(20, 1),
).to(device)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.LBFGS(model.parameters(), lr=1.0, max_iter=5000)


y_out_init = model(x_input)

print(f"Device : {device}, dtype : {dtype}")
iteration = 0


def closure():
    global iteration
    optimizer.zero_grad()
    y_out = model(x_input)
    loss = loss_fn(y_out, y_target)
    loss.backward()
    iteration += 1
    print(f"\rIteration : {iteration}, Loss : {loss.item():.5e}", end="")
    if iteration % 500 == 0:
        print("")
    return loss


optimizer.step(closure)
y_out_final = model(x_input).detach().cpu().numpy()

# visualization
plt.title("Curve fitting")
plt.xlabel("x")
plt.ylabel("y")

plt.plot(x_data, y_data, label="ground truth")
plt.plot(x_data, y_out_init.detach().cpu().numpy(), label="init curve")
plt.plot(x_data, y_out_final, label="optimized curve")

plt.legend()
plt.show()

Here are my outputs,

Device : cpu, dtype : torch.float32
Iteration : 421, Loss : 4.79544e-03

Device : cpu, dtype : torch.float64
Iteration : 500, Loss : 4.79593e-03
Iteration : 625, Loss : 4.79562e-03

Device : cuda, dtype : torch.float32
Iteration : 500, Loss : 4.91581e-03
Iteration : 1000, Loss : 4.93076e-03
Iteration : 1500, Loss : 4.92084e-03
Iteration : 2000, Loss : 4.91559e-03
Iteration : 2465, Loss : 2.16753e+17

Device : cuda, dtype : torch.float64
Iteration : 400, Loss : 4.79128e-03

Does anyone have an idea what might be the problem?

Could you disable TF32 via torch.backends.cuda.matmul.allow_tf32 = False and rerun the code, please?

1 Like

Oh, Now it seems to work fine. Would you mind explaining about this?

And, should I always disable TF32 to use the L-BFGS optimizer?

Thanks for your comments anyway.