Getting memory allocation error, how can I fix this?

I keep running into memory problems trying to train a neural network in PyTorch. The partition I’m using has 250 GB of RAM and the GPU has 16 GB of memory. This is the error I got:

Epoch:   0%|          | 1/1000 [00:00<00:00, 41527.76it/s]Current run is terminating due to exception: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 69553870848 bytes. Error code 12 (Cannot allocate memory)
frame #0: c10::ThrowEnforceNotMet(char const*, int, char const*, std::string const&, void const*) + 0x47 (0x2aaaee5fd957 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::alloc_cpu(unsigned long) + 0x1ba (0x2aaaee5e2e4a in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x1821a (0x2aaaee5e521a in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #3: at::native::empty_cpu(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0x1ac (0x2aaac7ff14ac in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0xde1b2b (0x2aaac825ab2b in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0xdd3cd7 (0x2aaac824ccd7 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0xdd3999 (0x2aaac824c999 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xdd3cd7 (0x2aaac824ccd7 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xb9aa3e (0x2aaac8013a3e in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::TensorIterator::fast_set_up() + 0x5cf (0x2aaac80148bf in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::TensorIterator::build() + 0x4c (0x2aaac801517c in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&, bool) + 0x146 (0x2aaac8015826 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #12: at::native::add(at::Tensor const&, at::Tensor const&, c10::Scalar) + 0x45 (0x2aaac7d346b5 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0xddc265 (0x2aaac8255265 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0xe23fab (0x2aaac829cfab in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x296aa38 (0x2aaac9de3a38 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0xe23fab (0x2aaac829cfab in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x2af7abd (0x2aaac9f70abd in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x2af9110 (0x2aaac9f72110 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #19: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0xf5b (0x2aaac9f5dd7b in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x2aaac9f5f2f2 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #21: torch::autograd::Engine::thread_init(int) + 0x39 (0x2aaac9f57969 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x2aaac69ae9f8 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #23: <unknown function> + 0xc819d (0x2aaac61e219d in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #24: <unknown function> + 0x7e25 (0x2aaaaacd6e25 in /lib64/libpthread.so.0)
frame #25: clone + 0x6d (0x2aaaaafe334d in /lib64/libc.so.6)
.
Engine run is terminating due to exception: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 69553870848 bytes. Error code 12 (Cannot allocate memory)
frame #0: c10::ThrowEnforceNotMet(char const*, int, char const*, std::string const&, void const*) + 0x47 (0x2aaaee5fd957 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::alloc_cpu(unsigned long) + 0x1ba (0x2aaaee5e2e4a in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x1821a (0x2aaaee5e521a in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #3: at::native::empty_cpu(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0x1ac (0x2aaac7ff14ac in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0xde1b2b (0x2aaac825ab2b in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0xdd3cd7 (0x2aaac824ccd7 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0xdd3999 (0x2aaac824c999 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xdd3cd7 (0x2aaac824ccd7 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xb9aa3e (0x2aaac8013a3e in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::TensorIterator::fast_set_up() + 0x5cf (0x2aaac80148bf in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::TensorIterator::build() + 0x4c (0x2aaac801517c in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&, bool) + 0x146 (0x2aaac8015826 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #12: at::native::add(at::Tensor const&, at::Tensor const&, c10::Scalar) + 0x45 (0x2aaac7d346b5 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0xddc265 (0x2aaac8255265 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0xe23fab (0x2aaac829cfab in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x296aa38 (0x2aaac9de3a38 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0xe23fab (0x2aaac829cfab in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x2af7abd (0x2aaac9f70abd in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x2af9110 (0x2aaac9f72110 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #19: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0xf5b (0x2aaac9f5dd7b in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x2aaac9f5f2f2 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #21: torch::autograd::Engine::thread_init(int) + 0x39 (0x2aaac9f57969 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x2aaac69ae9f8 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #23: <unknown function> + 0xc819d (0x2aaac61e219d in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #24: <unknown function> + 0x7e25 (0x2aaaaacd6e25 in /lib64/libpthread.so.0)
frame #25: clone + 0x6d (0x2aaaaafe334d in /lib64/libc.so.6)
.
Traceback (most recent call last):
  File "model_xvec.py", line 162, in <module>
    trainer.run(train_loader, max_epochs=1000)
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/engine.py", line 659, in run
    return self._internal_run()
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/engine.py", line 723, in _internal_run
    self._handle_exception(e)
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/engine.py", line 438, in _handle_exception
    raise e
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/engine.py", line 698, in _internal_run
    time_taken = self._run_once_on_dataset()
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/engine.py", line 789, in _run_once_on_dataset
    self._handle_exception(e)
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/engine.py", line 438, in _handle_exception
    raise e
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/engine.py", line 772, in _run_once_on_dataset
    self.state.output = self._process_function(self, self.state.batch)
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/ignite/engine/__init__.py", line 95, in _update
    loss.backward()
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/autograd/__init__.py", line 98, in backward
    Variable._execution_engine.run_backward(
RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 69553870848 bytes. Error code 12 (Cannot allocate memory)
frame #0: c10::ThrowEnforceNotMet(char const*, int, char const*, std::string const&, void const*) + 0x47 (0x2aaaee5fd957 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::alloc_cpu(unsigned long) + 0x1ba (0x2aaaee5e2e4a in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x1821a (0x2aaaee5e521a in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #3: at::native::empty_cpu(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0x1ac (0x2aaac7ff14ac in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0xde1b2b (0x2aaac825ab2b in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0xdd3cd7 (0x2aaac824ccd7 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0xdd3999 (0x2aaac824c999 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xdd3cd7 (0x2aaac824ccd7 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xb9aa3e (0x2aaac8013a3e in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::TensorIterator::fast_set_up() + 0x5cf (0x2aaac80148bf in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::TensorIterator::build() + 0x4c (0x2aaac801517c in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&, bool) + 0x146 (0x2aaac8015826 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #12: at::native::add(at::Tensor const&, at::Tensor const&, c10::Scalar) + 0x45 (0x2aaac7d346b5 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0xddc265 (0x2aaac8255265 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0xe23fab (0x2aaac829cfab in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x296aa38 (0x2aaac9de3a38 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0xe23fab (0x2aaac829cfab in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x2af7abd (0x2aaac9f70abd in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x2af9110 (0x2aaac9f72110 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #19: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0xf5b (0x2aaac9f5dd7b in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x2aaac9f5f2f2 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #21: torch::autograd::Engine::thread_init(int) + 0x39 (0x2aaac9f57969 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x2aaac69ae9f8 in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #23: <unknown function> + 0xc819d (0x2aaac61e219d in /users/barendale/.conda/envs/capstone/lib/python3.8/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #24: <unknown function> + 0x7e25 (0x2aaaaacd6e25 in /lib64/libpthread.so.0)
frame #25: clone + 0x6d (0x2aaaaafe334d in /lib64/libc.so.6)

Epoch:   0%|          | 1/1000 [00:46<12:47:58, 46.12s/it]

This is the code I’m using:

import torch
import torch.nn as nn
from collections import OrderedDict
from torch.utils.data import TensorDataset, DataLoader
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import EarlyStopping
from ignite.contrib.engines.common import setup_common_training_handlers
from sklearn.metrics import accuracy_score, confusion_matrix

# Concatenate vectors from parts
X_train = torch.cat([
    torch.load('tensors/filterbanks/filterbanks_train_a.pt'),
    torch.load('tensors/filterbanks/filterbanks_train_b.pt')
])
X_valid = torch.cat([
    torch.load('tensors/filterbanks/filterbanks_valid_a.pt'),
    torch.load('tensors/filterbanks/filterbanks_valid_b.pt')
])
y_train = torch.cat([
    torch.load('tensors/labels/labels_train_a.pt'),
    torch.load('tensors/labels/labels_train_b.pt')
])
y_valid = torch.cat([
    torch.load('tensors/labels/labels_valid_a.pt'),
    torch.load('tensors/labels/labels_valid_b.pt')
])
# Change from one-hot to integer labels
y_train = y_train.argmax(dim=1)
y_valid = y_valid.argmax(dim=1)

train_ds = TensorDataset(X_train, y_train)
valid_ds = TensorDataset(X_valid, y_valid)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(valid_ds, batch_size=32, shuffle=False)

# Pooling function for x-vector network
def mean_std_pooling(x, eps=1e-9):
    m = torch.mean(x, dim=2)
    s = torch.sqrt(torch.mean((x - m.unsqueeze(2))**2, dim=2) + eps)
    x = torch.cat([m, s], dim=1)
    return x

# courtesy of Daniel Garcia-Romero
class xvec(nn.Module):

    def __init__(self, input_dim, layer_dim, mid_dim, embedding_dim, expansion_rate=3, drop_p=0.0, bn_affine=False):

        super(xvec, self).__init__()

        layers = []
        # conv blocks                                                                                                                                                                                                                                                                                         
        layers.extend(self.conv_block(1, input_dim, layer_dim, mid_dim, 5, 1, 0, bn_affine))
        layers.extend(self.conv_block(2, mid_dim, layer_dim, mid_dim, 3, 2, 0, bn_affine))
        layers.extend(self.conv_block(3, mid_dim, layer_dim, mid_dim, 3, 3, 0, bn_affine))
        layers.extend(self.conv_block(4, mid_dim, layer_dim, layer_dim, 3, 4, 0, bn_affine))

        # expansion layer                                                                                                                                                                                                                                                                                     
        layers.extend([('expand_linear', nn.Conv1d(layer_dim, layer_dim*expansion_rate, kernel_size=1)),
                       ('expand_relu', nn.LeakyReLU(inplace=True)),
                       ('expand_bn', nn.BatchNorm1d(layer_dim*expansion_rate, affine=False))])

        # Dropout pre-pooling                                                                                                                                                                                                                                                                                 
        if drop_p > 0.0:
            layers.extend([('drop_pre_pool', nn.Dropout2d(p=drop_p, inplace=True))])

        self.prepooling_layers = nn.Sequential(OrderedDict(layers))

        # pooling defined below                                                                                                                                                                                                                                                                               

        # embedding                                                                                                                                                                                                                                                                                           
        self.embedding = nn.Linear(layer_dim*expansion_rate*2, embedding_dim)

        self.init_weight()

    def conv_block(self, index, in_channels, mid_channels, out_channels, kernel_size, dilation, padding, bn_affine=False):
         return [('conv%d' % index, nn.Conv1d(in_channels, mid_channels, kernel_size, dilation=dilation, padding=padding)),
                 ('relu%d' % index, nn.LeakyReLU(inplace=True)),
                 ('bn%d' % index, nn.BatchNorm1d(mid_channels, affine=bn_affine)),
                 ('linear%d' % index, nn.Conv1d(mid_channels, out_channels, kernel_size=1)),
                 ('relu%da' % index, nn.LeakyReLU(inplace=True)),
                 ('bn%da' % index, nn.BatchNorm1d(out_channels, affine=bn_affine))]

    def init_weight(self):
        """                                                                                                                                                                                                                                                                                                   
        Initialize weight with sensible defaults for the various layer types                                                                                                                                                                                                                                  
        :return:                                                                                                                                                                                                                                                                                              
        """
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                print("Initializing %s with kaiming_normal" % str(m))                                                                                                                                                                                                                                                      
                nn.init.kaiming_normal_(m.weight, a=0.01)

            if isinstance(m, nn.Linear):
                print("Initializing %s with kaiming_normal" % str(m))                                                                                                                                                                                                                                                          
                nn.init.kaiming_normal_(m.weight, a=0.01)


    def extract_pre_pooling(self, x):
        x = self.prepooling_layers(x)
        return x

    def extract_post_pooling(self, x):
        x = self.extract_pre_pooling(x)
        x = mean_std_pooling(x)
        return x

    def extract_embedding(self, x):        
        x = self.extract_post_pooling(x)
        x = self.embedding(x)
        return x

    def forward(self, x):
        # Compute embeddings                                                                                                                                                                                   #                                                                                               
        x = self.extract_embedding(x)
        return x

# E-TDNN architecture
model = nn.Sequential(
    xvec(input_dim=64, layer_dim=512, mid_dim=198, embedding_dim=512),
    nn.LeakyReLU(inplace=True),
    nn.Linear(512, 512),
    nn.LeakyReLU(inplace=True),
    nn.Linear(512,8),
    nn.LogSoftmax(dim=1)
)

device = torch.device("cuda:0")
model.to(device)

optimizer = torch.optim.Adam(model.parameters())
loss = torch.nn.NLLLoss()

trainer = create_supervised_trainer(model, optimizer, loss, device=device)
evaluator = create_supervised_evaluator(model,
                                        metrics={
                                            'accuracy': Accuracy(),
                                            'nll': Loss(loss)
                                            }, device=device)

# Setup checkpoints and progress bar
setup_common_training_handlers(
    trainer = trainer,
    to_save = {'model': model, 'optimizer': optimizer, 'trainer': trainer},
    save_every_iters = 10000,
    output_path = 'checkpoints/xvec_model/'
)
# Print validation loss and accuracy
@trainer.on(Events.EPOCH_COMPLETED)
def run_evaluator(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print(f"Validation loss: {metrics['nll']}, accuracy: {metrics['accuracy']}")
# Setup early stopping
def score_function(engine):
    val_loss = engine.state.metrics['nll']
    return -val_loss
handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)

trainer.run(train_loader, max_epochs=1000)

# Save model
torch.save(model.state_dict(), 'model_xvec_weights.pt')

# Test results
X_test = torch.cat([
    torch.load('tensors/filterbanks/filterbanks_test_a.pt'),
    torch.load('tensors/filterbanks/filterbanks_test_b.pt')
])
y_test = torch.cat([
    torch.load('tensors/labels/labels_test_a.pt'),
    torch.load('tensors/labels/labels_test_b.pt')
])
y_test = y_test.argmax(dim=1)

model.eval()
with torch.no_grad():
    y_pred = model(X_test)
    torch.save(y_pred, 'model_xvec_predictions.pt')
    # Get predicted class
    y_pred = y_pred.argmax(dim=1)
    print('Test accuracy:', accuracy_score(y_test, y_pred))
    print('Confusion matrix:', confusion_matrix(y_test, y_pred))

X_train is 66 GB, X_valid is 13 GB, and the labels are only about 100 MB combined. It doesn’t seem like it should be enough to cause a memory error even including the model weights and gradients. Seems to be a CPU RAM thing, not GPU, so I don’t know where the problem is.

@barendale you are using fully connected layers that is the reason your network requires a lot of memory the two things you can do is reduce the fully connect nodes in the linear function in the model and also try reducing the batch size if the you need to retain the same model