Extremely slow training, high single CPU usage

Hello there, this is my first time posting. I’m having trouble with some code I found on github and I’m working on it. Weirdly enough, training is extremely slow. It shows that only the first core of the CPU is being 100% used. I have no idea what to do next since from what I saw everything should be at least in the right place.

I apologize for the long code. Thanks for the attention.

enSN = True
enGP = True
device = 'cuda:0'

Tensor = torch.cuda.FloatTensor

from torch.nn.utils import spectral_norm as SN_
if enSN:
    SN=SN_
else:
    SN=lambda x:x

Hidden = List[Tuple[Tensor, ...]]

def enlarge_as(src : Tensor, other : Tensor) -> Tensor:
    '''
        Add sufficient number of singleton dimensions
        to tensor a **to the right** so to match the
        shape of tensor b. NOTE that simple broadcasting
        works in the opposite direction.
    '''
    return rearrange(src, f'... -> ...{" 1" * (other.dim() - src.dim())}').contiguous()

class CausalConv1d(nn.Conv1d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        groups=1,
        bias=True
    ):
        self._padding = (kernel_size - 1) * dilation

        super(CausalConv1d, self).__init__(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=self._padding,
            dilation=dilation,
            groups=groups,
            bias=bias)

    def forward(self, inp : Tensor) -> Tensor:
        # Handle the case where input has only two dimensions
        # we expect them to have semantics (batch, channels),
        # so we add the missing dimension manually
        if inp.dim() == 2: inp = rearrange(inp, 'b i -> b 1 i')
        
        result = super(CausalConv1d, self).forward(inp)
        if self._padding != 0: return result[..., :-self._padding]
        return result
    
class BlockLinear(nn.Module):
    def __init__(
        self,
        block_dims : List[int | List[int]],
        bias : bool = False,
    ):
        super(BlockLinear, self).__init__()
        
        self._blocks = nn.ParameterList([
            nn.Parameter(torch.randn(size, requires_grad=True))
            for size in block_dims
        ])
        
        self._bias = nn.Parameter(torch.zeros(sum(block_dims))) if bias else None
        
    def forward(self, inp : Tensor) -> Tensor:
        # Assemble the blocks into a block-diagonal matrix
        full = torch.block_diag(*self._blocks)
        
        out = torch.matmul(inp, full)
        
        if self._bias is not None:
            out = out + self._bias
        
        return out

