I want to train a model that has weights in a custom, quantized datatype (< 16 bits) to bring down memory usage. To avoid precision loss, I still want the gradient to be computed in torch.bfloat16. A torch.float32 copy of the weight would be used for accumulation.
import torch
class QuantizedModuleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, w_quantized):
# will eventually be some actual computation
return torch.zeros(128,128, dtype=input.dtype, device=input.device)
@staticmethod
def backward(ctx, grad_output):
# will eventually be some actual computation
grad_input = torch.zeros(128,128, dtype=grad_output.dtype, device=grad_output.device)
grad_w = torch.zeros(128,128, dtype=torch.bfloat16, device=grad_output.device)
return grad_input, grad_w
class QuantizedModule(torch.nn.Module):
def __init__(self):
super(QuantizedModule, self).__init__()
# will eventually be torch.<some custom dtype>
# See question 1
self.w_quantized = torch.nn.Parameter(torch.zeros(128, 128, dtype=torch.complex64))
self.w_fp32 = torch.zeros(128, 128, requires_grad=False)
def forward(self, input):
return QuantizedModuleFunction.apply(input, self.w_quantized)
def custom_optimizer(w_quantized, w_fp32, grad):
""" will eventually update w_fp32 and w_quantized and zero out grad"""
pass
def test_training_loop():
input = torch.zeros((128,128), requires_grad=True, dtype=torch.bfloat16)
model = QuantizedModule()
for i in range(10):
ans = model(input)
loss = ans.sum()
loss.backward()
if model.w_quantized.grad.dtype != torch.bfloat16:
# see question 2
print(f"gradient dtype expected to be torch.bfloat16 but is instead {model.w_quantized.grad.dtype}")
if (i+1) % 2 == 0:
# update weights every 2 microbatches
custom_optimizer(model.w_quantized, model.w_fp32, model.w_quantized.grad)
- Currently, I am using
torch.complex64
for the weight; but I would like to use a custom, quantized datatype. Are there any suggestions on how to add python bindings for a new datatype? - Despite the backward function returning
torch.bfloat16
,model.w_quantized.grad.dtype
istorch.complex64
because an implicit cast is performed here: pytorch/engine.cpp at d03d43df527e48771875537ad20212d5cb333215 · pytorch/pytorch · GitHub. Is there a preferred method for computing the gradient in a different type than the weight? One option is to add a flag (e.grequires_grad_diff_type
) to theTensorOptions
object.
Alternatives I’ve considered:
- An alternative is to define the model weight in
torch.bfloat16
and have the forward pass internally create a quantized copy. This is less desirable since it increases memory usage; there would be both abfloat16
and a quantized copy in memory. Additionally, the quantized copy would be recreated every microbatch even though the weights only change every 2 microbatches. - I took a look at the torch quantization feature Quantization — PyTorch 1.11.0 documentation. But in my usage scenario, I want the weights to be trained in the quantized datatype, which doesn’t seem possible.