Training Neural Networks with BFloat16

Hello, I’m trying to train Neural Networks using format datatype BFloat16 in Pytorch.

I’ve started with a simple example. I’ve tried to train LeNet5 with MNIST dataset.
Firstly, I’ve extracted the datasets and dataloaders with the next code:

transforms = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor(),
                                 transforms.ConvertImageDtype(dtype= torch.bfloat16)]) 

# download and create datasets
train_dataset = datasets.MNIST(root='mnist_data', 
                               train=True, 
                               transform=transforms,
                               download=True) 

valid_dataset = datasets.MNIST(root='mnist_data', 
                               train=False, 
                               transform=transforms) 

# define the data loaders
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)  

valid_loader = DataLoader(dataset=valid_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False) 

Apparently, this works, and the dataset has Bfloat16 datatype.

The problem I have is when I try to execute the model.
This is my train function:

def train(train_loader, model, criterion, optimizer, device):
    model.train() 
    running_loss = 0 
    
    for X, y_true in train_loader: 

        optimizer.zero_grad() 
      
        X = X.to(device) 
        y_true = y_true.to(device) 
        
        # Forward pass
        y_hat, _ = model(X) 
        loss = criterion(y_hat, y_true) 
        running_loss += loss.item() * X.size(0) 

        # Backward pass
        loss.backward() 
        optimizer.step() 
        
    epoch_loss = running_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss 

An, when I execute the model function. I get the error.
This is how I configure the model execution:

torch.manual_seed(RANDOM_SEED) 
model = LeNet5(N_CLASSES).to(DEVICE) 
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 
criterion = nn.CrossEntropyLoss() 

model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, N_EPOCHS, DEVICE)

And this is the error I get:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-45-d0a5e863d7fc> in <module>
----> 1 model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, N_EPOCHS, DEVICE)

<ipython-input-37-247a8affb032> in training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every)
     13 
     14         # training
---> 15         model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device) #Llamada a la función de entrenamiento
     16         train_losses.append(train_loss) #Añadimos al array de pérdidas de entrenamiento, la périda del entrenamiento de la época
     17 

<ipython-input-35-c9cf685f7d11> in train(train_loader, model, criterion, optimizer, device)
     18 
     19         # Forward pass
---> 20         y_hat, _ = model(X) #Aplicamos el modelo en X y obtenemos la predicción
     21         loss = criterion(y_hat, y_true) #Calculamos la pérdida comparando la predicción con el valor correcto
     22         running_loss += loss.item() * X.size(0) #Aumentamos el valor de la pérdida total sumando la pérdida actual con el tamaño de la imagen actual

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

<ipython-input-38-7278d3f7d36a> in forward(self, x)
     23 
     24     def forward(self, x): #Implementación de Lenet5
---> 25         x = self.feature_extractor(x) #Extrae las características de X y las guarda en X
     26         x = torch.flatten(x, 1) #Aplana un rango de atenuaciones en un tensor. Comienzo de dimensión: X, final de dimensión:1
     27         logits = self.classifier(x) #Clasificamos los valores de X

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    397 
    398     def forward(self, input: Tensor) -> Tensor:
--> 399         return self._conv_forward(input, self.weight, self.bias)
    400 
    401 class Conv3d(_ConvNd):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    393                             weight, bias, self.stride,
    394                             _pair(0), self.dilation, self.groups)
--> 395         return F.conv2d(input, weight, bias, self.stride,
    396                         self.padding, self.dilation, self.groups)
    397 

RuntimeError: Expected object of scalar type BFloat16 but got scalar type Float for argument #2 'weight' in call to _thnn_conv2d_forward

How can I solve this?
Thank you

You need to convert your model to bfloat16 as well (model.to(dtype=torch.bfloat16, device=DEVICE))

Best regards

Thomas

1 Like

Thanks for the help :slight_smile:

Now, I’ve tried to change the model datatype. But I’ve got a new error, which is this:

RuntimeError                              Traceback (most recent call last)
<ipython-input-53-d0a5e863d7fc> in <module>
----> 1 model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, N_EPOCHS, DEVICE)

<ipython-input-37-247a8affb032> in training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every)
     13 
     14         # training
