GPU out of memory with torch.nn.Bilinear

Hello,

I am currently trying to train a model which includes among other modules the torch.nn.Bilinear one :

self.batch_size = 1
input_len = 100
output_len = 2
self.bilinear = nn.Bilinear(
            input_len,
            input_len,
            output_len,
            bias=False,
        )
 self.left_indices = [
            index
            for index in range(config["max_seq_len"])
            for _ in range(config["max_seq_len"] - index)
        ]
self.right_indices = [
            higher_index
            for index in range(config["max_seq_len"])
            for higher_index in range(index, config["max_seq_len"])
        ]

Which I use this way during training :

...
output_tensor = torch.zeros(
            self.batch_size,
            self.max_seq_len,
            self.max_seq_len,
            self.output_len,
            device=used_device,
        )
output_tensor[
                :, self.left_indices, self.right_indices
            ] = self.bilinear(
            input_left[:, self.left_indices],
            input_right[:, self.right_indices],
        )
....

However I hit the traditional “CUDA out of memory.” (which seems to happen when processing a second batch)

I tried varying the size of the two input tensors (from 1 to 100) and the resource allocated to the GPU takes up to 8GB with a Batch size of 1.

Is this normal or am I doing something wrong ? (assigning the module output to a tensor like that ?)

Thank you for your help ! =)

Hi, your answer has multiple aspects, I’ll try to answer all.

  1. Nothing here seems way off, but it’s hard to tell if this should result in CUDA OOM without the real numbers (e.g. the config params). Could you share a reproducible example?
  2. The fact that it happens in the second batch (assuming you mean a second call to self.bilinear and not batch_size=2 indicates some memory used in the first call is not getting freed. This might be unintentional, but again a full code would help.
  3. Assigning the module output to a tensor is fine
  4. You could replace those nested list comprehensions with PyTorch functions (e.g. see torch.arange, repeat).