Batchnorm1D - CUDA error: an illegal memory access was encountered

Hi all,

I am trying to use batchnorm1d layer in the linear output layer of my model as follows:

class output_layers(torch.nn.Module):
  def __init__(self, param):
    super(output_layers, self).__init__()
    self.param = param
    self.num_layers = self.param["decoder_num_layers"] 
    self.activation = self.param["activation_function"]
    self.input_size = self.param["lstm_hidden_size"] * self.param["seq_len"]
    self.factor = math.exp(np.log(1 / self.input_size) / self.num_layers)
    self.activation_layer = torch.nn.Mish()
   in_size = self.input_size
    for l in range(self.num_layers):
      if l != (self.num_layers - 1):
        self.decoder.add_module(f"decoder_block{l}", torch.nn.Sequential(torch.nn.Linear(in_size, round(in_size * self.factor)),
                                                                        self.activation_layer,
                                                                        torch.nn.BatchNorm1d(round(in_size * self.factor))))
        in_size = round(in_size*self.factor)
      else:
        self.decoder.add_module("output", torch.nn.Linear(in_size, 1))

  def forward(self, x):
    return self.decoder(x)

I plugged in this module into another main module with the forward implementation as below:

   def forward(self, x_con, x_cat):
    x = self.embedding(x_con, x_cat)
    x = x.view(-1, self.seq_len, self.input_size)
    if self.training == True:
      x = self.noise_layer(x)
    x = self.autoencoder(x)
    x = x.contiguous().view(-1, self.seq_len, self.encoded_final_size)
    output, (h_n, c_n) = self.lstm(x)
    output = output.contiguous().view(self.batch_size,-1)
    output = self.output_layer(output).squeeze()
    return output, (h_n, c_n)

While running though this output layer, the following error message pops out:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/optuna/_optimize.py", line 216, in _run_trial
    value_or_values = func(trial)
  File "<ipython-input-18-78249d7efa36>", line 154, in __call__
    output, (h_n, c_n) = model(X_con, X_cat)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "<ipython-input-15-ab96b0392a9e>", line 35, in forward
    output = self.output_layer(output).squeeze()
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "<ipython-input-14-061a1e634053>", line 29, in forward
    return self.decoder(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/batchnorm.py", line 178, in forward
    self.eps,
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2282, in batch_norm
    input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: CUDA error: an illegal memory access was encountered

The code was running with CUDA_LAUNCH_BLOCKING = 1

Cuda version is 11.0. GPU is Tesla P100 on google colab. Tried both pytorch version 1.9.0 and pytorch nighty. Same error message pops out. It only disappear if I deleted the batchnorm 1d layer.

Also tried to set “torch.backends.cudnn.enabled = False” but the error retain.

Most grateful if there are any solution to it. Much obliged.

Could you post an executable code snippets using random tensors, so that we could try to reproduce this issue, please?

Thank you for the reply!

Sorry I should have provided the code snippet in my question but the below are the snippet:

import gc
import math
import os

import numpy as np
import pandas as pd
import torch
from torch.utils.data.dataset import Dataset

class Autoencoder(torch.nn.Module):
  def __init__(self):
    super(Autoencoder, self).__init__()
    self.num_layers = 8 
    self.factor = 0.8
    self.drop_out_percent = 0.2
    self.input_size = 100
    self.activation_layer = torch.nn.Mish()
    self.encoder = torch.nn.Sequential()
    in_size = self.input_size
    for l in range(self.num_layers):
      self.encoder.add_module(f"encoder_block{l}", torch.nn.Sequential(torch.nn.Linear(in_size, round(in_size * self.factor)),
                                                                      self.activation_layer,
                                                                      torch.nn.LayerNorm(round(in_size * self.factor)),
                                                                      torch.nn.Dropout(self.drop_out_percent)))
      in_size = round(in_size*self.factor)
    self.encoder.add_module(f"encoder_block_last", torch.nn.Sequential(torch.nn.Linear(in_size, 2000),
                                                                       self.activation_layer,
                                                                       torch.nn.LayerNorm(2000),
                                                                       torch.nn.Dropout(self.drop_out_percent)))

    self.final_size = in_size

  def forward(self, x):
    return self.encoder(x)

class Output_Layers(torch.nn.Module):
  def __init__(self):
    super(Output_Layers, self).__init__()
    self.num_layers = 10
    self.input_size = 2000
    self.factor = math.exp(np.log(1 / self.input_size) / self.num_layers)
    self.activation_layer = torch.nn.Mish()
    self.decoder = torch.nn.Sequential()
    in_size = self.input_size
    for l in range(self.num_layers):
      if l != (self.num_layers - 1):
        self.decoder.add_module(f"decoder_block{l}", torch.nn.Sequential(torch.nn.Linear(in_size, round(in_size * self.factor)),
                                                                        self.activation_layer,
                                                                        torch.nn.BatchNorm1d(round(in_size * self.factor))))
        in_size = round(in_size*self.factor)
      else:
        self.decoder.add_module("output", torch.nn.Linear(in_size, 1))

  def forward(self, x):
    return self.decoder(x)

class Driver_Model(torch.nn.Module):
  def __init__(self,  autoencoder, output_layer):
    super(Driver_Model, self).__init__()
    self.autoencoder = autoencoder
    self.output_layer = output_layer

  def forward(self, x):
    output = self.autoencoder(x).contiguous()
    output = self.output_layer(output).squeeze()
    return output

class My_Dataset(Dataset):
  def __init__(self, X):
    self.X = X
    
  def __len__(self):
    return(len(self.X))

  def __getitem__(self, idx):
    return self.X[idx]

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
data_length = 10000
data = []
for i in range(data_length):
  data.append(torch.rand(100))

autoencoder = Autoencoder()
output = Output_Layers()

model_parameter_list = list(autoencoder.parameters()) + list(output.parameters())

model = Driver_Model(autoencoder, output)
model.cuda().double()

test_dataset = My_Dataset(data)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 64, shuffle = False, num_workers = 4, pin_memory = True, drop_last = True)
print(model)

