Training with custom, quantized datatype

I want to train a model that has weights in a custom, quantized datatype (< 16 bits) to bring down memory usage. To avoid precision loss, I still want the gradient to be computed in torch.bfloat16. A torch.float32 copy of the weight would be used for accumulation.

import torch
class QuantizedModuleFunction(torch.autograd.Function):
    def forward(ctx, input, w_quantized):
        # will eventually be some actual computation
        return torch.zeros(128,128, dtype=input.dtype, device=input.device)

    def backward(ctx, grad_output):
        # will eventually be some actual computation
        grad_input = torch.zeros(128,128, dtype=grad_output.dtype, device=grad_output.device)
        grad_w =  torch.zeros(128,128, dtype=torch.bfloat16, device=grad_output.device)
        return grad_input, grad_w

class QuantizedModule(torch.nn.Module):
    def __init__(self):
        super(QuantizedModule, self).__init__()
        # will eventually be torch.<some custom dtype>
        # See question 1
        self.w_quantized = torch.nn.Parameter(torch.zeros(128, 128, dtype=torch.complex64))
        self.w_fp32 = torch.zeros(128, 128, requires_grad=False)

    def forward(self, input):
        return QuantizedModuleFunction.apply(input, self.w_quantized)

def custom_optimizer(w_quantized, w_fp32, grad):
    """ will eventually update w_fp32 and w_quantized and zero out grad"""

def test_training_loop():
    input = torch.zeros((128,128), requires_grad=True, dtype=torch.bfloat16)
    model = QuantizedModule()
    for i in range(10):
      ans = model(input)
      loss = ans.sum()
      if model.w_quantized.grad.dtype != torch.bfloat16:
          # see question 2
          print(f"gradient dtype expected to be torch.bfloat16 but is instead {model.w_quantized.grad.dtype}")
      if (i+1) % 2 == 0:
        # update weights every 2 microbatches
        custom_optimizer(model.w_quantized, model.w_fp32, model.w_quantized.grad)
  1. Currently, I am using torch.complex64 for the weight; but I would like to use a custom, quantized datatype. Are there any suggestions on how to add python bindings for a new datatype?
  2. Despite the backward function returning torch.bfloat16, model.w_quantized.grad.dtype is torch.complex64 because an implicit cast is performed here: pytorch/engine.cpp at d03d43df527e48771875537ad20212d5cb333215 · pytorch/pytorch · GitHub. Is there a preferred method for computing the gradient in a different type than the weight? One option is to add a flag (e.g requires_grad_diff_type) to the TensorOptions object.

Alternatives I’ve considered:

  • An alternative is to define the model weight in torch.bfloat16 and have the forward pass internally create a quantized copy. This is less desirable since it increases memory usage; there would be both a bfloat16 and a quantized copy in memory. Additionally, the quantized copy would be recreated every microbatch even though the weights only change every 2 microbatches.
  • I took a look at the torch quantization feature Quantization — PyTorch 1.11.0 documentation. But in my usage scenario, I want the weights to be trained in the quantized datatype, which doesn’t seem possible.


  1. Adding a new native dtype is pretty tricky and not really possible today I’m afraid. You can use a custom subclass though to do something like that: What (and Why) is __torch_dispatch__? - frontend API - PyTorch Dev Discussions and a collection of examples GitHub - albanD/subclass_zoo
  2. The constraint in the autograd engine is mostly because cross dtype tensor/grad is not something that is used today and that is a common error when implementing backwards. So it is safer to have this check. We are open to lifting this check for particular Tensor types (in particular subclasses that I mentioned above).
1 Like

Thanks @albanD . Assuming I create a python subclass, how would the python type be retrieved from here: pytorch/engine.cpp at d03d43df527e48771875537ad20212d5cb333215 · pytorch/pytorch · GitHub

Today, this dtype will be the “true” dtype of the subclass you created.
It will be either the dtype of the Tensor you gave to make_subclass or the dtype value passed to make_wrapper_subclass.

We are looking into ways to extend this, but since c++ is strictly typed, it is not very easy to do so.
But such a check can be bypassed like the one below when tensor.is_tensor_subclass().

1 Like

@albanD Thanks for your suggestions. I’ve created [WIP] quantized tensor by ashari4 · Pull Request #39 · albanD/subclass_zoo (, with some questions. I would appreciate a review.

1 Like