class sLSTM(nn.Module):
    '''The scalar-Long Short Term Memory (sLSTM) module as
    originally introduced in Beck et al. (2024)] see:
    (https://arxiv.org/abs/2405.04517).
    
    This model is a variant of the standard LSTM model and
    offers two major improvements:
    - Exponential gating with appropriate state normalization
        to avoid overflows induced by the exponential function.
    - A new memory mixing within heads but not across heads.
    '''
    
    def __init__(
        self,
        inp_dim : int,
        head_dim : int,
        head_num : int,
        ker_size : int = 4,
        p_factor : float = 4/3,
    ) -> None:
        super().__init__()
        
        self.inp_dim = inp_dim
        self.head_dim = head_dim
        self.head_num = head_num
        
        self.inp_norm = nn.LayerNorm(inp_dim)
        self.hid_norm = nn.GroupNorm(head_num, head_dim * head_num)
        
        self.causal_conv = CausalConv1d(1, 1, kernel_size=ker_size)
        
        self.W_z = nn.Linear(inp_dim, head_num * head_dim)
        self.W_i = nn.Linear(inp_dim, head_num * head_dim)
        self.W_o = nn.Linear(inp_dim, head_num * head_dim)
        self.W_f = nn.Linear(inp_dim, head_num * head_dim)
        
        self.R_z = BlockLinear([(head_dim, head_dim)] * head_num)
        self.R_i = BlockLinear([(head_dim, head_dim)] * head_num)
        self.R_o = BlockLinear([(head_dim, head_dim)] * head_num)
        self.R_f = BlockLinear([(head_dim, head_dim)] * head_num)
        
        # NOTE: The factor of two in the output dimension of the up_proj
        # is due to the fact that the output needs to branch into two
        # separate outputs to account for the the gated GeLU connection.
        # See Fig. 9 in the paper.
        proj_dim = int(p_factor * head_num * head_dim)
        self.up_proj   = nn.Linear(head_num * head_dim, 2 * proj_dim)
        self.down_proj = nn.Linear(proj_dim, inp_dim)
        
    @property
    def device(self) -> str:
        '''Get the device of the model.

        Returns:
            str: The device of the model.
        '''
        return next(self.parameters()).device
        
    def init_hidden(self, bs : int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        '''Initialize the hidden state of the sLSTM model.

        Args:
            batch_size (int): The batch size of the input sequence.

        Returns:
            Tuple[Tensor, Tensor, Tensor, Tensor]: The hidden state tuple containing the cell state,
                normalizer state, hidden state, and stabilizer state.
        '''
        
        n_0 = torch.ones (bs, self.head_num * self.head_dim, device=self.device)
        c_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device)
        h_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device)
        m_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device)
        
        return c_0, n_0, h_0, m_0
        
    def forward(
        self,
        seq: Tensor,
        hid: Tuple[Tensor, Tensor, Tensor, Tensor],
        use_conv : bool = False,    
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]:
        '''Forward pass of the sLSTM model.

        Args:
            seq (Tensor): The input sequence tensor of shape (batch_size, input_dim).
            hid (Tuple[Tensor, Tensor, Tensor, Tensor]): The hidden state tuple containing the cell state,
                normalizer state, hidden state, and stabilizer state.

        Returns:
            Tuple[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: The output tensor with the residual
                connection and the newly updated hidden state tuple.
        '''
        
        b, d = seq.shape
        
        # Separate the hidden (previous) state into the cell state,
        # the normalizer state, the hidden state, and the stabilizer state.
        c_tm1, n_tm1, h_tm1, m_tm1 = hid
        
        x_t : Tensor = self.inp_norm(seq)
        
        # Optional causal convolution block for the input
        # and forget gates. See Fig. 9 in the paper.
        if use_conv:
            # FIXME: The causal conv branch is broken.
            x_c = self.causal_conv(x_t)
            x_c = F.silu(x_c).squeeze()
        else:
            x_c = x_t
        
        # Project the input to the different heads for all
        # the gates.
        # NOTE: For input (i) and forget (f) inputs we use
        # the output of the causal conv. See Fig. 9 in the paper.
        i_t: Tensor = self.W_i(x_c) + self.R_i(h_tm1) 
        f_t: Tensor = self.W_f(x_c) + self.R_f(h_tm1) 
        z_t: Tensor = self.W_z(x_t) + self.R_z(h_tm1)
        o_t: Tensor = self.W_o(x_t) + self.R_o(h_tm1)
        
        # Compute the gated outputs for the newly computed inputs
        m_t = torch.max(f_t + m_tm1, i_t)
        
        i_t = exp(i_t - m_t)         # Eq. (16) in ref. paper | or Eq. (38) in supp. mat.
        f_t = exp(f_t - m_t + m_tm1) # Eq. (17) in ref. paper | or Eq. (39) in supp. mat.
        
        z_t = tanh(z_t)              # Eq. (11) in ref. paper
        o_t = sigmoid(o_t)           # Eq. (14) in ref. paper
        
        # Update the internal states of the model
        c_t = f_t * c_tm1 + i_t * z_t # Eq. (8) in ref. paper
        n_t = f_t * n_tm1 + i_t       # Eq. (9) in ref. paper
        h_t = o_t * (c_t / n_t)       # Eq. (10) in ref. paper
        
        # Compute the output of the LSTM block
        out = self.hid_norm(h_t)
        
        # Perform up-and-down projection of the output with
        # projection factor 4/3. See Fig. (9) in supp. mat.
        out1, out2 = self.up_proj(out).chunk(2, dim=-1)
        
        out = out1 + F.gelu(out2)
        out = self.down_proj(out)
        
        # Return output with the residual connection and the
        # newly updated hidden state.
        return out + seq, (c_t, n_t, h_t, m_t)
        