for j, data in enumerate(test_dataloader):
  data = data.cuda().double()
  output = model(data)

Many thanks!

Further update, seems that if i tune the input size of Output_Layers to lower number the issue disappear:

class Output_Layers(torch.nn.Module):
  def __init__(self):
    super(Output_Layers, self).__init__()
    self.num_layers = 10
    self.input_size = 200 #Tuned from 2000 to 200
    self.factor = math.exp(np.log(1 / self.input_size) / self.num_layers)
    self.activation_layer = torch.nn.Mish()
    self.decoder = torch.nn.Sequential()
    in_size = self.input_size
    for l in range(self.num_layers):
      if l != (self.num_layers - 1):
        self.decoder.add_module(f"decoder_block{l}", torch.nn.Sequential(torch.nn.Linear(in_size, round(in_size * self.factor)),
                                                                        self.activation_layer,
                                                                        torch.nn.BatchNorm1d(round(in_size * self.factor))))
        in_size = round(in_size*self.factor)
      else:
        self.decoder.add_module("output", torch.nn.Linear(in_size, 1))

  def forward(self, x):
    return self.decoder(x)

Thanks for the code snippet!
While I was able to reproduce the memory violation in 1.9.0, I wasn’t able to see it in the nightly binary as well as a pretty new source build, so I guess it might have been a known issue, which was aready fixed. Could you install the nightly and verify it, please?

Thank you for the reply!

I tried run again the code with the nighty version installed (Cuda 11.1). The error retains : (

The GPU is Tesla P100 on google colab.

I have the similar problem. Both P100 or T4 have this problem, but K80 works.

Same driver number, same cuda version (10.2), same code, just different hardware

Could you post an executable code snippet to reproduce the issue as well as the output of python -m torch.utils.collect_env as I wasn’t able to reproduce the issue the last time using the nightly binary and a source build.

Hi ptrblck,

I can confirm it’s fixed in the nighly build of 1.10.

Will the fix be in soon to be released 1.9.1?

I’m not aware of a planned 1.9.1 release, but based on the past release cadence I would expect to see a branch cut for 1.10.0 soon.