---> 15         model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device) #Llamada a la función de entrenamiento
     16         train_losses.append(train_loss) #Añadimos al array de pérdidas de entrenamiento, la périda del entrenamiento de la época
     17 

<ipython-input-35-c9cf685f7d11> in train(train_loader, model, criterion, optimizer, device)
     18 
     19         # Forward pass
---> 20         y_hat, _ = model(X) #Aplicamos el modelo en X y obtenemos la predicción
     21         loss = criterion(y_hat, y_true) #Calculamos la pérdida comparando la predicción con el valor correcto
     22         running_loss += loss.item() * X.size(0) #Aumentamos el valor de la pérdida total sumando la pérdida actual con el tamaño de la imagen actual

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

<ipython-input-38-7278d3f7d36a> in forward(self, x)
     23 
     24     def forward(self, x): #Implementación de Lenet5
---> 25         x = self.feature_extractor(x) #Extrae las características de X y las guarda en X
     26         x = torch.flatten(x, 1) #Aplana un rango de atenuaciones en un tensor. Comienzo de dimensión: X, final de dimensión:1
     27         logits = self.classifier(x) #Clasificamos los valores de X

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    397 
    398     def forward(self, input: Tensor) -> Tensor:
--> 399         return self._conv_forward(input, self.weight, self.bias)
    400 
    401 class Conv3d(_ConvNd):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    393                             weight, bias, self.stride,
    394                             _pair(0), self.dilation, self.groups)
--> 395         return F.conv2d(input, weight, bias, self.stride,
    396                         self.padding, self.dilation, self.groups)
    397 

RuntimeError: at::cuda::blas::gemm: not implemented for N3c108BFloat16E

Apparently, BFloat16 is not implemented in cuda. So, I’ve tried to execute it in CPU and I got this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-60-d0a5e863d7fc> in <module>
----> 1 model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, N_EPOCHS, DEVICE)

<ipython-input-37-247a8affb032> in training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every)
     13 
     14         # training
---> 15         model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device) #Llamada a la función de entrenamiento
     16         train_losses.append(train_loss) #Añadimos al array de pérdidas de entrenamiento, la périda del entrenamiento de la época
     17 

<ipython-input-57-9cb64ce1b7a6> in train(train_loader, model, criterion, optimizer, device)
     18 
     19         # Forward pass
---> 20         y_hat, _ = model(X) #Aplicamos el modelo en X y obtenemos la predicción
     21         loss = criterion(y_hat, y_true) #Calculamos la pérdida comparando la predicción con el valor correcto
     22         running_loss += loss.item() * X.size(0) #Aumentamos el valor de la pérdida total sumando la pérdida actual con el tamaño de la imagen actual

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

<ipython-input-38-7278d3f7d36a> in forward(self, x)
     23 
     24     def forward(self, x): #Implementación de Lenet5
---> 25         x = self.feature_extractor(x) #Extrae las características de X y las guarda en X
     26         x = torch.flatten(x, 1) #Aplana un rango de atenuaciones en un tensor. Comienzo de dimensión: X, final de dimensión:1
     27         logits = self.classifier(x) #Clasificamos los valores de X

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/pooling.py in forward(self, input)
    613 
    614     def forward(self, input: Tensor) -> Tensor:
--> 615         return F.avg_pool2d(input, self.kernel_size, self.stride,
    616                             self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override)
    617 

RuntimeError: "avg_pool2d_out_frame" not implemented for 'BFloat16'

Does this mean that is not possible to execute it with BFloat16?

Thank you

You do hit PyTorch operations which don’t support bfloat16 on the selected backend yet, so it’s not as easy.
There are two parts to this:

  • Personally, I tend to try to look into fixing stuff that I need and doesn’t work. If the backend libraries (eg cublas) support it, it might be doable (I think a relatively new cuda/cublas/cudnn version might help).
  • The other option could be to manually cast only the supported parts and cast the activations in between. This is similar to the automatic multi-precision (AMP) support but for support reasons and less automatic.

Best regards

Thomas

Thank you for your answer.
I’ve tried using Automatic Mixed Precision, and it works.

Best Regards,
Rodrigo