class mLSTM(nn.Module):
    '''The matrix-Long Short Term Memory (mLSTM) module as
    originally introduced in Beck et al. (2024)] see:
    (https://arxiv.org/abs/2405.04517).
    
    This model is a variant of the standard LSTM model and
    offers superior memory due to its storing values in a
    matrix instead of a scalar. It is fully parallelizable
    and updates internal memory with the covariance rule.
    '''
    
    def __init__(
        self,
        inp_dim : int,
        head_num : int,
        head_dim : int,
        p_factor : int = 2,
        ker_size : int = 4,
    ) -> None:
        super().__init__()
        
        self.inp_dim = inp_dim
        self.head_num = head_num
        self.head_dim = head_dim

        hid_dim = head_num * head_dim
        
        self.inp_norm = nn.LayerNorm(inp_dim)
        self.hid_norm = nn.GroupNorm(head_num, hid_dim)
        
        # NOTE: The factor of two in the output dimension of the up_proj
        # is due to the fact that the output needs to branch into two
        self.up_l_proj = nn.Linear(inp_dim, int(p_factor * inp_dim))
        self.up_r_proj = nn.Linear(inp_dim, hid_dim)
        self.down_proj = nn.Linear(hid_dim, inp_dim)
        
        self.causal_conv = CausalConv1d(1, 1, kernel_size=ker_size)
        
        self.skip = nn.Conv1d(int(p_factor * inp_dim), hid_dim, kernel_size=1, bias=False)
        
        self.W_i = nn.Linear(int(p_factor * inp_dim), head_num)
        self.W_f = nn.Linear(int(p_factor * inp_dim), head_num)
        self.W_o = nn.Linear(int(p_factor * inp_dim), hid_dim)
        
        self.W_q = nn.Linear(int(p_factor * inp_dim), hid_dim)
        self.W_k = nn.Linear(int(p_factor * inp_dim), hid_dim)
        self.W_v = nn.Linear(int(p_factor * inp_dim), hid_dim)
        
    @property
    def device(self) -> str:
        '''Get the device of the model.

        Returns:
            str: The device of the model.
        '''
        return next(self.parameters()).device
    
    def init_hidden(self, bs : int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        '''Initialize the hidden state of the sLSTM model.

        Args:
            batch_size (int): The batch size of the input sequence.

        Returns:
            Tuple[Tensor, Tensor, Tensor, Tensor]: The hidden state tuple containing the cell state,
                normalizer state, hidden state, and stabilizer state.
        '''
        
        c_0 = torch.zeros(bs, self.head_num, self.head_dim, self.head_dim, device=self.device)
        n_0 = torch.ones (bs, self.head_num, self.head_dim               , device=self.device)
        m_0 = torch.zeros(bs, self.head_num                              , device=self.device)
        
        return c_0, n_0, m_0
    
    def forward(
        self,
        seq: Tensor,
        hid: Tuple[Tensor, Tensor],
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        '''_summary_

        Args:
            seq (Tensor): _description_
            hid (Tuple[Tensor, Tensor]): _description_

        Returns:
            Tuple[Tensor, Tuple[Tensor, Tensor]]: _description_
        '''
        
        # Separate the hidden (previous) state into the cell state,
        # the normalizer state, the hidden state, and the stabilizer state.
        c_tm1, n_tm1, m_tm1 = hid
        
        x_n : Tensor = self.inp_norm(seq) # shape: b i
        
        x_t = self.up_l_proj(x_n) # shape: b (i * p_factor)
        r_t = self.up_r_proj(x_n) # shape: b (h d)
        
        # Compute the causal convolutional input (to be 
        # used for the query and key gates)
        x_c = self.causal_conv(x_t)                    # shape: b 1 (i * p_factor)
        x_c = rearrange(F.silu(x_c), 'b ... -> b (...)') # shape: b   (i * p_factor)
        
        q_t = rearrange(self.W_q(x_c), 'b (h d) -> b h d', h=self.head_num)
        k_t = rearrange(self.W_k(x_c), 'b (h d) -> b h d', h=self.head_num) / sqrt(self.head_dim)
        v_t = rearrange(self.W_v(x_t), 'b (h d) -> b h d', h=self.head_num)
        
        i_t: Tensor = self.W_i(x_c) # shape: b h
        f_t: Tensor = self.W_f(x_c) # shape: b h
        o_t: Tensor = self.W_o(x_t) # shape: b (h d)
        
        # Compute the gated outputs for the newly computed inputs
        m_t = torch.max(f_t + m_tm1, i_t)
        
        i_t = exp(i_t - m_t)         # Eq. (25) in ref. paper
        f_t = exp(f_t - m_t + m_tm1) # Eq. (26) in ref. paper
        o_t = sigmoid(o_t)           # Eq. (27) in ref. paper
        
        # Update the internal states of the model
        c_t = enlarge_as(f_t, c_tm1) * c_tm1 + enlarge_as(i_t, c_tm1) * einsum(v_t, k_t, 'b h d, b h p -> b h d p')
        n_t = enlarge_as(f_t, n_tm1) * n_tm1 + enlarge_as(i_t, k_t)   * k_t                    
        h_t = o_t * rearrange(
                einsum(c_t, q_t, 'b h d p, b h p -> b h d') /
                einsum(n_t, q_t, 'b h d, b h d -> b h').clamp(min=1).unsqueeze(-1),
                'b h d -> b (h d)'
            ) # Eq. (21) in ref. paper

        x_c = rearrange(x_c, 'b i -> b i 1')
        out = self.hid_norm(h_t) + self.skip(x_c).squeeze() # shape: b (h d)
        out = out * F.silu(r_t)                               # shape: b (h d)
        out = self.down_proj(out)                           # shape: h i
        
        # Return output with the residual connection and the
        # newly updated hidden state.
        return out + seq, (c_t, n_t, m_t)

class xLSTM(nn.Module):
    '''The extended Long Short Term Memory (xLSTM) module as
    originally introduced in Beck et al. (2024)] see:
    (https://arxiv.org/abs/2405.04517).
    
    This model stacks sLSTM and mLSTM modules with residual
    connections and offers superior memory and performance
    compared to the standard LSTM model, achieving competitive
    or better performance and scaling than Transformer models
    or State-Space models.
    '''
    
    def __init__(
        self, 
        num_layers : int,
        signature : Tuple[int, int],
        inp_dim : int,
        head_dim : int,
        head_num : int,
        p_factor : Tuple[float, float] = (2, 4/3),
        ker_size : int = 4,
        out_dim : int = None
    ) -> None:
        '''Initialize the model.

        Args:
            num_layers (int): The number of layers in the model.
            signature (Tuple[int, int]): The signature of the model,
                which represents the ration of the mLSTM-to-sLSTM blocks.
            inp_dim (int): The dimension of the input tokens.
            head_dim (int): The dimension of each attention head.
            head_num (int): The number of attention heads.
            p_factor (Tuple[float, float], optional): The expansion factor
                for the MLP projection in the m|s-LSTM blocks. Defaults to (2, 4/3).
            ker_size (int, optional): The kernel size for the causal convolutional layers.
                Defaults to 4.
        '''
        super().__init__()

        # Calculate out_dim if not provided
        if out_dim is None:
            out_dim = head_num * head_dim
        else:
            out_dim = out_dim

        m_factor, s_factor = p_factor
        
        mlstm_par = {
            'inp_dim' : inp_dim,
            'head_dim' : head_dim,
            'head_num' : head_num,
            'p_factor' : m_factor,
            'ker_size' : ker_size,
        }
        
        slstm_par = {
            'inp_dim' : inp_dim,
            'head_dim' : head_dim,
            'head_num' : head_num,
            'p_factor' : s_factor,
            'ker_size' : ker_size,
        }

        m_num, s_num = signature
        which = [True] * m_num + [False] * s_num

        self.model : List[mLSTM | sLSTM] = nn.ModuleList([
            mLSTM(**mlstm_par) if v else sLSTM(**slstm_par)
            for w in repeat(which, num_layers) for v in w
        ]).to(device)

        # Prediction head to map the output of the xLSTM model to the output
        self.head = nn.Linear(inp_dim, out_dim, bias=False)

    def forward(
        self,
        x: Tensor,
        hid: Hidden | None = None,
        batch_first : bool = False,
    ) -> Tuple[Tensor, Hidden]:
        '''Forward pass of the xLSTM model.

        Args:
            x (Tensor): Input tensor representing the sequence tokens.
                Expected shape: (batch, seq_len) if batch_first=True,
                else (seq_len, batch).
            hid (Hidden, optional): Cache object for storing intermediate hidden
                values of the m|s-LSTM blocks of the model. If None, the hidden
                states are initialized by the models. Defaults to None.

        Returns:
            Tuple[Tensor, Hidden]: Returns tensor of predicted logits of shape
                (batch, seq_len, vocab_size) if batch_first=True or of shape
                (seq_len, batch, vocab_size) if batch_first=False, and the
                updated hidden model states.
        '''

        if batch_first: x = rearrange(x, 'b s i -> s b i')
        if hid is None: hid = [l.init_hidden(x.shape[1]) for l in self.model]
        
        # Pass the sequence through the mLSTM and sLSTM blocks
        out = []
        for inp in x:
            # Compute model output and update the hidden states
            for i, lstm in enumerate(self.model):
                inp, hid[i] = lstm(inp, hid[i])
            
            out.append(inp)
            
        out = torch.stack(out, dim=1 if batch_first else 0)
        out = self.head(out)
        
        return out, hid

class FourierFeatureMapper(nn.Module):
    def __init__(self, input_dim, num_features, scale=10.0):
        """
        Fourier feature mapping module.

        Args:
            input_dim (int): Dimension of the input.
            num_features (int): Number of Fourier features.
            scale (float): Scaling factor for random Fourier frequencies.
        """
        super(FourierFeatureMapper, self).__init__()
        self.B = nn.Parameter(
            scale * torch.randn((input_dim, num_features)), requires_grad=False
        )

    def forward(self, x):
        """
        Map input to Fourier features.

        Args:
            x (tensor): [batch_size, seq_len, input_dim]

        Returns:
            tensor: [batch_size, seq_len, num_features * 2]
        """
        # x: [batch_size, seq_len, input_dim]
        x_proj = 2 * torch.pi * torch.matmul(x, self.B)  # [batch_size, seq_len, num_features]
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)  # Concatenate along the feature dimension

# Define the Generator class
class Generator(nn.Module):
    def __init__(self, input_features=5, hidden_size=64, output_features=1):
        super(Generator, self).__init__()
        self.model = nn.ModuleDict({
            'fourier': FourierFeatureMapper(
                input_dim=input_features,
                num_features=128,
                scale=10.0
            ),
            'xLSTM': xLSTM(
                num_layers=1,
                signature=(2, 3),
                inp_dim=256,
                head_dim=16,
                head_num=16,
                ker_size=16,
                p_factor=(2, 4/3)
            ),
            'linear': nn.Linear(
                in_features=hidden_size,
                out_features=output_features
            ),
            'Sigmoid': nn.Sigmoid()
        }).to(device)

    def forward(self, x):
        # x: [batch_size, sequence_length, input_features]
        fourier_out = self.model['fourier'](x)
        xlstm_out, _ = self.model['xLSTM'](fourier_out)

        # Apply Linear layer to each time step
        linear_out = self.model['linear'](xlstm_out)
        linear_out = self.model['Sigmoid'](linear_out)
        # Return all time steps (not just the last one)
        return linear_out  # Shape will be [batch_size, sequence_length, output_features]

# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self, input_features=5, hidden_size=64):
        super(Discriminator, self).__init__()
        self.model = nn.ModuleDict({
            'fourier': FourierFeatureMapper(
                input_dim=input_features,
                num_features=128,
                scale=10.0
            ),
            'xLSTM': xLSTM(
                num_layers=1,
                signature=(2, 3),
                inp_dim=256,
                head_dim=16,
                head_num=16,
                ker_size=16,
                p_factor=(2, 4/3)
            ),
            'linear': nn.Linear(
                in_features=hidden_size,
                out_features=1
            )
        }).to(device)

    def forward(self, x):
        # x: [batch_size, sequence_length, input_features]
        fourier_out = self.model['fourier'](x)
        xlstm_out, _ = self.model['xLSTM'](fourier_out)
        # Apply Linear layer to each time step
        linear_out = self.model['linear'](xlstm_out)
        # Return the output for the last time step
        return linear_out

torch.cuda.init()
torch.cuda.current_device()
torch.cuda.set_device(0)  # For the first GPU

# Instantiate models
generator = Generator(input_features=5, hidden_size=256, output_features=5).to(device)
discriminator = Discriminator(input_features=5, hidden_size=256).to(device)

generator.to(device)
discriminator.to(device)

# Optimizers
alpha = 5e-5
g_optim = optim.RMSprop(generator.parameters(), lr=alpha)
c_optim = optim.RMSprop(discriminator.parameters(), lr=alpha)

g_losses = []
c_losses = []
images = []

# Data generation functions
def row(batch_size, sequence_length, input_features):
    """
    Generate structured input data with shape (batch_size, sequence_length, input_features).
    """
    # Create a random array for features of shape [batch_size, sequence_length, input_features]
    random0 = np.random.rand(batch_size, sequence_length, input_features) * 2 * np.pi
    
    # Generate a feature index array for the shape [1, 1, input_features]
    feature_indices = np.arange(input_features).reshape(1, 1, -1)  # Shape (1, 1, input_features)
    
    # Apply broadcasting to generate data of shape [batch_size, sequence_length, input_features]
    randarray = random0 * feature_indices  # Broadcast to [batch_size, sequence_length, input_features]
    
    # Create another random array for additional features, same shape [batch_size, sequence_length, input_features]
    random1 = np.random.rand(batch_size, sequence_length, input_features) * 2 * np.pi
    
    # Combine them to create a cyclic pattern
    output = random1 + randarray
    
    # Apply cosine transformation
    output = np.cos(output) * 0.5 + 0.5  # Normalize between 0 and 1
    
    # Convert the output to a PyTorch tensor and send it to the correct device
    return torch.from_numpy(output.astype(np.float32)).to(device)


def noise(batch_size, sequence_length, input_features):
    """
    Generate random noise with shape (batch_size, sequence_length, input_features).
    """
    return torch.rand(batch_size, sequence_length, input_features).to(device)

# Training discriminator
from torch.autograd import grad as tg
if enGP:
    def train_discriminator(optimizer, real_data, fake_data, c=0.01):
        optimizer.zero_grad()
        error_real = discriminator(real_data).mean()
        error_fake = discriminator(fake_data).mean()
        total_error = -(error_real - error_fake)

        batch_size, sequence_length, _ = real_data.size()  # Extract dimensions
        ep = torch.rand(batch_size, sequence_length, 1).to(device)  # Match shape with input

        middle_data = ep * real_data + (1 - ep) * fake_data
        middle_data.requires_grad_()
        discriminator_middle = discriminator(middle_data)

        # Adjust grad_outputs shape to match discriminator_middle
        grad_outputs = torch.ones_like(discriminator_middle, device=device)
        grad = tg(outputs=discriminator_middle, inputs=middle_data, 
              grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0]
    
        GP_COEF = 5.0
        loss = total_error + GP_COEF * (torch.norm(grad.view(batch_size, -1), dim=1)**2).mean()
        loss.backward()
        optimizer.step()
        return -total_error


# Training generator
def train_generator(optimizer, fake_data):
    optimizer.zero_grad()
    error = -discriminator(fake_data).mean()
    error.backward()
    optimizer.step()
    return error

# Training loop
num_epochs = 100
n_discriminator = 5
sequence_length = 1  # Set sequence length for LSTM

generator.train()
discriminator.train()

# Model summaries
print("Generator Summary:")
summary(generator, input_size=(64, 1, 5), col_names=["input_size","output_size","num_params"])
print("Discriminator Summary:")
summary(discriminator, input_size=(64, 1, 5), col_names=["input_size","output_size","num_params"])

for epoch in range(num_epochs + 1):
    g_error = 0.0
    c_error = 0.0
    STEP_PER_EPOCH = 1
    for i in range(n_discriminator * STEP_PER_EPOCH):
        batch_size = 64
        fake_data = generator(row(batch_size, sequence_length, input_features=5)).detach()
        real_data = noise(batch_size, sequence_length, input_features=5)
        c_error += train_discriminator(c_optim, real_data, fake_data)
        if (i + 1) % n_discriminator == 0:
            fake_data = generator(row(batch_size, sequence_length, input_features=5))
            g_error += train_generator(g_optim, fake_data)
    
    # Print losses for each epoch
    print(f"Epoch [{epoch}/{num_epochs}] - Discriminator Loss: {c_error.item():.4f}, Generator Loss: {g_error.item():.4f}")

    
    if epoch % 100 == 0:
        img = generator(row(batch_size, sequence_length, 5)).cpu().detach()
        img = make_grid(img)
        to_image = lambda x: (np.clip(x.numpy().copy().transpose(1,2,0),0,1)*255).astype(np.uint8)
        img = to_image(img)
        images.append(img)
        c_losses.append(c_error)
        
        r_img = to_image(make_grid(real_data.cpu().detach()))
                
        from IPython.display import clear_output
        clear_output()
        plt.clf()
        plt.plot([loss.cpu().detach().numpy() for loss in c_losses], label='Discriminator Losses')
        plt.plot([0,len(g_losses)],[0,0])
        plt.yscale("log")
        plt.legend()
        plt.savefig('loss.png', dpi=900)

    if epoch%100==0:
        with open('G_output_10_'+str(epoch)+'.txt', 'a') as G_output:
            for num in range(100):
                rand_img = generator(row(64,1,5)).cpu().detach().numpy().copy().flatten() 
                np.savetxt(G_output,rand_img)
        G_output.close()
        #output random number pictures
        rand_img = generator(row(64,1,5)).cpu().detach()
        rand_img_real = torch.rand(64,1,6,6).cpu().detach()
        #to array
        im_list = np.asarray(rand_img.view(-1,5))
        im_list_real = np.asarray(rand_img_real.view(-1,6*6))
        plt.figure(figsize=(10,10))
        plt.subplot(1,2,1)
        plt.imshow(im_list,cmap='gray',vmin=0, vmax=1)
        plt.colorbar()
        plt.subplot(1,2,2)
        plt.imshow(im_list_real,cmap='gray',vmin=0, vmax=1)
        plt.colorbar()
        plt.savefig('rand_image_'+str(epoch)+'.jpg', dpi=1200)

        
print('Training Finished')
# Save the generator's state dictionary
torch.save(generator.state_dict(), 'mnist_generator.pth')

# Export the generator to ONNX
dummy_input = row(1, 1, 5).to(device)
torch.onnx.export(generator, dummy_input, "generator.onnx", 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'], 
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

# Save the discriminator's state dictionary
torch.save(discriminator.state_dict(), 'mnist_discriminator.pth')

# Export the discriminator to ONNX
dummy_input = noise(1, 1, 5).to(device)
torch.onnx.export(discriminator, dummy_input, "discriminator.onnx", 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'], 
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

Profiling the code may be helpful PyTorch Profiler — PyTorch Tutorials 2.5.0+cu124 documentation