Hi, I have some questions about torch.cuda.amp. I want to use mixed precision training in my project and I notice that the grads of network parameters is still float32, which is really bothering me. I consider that the fp16 params should have fp16 gradients and fp32 params should have fp32 gradients in mixed precision training. So ,is there any wrong with me? Can anyone help me ?
The following is my test code:
class CustomModel(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(CustomModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.relu = nn.ReLU()
self.fc3 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.fc3(out)
return out
# 初始化模型
input_size = X_train_tensor.shape[1]
hidden_size = 16000
num_classes = 2000
model = CustomModel(input_size, hidden_size, num_classes).to('cuda')
scaler = torch.cuda.amp.GradScaler()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
num_epochs = 1
for epoch in range(num_epochs):
optimizer.zero_grad()
with torch.cuda.amp.autocast(dtype=torch.float16, enabled=True):
outputs = model(X_train_tensor)
loss = criterion(outputs, y_train_tensor)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print(outputs.dtype)
print(model.fc1.weight.grad.dtype)
print(model.fc2.weight.grad.dtype)
print("Done!")
and the output is
torch.float16
torch.float32
torch.float32
Done!