Hello there, this is my first time posting. I’m having trouble with some code I found on github and I’m working on it. Weirdly enough, training is extremely slow. It shows that only the first core of the CPU is being 100% used. I have no idea what to do next since from what I saw everything should be at least in the right place.
I apologize for the long code. Thanks for the attention.
enSN = True
enGP = True
device = 'cuda:0'
Tensor = torch.cuda.FloatTensor
from torch.nn.utils import spectral_norm as SN_
if enSN:
SN=SN_
else:
SN=lambda x:x
Hidden = List[Tuple[Tensor, ...]]
def enlarge_as(src : Tensor, other : Tensor) -> Tensor:
'''
Add sufficient number of singleton dimensions
to tensor a **to the right** so to match the
shape of tensor b. NOTE that simple broadcasting
works in the opposite direction.
'''
return rearrange(src, f'... -> ...{" 1" * (other.dim() - src.dim())}').contiguous()
class CausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True
):
self._padding = (kernel_size - 1) * dilation
super(CausalConv1d, self).__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self._padding,
dilation=dilation,
groups=groups,
bias=bias)
def forward(self, inp : Tensor) -> Tensor:
# Handle the case where input has only two dimensions
# we expect them to have semantics (batch, channels),
# so we add the missing dimension manually
if inp.dim() == 2: inp = rearrange(inp, 'b i -> b 1 i')
result = super(CausalConv1d, self).forward(inp)
if self._padding != 0: return result[..., :-self._padding]
return result
class BlockLinear(nn.Module):
def __init__(
self,
block_dims : List[int | List[int]],
bias : bool = False,
):
super(BlockLinear, self).__init__()
self._blocks = nn.ParameterList([
nn.Parameter(torch.randn(size, requires_grad=True))
for size in block_dims
])
self._bias = nn.Parameter(torch.zeros(sum(block_dims))) if bias else None
def forward(self, inp : Tensor) -> Tensor:
# Assemble the blocks into a block-diagonal matrix
full = torch.block_diag(*self._blocks)
out = torch.matmul(inp, full)
if self._bias is not None:
out = out + self._bias
return out
class sLSTM(nn.Module):
'''The scalar-Long Short Term Memory (sLSTM) module as
originally introduced in Beck et al. (2024)] see:
(https://arxiv.org/abs/2405.04517).
This model is a variant of the standard LSTM model and
offers two major improvements:
- Exponential gating with appropriate state normalization
to avoid overflows induced by the exponential function.
- A new memory mixing within heads but not across heads.
'''
def __init__(
self,
inp_dim : int,
head_dim : int,
head_num : int,
ker_size : int = 4,
p_factor : float = 4/3,
) -> None:
super().__init__()
self.inp_dim = inp_dim
self.head_dim = head_dim
self.head_num = head_num
self.inp_norm = nn.LayerNorm(inp_dim)
self.hid_norm = nn.GroupNorm(head_num, head_dim * head_num)
self.causal_conv = CausalConv1d(1, 1, kernel_size=ker_size)
self.W_z = nn.Linear(inp_dim, head_num * head_dim)
self.W_i = nn.Linear(inp_dim, head_num * head_dim)
self.W_o = nn.Linear(inp_dim, head_num * head_dim)
self.W_f = nn.Linear(inp_dim, head_num * head_dim)
self.R_z = BlockLinear([(head_dim, head_dim)] * head_num)
self.R_i = BlockLinear([(head_dim, head_dim)] * head_num)
self.R_o = BlockLinear([(head_dim, head_dim)] * head_num)
self.R_f = BlockLinear([(head_dim, head_dim)] * head_num)
# NOTE: The factor of two in the output dimension of the up_proj
# is due to the fact that the output needs to branch into two
# separate outputs to account for the the gated GeLU connection.
# See Fig. 9 in the paper.
proj_dim = int(p_factor * head_num * head_dim)
self.up_proj = nn.Linear(head_num * head_dim, 2 * proj_dim)
self.down_proj = nn.Linear(proj_dim, inp_dim)
@property
def device(self) -> str:
'''Get the device of the model.
Returns:
str: The device of the model.
'''
return next(self.parameters()).device
def init_hidden(self, bs : int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
'''Initialize the hidden state of the sLSTM model.
Args:
batch_size (int): The batch size of the input sequence.
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: The hidden state tuple containing the cell state,
normalizer state, hidden state, and stabilizer state.
'''
n_0 = torch.ones (bs, self.head_num * self.head_dim, device=self.device)
c_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device)
h_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device)
m_0 = torch.zeros(bs, self.head_num * self.head_dim, device=self.device)
return c_0, n_0, h_0, m_0
def forward(
self,
seq: Tensor,
hid: Tuple[Tensor, Tensor, Tensor, Tensor],
use_conv : bool = False,
) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]:
'''Forward pass of the sLSTM model.
Args:
seq (Tensor): The input sequence tensor of shape (batch_size, input_dim).
hid (Tuple[Tensor, Tensor, Tensor, Tensor]): The hidden state tuple containing the cell state,
normalizer state, hidden state, and stabilizer state.
Returns:
Tuple[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: The output tensor with the residual
connection and the newly updated hidden state tuple.
'''
b, d = seq.shape
# Separate the hidden (previous) state into the cell state,
# the normalizer state, the hidden state, and the stabilizer state.
c_tm1, n_tm1, h_tm1, m_tm1 = hid
x_t : Tensor = self.inp_norm(seq)
# Optional causal convolution block for the input
# and forget gates. See Fig. 9 in the paper.
if use_conv:
# FIXME: The causal conv branch is broken.
x_c = self.causal_conv(x_t)
x_c = F.silu(x_c).squeeze()
else:
x_c = x_t
# Project the input to the different heads for all
# the gates.
# NOTE: For input (i) and forget (f) inputs we use
# the output of the causal conv. See Fig. 9 in the paper.
i_t: Tensor = self.W_i(x_c) + self.R_i(h_tm1)
f_t: Tensor = self.W_f(x_c) + self.R_f(h_tm1)
z_t: Tensor = self.W_z(x_t) + self.R_z(h_tm1)
o_t: Tensor = self.W_o(x_t) + self.R_o(h_tm1)
# Compute the gated outputs for the newly computed inputs
m_t = torch.max(f_t + m_tm1, i_t)
i_t = exp(i_t - m_t) # Eq. (16) in ref. paper | or Eq. (38) in supp. mat.
f_t = exp(f_t - m_t + m_tm1) # Eq. (17) in ref. paper | or Eq. (39) in supp. mat.
z_t = tanh(z_t) # Eq. (11) in ref. paper
o_t = sigmoid(o_t) # Eq. (14) in ref. paper
# Update the internal states of the model
c_t = f_t * c_tm1 + i_t * z_t # Eq. (8) in ref. paper
n_t = f_t * n_tm1 + i_t # Eq. (9) in ref. paper
h_t = o_t * (c_t / n_t) # Eq. (10) in ref. paper
# Compute the output of the LSTM block
out = self.hid_norm(h_t)
# Perform up-and-down projection of the output with
# projection factor 4/3. See Fig. (9) in supp. mat.
out1, out2 = self.up_proj(out).chunk(2, dim=-1)
out = out1 + F.gelu(out2)
out = self.down_proj(out)
# Return output with the residual connection and the
# newly updated hidden state.
return out + seq, (c_t, n_t, h_t, m_t)
class mLSTM(nn.Module):
'''The matrix-Long Short Term Memory (mLSTM) module as
originally introduced in Beck et al. (2024)] see:
(https://arxiv.org/abs/2405.04517).
This model is a variant of the standard LSTM model and
offers superior memory due to its storing values in a
matrix instead of a scalar. It is fully parallelizable
and updates internal memory with the covariance rule.
'''
def __init__(
self,
inp_dim : int,
head_num : int,
head_dim : int,
p_factor : int = 2,
ker_size : int = 4,
) -> None:
super().__init__()
self.inp_dim = inp_dim
self.head_num = head_num
self.head_dim = head_dim
hid_dim = head_num * head_dim
self.inp_norm = nn.LayerNorm(inp_dim)
self.hid_norm = nn.GroupNorm(head_num, hid_dim)
# NOTE: The factor of two in the output dimension of the up_proj
# is due to the fact that the output needs to branch into two
self.up_l_proj = nn.Linear(inp_dim, int(p_factor * inp_dim))
self.up_r_proj = nn.Linear(inp_dim, hid_dim)
self.down_proj = nn.Linear(hid_dim, inp_dim)
self.causal_conv = CausalConv1d(1, 1, kernel_size=ker_size)
self.skip = nn.Conv1d(int(p_factor * inp_dim), hid_dim, kernel_size=1, bias=False)
self.W_i = nn.Linear(int(p_factor * inp_dim), head_num)
self.W_f = nn.Linear(int(p_factor * inp_dim), head_num)
self.W_o = nn.Linear(int(p_factor * inp_dim), hid_dim)
self.W_q = nn.Linear(int(p_factor * inp_dim), hid_dim)
self.W_k = nn.Linear(int(p_factor * inp_dim), hid_dim)
self.W_v = nn.Linear(int(p_factor * inp_dim), hid_dim)
@property
def device(self) -> str:
'''Get the device of the model.
Returns:
str: The device of the model.
'''
return next(self.parameters()).device
def init_hidden(self, bs : int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
'''Initialize the hidden state of the sLSTM model.
Args:
batch_size (int): The batch size of the input sequence.
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: The hidden state tuple containing the cell state,
normalizer state, hidden state, and stabilizer state.
'''
c_0 = torch.zeros(bs, self.head_num, self.head_dim, self.head_dim, device=self.device)
n_0 = torch.ones (bs, self.head_num, self.head_dim , device=self.device)
m_0 = torch.zeros(bs, self.head_num , device=self.device)
return c_0, n_0, m_0
def forward(
self,
seq: Tensor,
hid: Tuple[Tensor, Tensor],
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
'''_summary_
Args:
seq (Tensor): _description_
hid (Tuple[Tensor, Tensor]): _description_
Returns:
Tuple[Tensor, Tuple[Tensor, Tensor]]: _description_
'''
# Separate the hidden (previous) state into the cell state,
# the normalizer state, the hidden state, and the stabilizer state.
c_tm1, n_tm1, m_tm1 = hid
x_n : Tensor = self.inp_norm(seq) # shape: b i
x_t = self.up_l_proj(x_n) # shape: b (i * p_factor)
r_t = self.up_r_proj(x_n) # shape: b (h d)
# Compute the causal convolutional input (to be
# used for the query and key gates)
x_c = self.causal_conv(x_t) # shape: b 1 (i * p_factor)
x_c = rearrange(F.silu(x_c), 'b ... -> b (...)') # shape: b (i * p_factor)
q_t = rearrange(self.W_q(x_c), 'b (h d) -> b h d', h=self.head_num)
k_t = rearrange(self.W_k(x_c), 'b (h d) -> b h d', h=self.head_num) / sqrt(self.head_dim)
v_t = rearrange(self.W_v(x_t), 'b (h d) -> b h d', h=self.head_num)
i_t: Tensor = self.W_i(x_c) # shape: b h
f_t: Tensor = self.W_f(x_c) # shape: b h
o_t: Tensor = self.W_o(x_t) # shape: b (h d)
# Compute the gated outputs for the newly computed inputs
m_t = torch.max(f_t + m_tm1, i_t)
i_t = exp(i_t - m_t) # Eq. (25) in ref. paper
f_t = exp(f_t - m_t + m_tm1) # Eq. (26) in ref. paper
o_t = sigmoid(o_t) # Eq. (27) in ref. paper
# Update the internal states of the model
c_t = enlarge_as(f_t, c_tm1) * c_tm1 + enlarge_as(i_t, c_tm1) * einsum(v_t, k_t, 'b h d, b h p -> b h d p')
n_t = enlarge_as(f_t, n_tm1) * n_tm1 + enlarge_as(i_t, k_t) * k_t
h_t = o_t * rearrange(
einsum(c_t, q_t, 'b h d p, b h p -> b h d') /
einsum(n_t, q_t, 'b h d, b h d -> b h').clamp(min=1).unsqueeze(-1),
'b h d -> b (h d)'
) # Eq. (21) in ref. paper
x_c = rearrange(x_c, 'b i -> b i 1')
out = self.hid_norm(h_t) + self.skip(x_c).squeeze() # shape: b (h d)
out = out * F.silu(r_t) # shape: b (h d)
out = self.down_proj(out) # shape: h i
# Return output with the residual connection and the
# newly updated hidden state.
return out + seq, (c_t, n_t, m_t)
class xLSTM(nn.Module):
'''The extended Long Short Term Memory (xLSTM) module as
originally introduced in Beck et al. (2024)] see:
(https://arxiv.org/abs/2405.04517).
This model stacks sLSTM and mLSTM modules with residual
connections and offers superior memory and performance
compared to the standard LSTM model, achieving competitive
or better performance and scaling than Transformer models
or State-Space models.
'''
def __init__(
self,
num_layers : int,
signature : Tuple[int, int],
inp_dim : int,
head_dim : int,
head_num : int,
p_factor : Tuple[float, float] = (2, 4/3),
ker_size : int = 4,
out_dim : int = None
) -> None:
'''Initialize the model.
Args:
num_layers (int): The number of layers in the model.
signature (Tuple[int, int]): The signature of the model,
which represents the ration of the mLSTM-to-sLSTM blocks.
inp_dim (int): The dimension of the input tokens.
head_dim (int): The dimension of each attention head.
head_num (int): The number of attention heads.
p_factor (Tuple[float, float], optional): The expansion factor
for the MLP projection in the m|s-LSTM blocks. Defaults to (2, 4/3).
ker_size (int, optional): The kernel size for the causal convolutional layers.
Defaults to 4.
'''
super().__init__()
# Calculate out_dim if not provided
if out_dim is None:
out_dim = head_num * head_dim
else:
out_dim = out_dim
m_factor, s_factor = p_factor
mlstm_par = {
'inp_dim' : inp_dim,
'head_dim' : head_dim,
'head_num' : head_num,
'p_factor' : m_factor,
'ker_size' : ker_size,
}
slstm_par = {
'inp_dim' : inp_dim,
'head_dim' : head_dim,
'head_num' : head_num,
'p_factor' : s_factor,
'ker_size' : ker_size,
}
m_num, s_num = signature
which = [True] * m_num + [False] * s_num
self.model : List[mLSTM | sLSTM] = nn.ModuleList([
mLSTM(**mlstm_par) if v else sLSTM(**slstm_par)
for w in repeat(which, num_layers) for v in w
]).to(device)
# Prediction head to map the output of the xLSTM model to the output
self.head = nn.Linear(inp_dim, out_dim, bias=False)
def forward(
self,
x: Tensor,
hid: Hidden | None = None,
batch_first : bool = False,
) -> Tuple[Tensor, Hidden]:
'''Forward pass of the xLSTM model.
Args:
x (Tensor): Input tensor representing the sequence tokens.
Expected shape: (batch, seq_len) if batch_first=True,
else (seq_len, batch).
hid (Hidden, optional): Cache object for storing intermediate hidden
values of the m|s-LSTM blocks of the model. If None, the hidden
states are initialized by the models. Defaults to None.
Returns:
Tuple[Tensor, Hidden]: Returns tensor of predicted logits of shape
(batch, seq_len, vocab_size) if batch_first=True or of shape
(seq_len, batch, vocab_size) if batch_first=False, and the
updated hidden model states.
'''
if batch_first: x = rearrange(x, 'b s i -> s b i')
if hid is None: hid = [l.init_hidden(x.shape[1]) for l in self.model]
# Pass the sequence through the mLSTM and sLSTM blocks
out = []
for inp in x:
# Compute model output and update the hidden states
for i, lstm in enumerate(self.model):
inp, hid[i] = lstm(inp, hid[i])
out.append(inp)
out = torch.stack(out, dim=1 if batch_first else 0)
out = self.head(out)
return out, hid
class FourierFeatureMapper(nn.Module):
def __init__(self, input_dim, num_features, scale=10.0):
"""
Fourier feature mapping module.
Args:
input_dim (int): Dimension of the input.
num_features (int): Number of Fourier features.
scale (float): Scaling factor for random Fourier frequencies.
"""
super(FourierFeatureMapper, self).__init__()
self.B = nn.Parameter(
scale * torch.randn((input_dim, num_features)), requires_grad=False
)
def forward(self, x):
"""
Map input to Fourier features.
Args:
x (tensor): [batch_size, seq_len, input_dim]
Returns:
tensor: [batch_size, seq_len, num_features * 2]
"""
# x: [batch_size, seq_len, input_dim]
x_proj = 2 * torch.pi * torch.matmul(x, self.B) # [batch_size, seq_len, num_features]
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) # Concatenate along the feature dimension
# Define the Generator class
class Generator(nn.Module):
def __init__(self, input_features=5, hidden_size=64, output_features=1):
super(Generator, self).__init__()
self.model = nn.ModuleDict({
'fourier': FourierFeatureMapper(
input_dim=input_features,
num_features=128,
scale=10.0
),
'xLSTM': xLSTM(
num_layers=1,
signature=(2, 3),
inp_dim=256,
head_dim=16,
head_num=16,
ker_size=16,
p_factor=(2, 4/3)
),
'linear': nn.Linear(
in_features=hidden_size,
out_features=output_features
),
'Sigmoid': nn.Sigmoid()
}).to(device)
def forward(self, x):
# x: [batch_size, sequence_length, input_features]
fourier_out = self.model['fourier'](x)
xlstm_out, _ = self.model['xLSTM'](fourier_out)
# Apply Linear layer to each time step
linear_out = self.model['linear'](xlstm_out)
linear_out = self.model['Sigmoid'](linear_out)
# Return all time steps (not just the last one)
return linear_out # Shape will be [batch_size, sequence_length, output_features]
# Define the Discriminator class
class Discriminator(nn.Module):
def __init__(self, input_features=5, hidden_size=64):
super(Discriminator, self).__init__()
self.model = nn.ModuleDict({
'fourier': FourierFeatureMapper(
input_dim=input_features,
num_features=128,
scale=10.0
),
'xLSTM': xLSTM(
num_layers=1,
signature=(2, 3),
inp_dim=256,
head_dim=16,
head_num=16,
ker_size=16,
p_factor=(2, 4/3)
),
'linear': nn.Linear(
in_features=hidden_size,
out_features=1
)
}).to(device)
def forward(self, x):
# x: [batch_size, sequence_length, input_features]
fourier_out = self.model['fourier'](x)
xlstm_out, _ = self.model['xLSTM'](fourier_out)
# Apply Linear layer to each time step
linear_out = self.model['linear'](xlstm_out)
# Return the output for the last time step
return linear_out
torch.cuda.init()
torch.cuda.current_device()
torch.cuda.set_device(0) # For the first GPU
# Instantiate models
generator = Generator(input_features=5, hidden_size=256, output_features=5).to(device)
discriminator = Discriminator(input_features=5, hidden_size=256).to(device)
generator.to(device)
discriminator.to(device)
# Optimizers
alpha = 5e-5
g_optim = optim.RMSprop(generator.parameters(), lr=alpha)
c_optim = optim.RMSprop(discriminator.parameters(), lr=alpha)
g_losses = []
c_losses = []
images = []
# Data generation functions
def row(batch_size, sequence_length, input_features):
"""
Generate structured input data with shape (batch_size, sequence_length, input_features).
"""
# Create a random array for features of shape [batch_size, sequence_length, input_features]
random0 = np.random.rand(batch_size, sequence_length, input_features) * 2 * np.pi
# Generate a feature index array for the shape [1, 1, input_features]
feature_indices = np.arange(input_features).reshape(1, 1, -1) # Shape (1, 1, input_features)
# Apply broadcasting to generate data of shape [batch_size, sequence_length, input_features]
randarray = random0 * feature_indices # Broadcast to [batch_size, sequence_length, input_features]
# Create another random array for additional features, same shape [batch_size, sequence_length, input_features]
random1 = np.random.rand(batch_size, sequence_length, input_features) * 2 * np.pi
# Combine them to create a cyclic pattern
output = random1 + randarray
# Apply cosine transformation
output = np.cos(output) * 0.5 + 0.5 # Normalize between 0 and 1
# Convert the output to a PyTorch tensor and send it to the correct device
return torch.from_numpy(output.astype(np.float32)).to(device)
def noise(batch_size, sequence_length, input_features):
"""
Generate random noise with shape (batch_size, sequence_length, input_features).
"""
return torch.rand(batch_size, sequence_length, input_features).to(device)
# Training discriminator
from torch.autograd import grad as tg
if enGP:
def train_discriminator(optimizer, real_data, fake_data, c=0.01):
optimizer.zero_grad()
error_real = discriminator(real_data).mean()
error_fake = discriminator(fake_data).mean()
total_error = -(error_real - error_fake)
batch_size, sequence_length, _ = real_data.size() # Extract dimensions
ep = torch.rand(batch_size, sequence_length, 1).to(device) # Match shape with input
middle_data = ep * real_data + (1 - ep) * fake_data
middle_data.requires_grad_()
discriminator_middle = discriminator(middle_data)
# Adjust grad_outputs shape to match discriminator_middle
grad_outputs = torch.ones_like(discriminator_middle, device=device)
grad = tg(outputs=discriminator_middle, inputs=middle_data,
grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0]
GP_COEF = 5.0
loss = total_error + GP_COEF * (torch.norm(grad.view(batch_size, -1), dim=1)**2).mean()
loss.backward()
optimizer.step()
return -total_error
# Training generator
def train_generator(optimizer, fake_data):
optimizer.zero_grad()
error = -discriminator(fake_data).mean()
error.backward()
optimizer.step()
return error
# Training loop
num_epochs = 100
n_discriminator = 5
sequence_length = 1 # Set sequence length for LSTM
generator.train()
discriminator.train()
# Model summaries
print("Generator Summary:")
summary(generator, input_size=(64, 1, 5), col_names=["input_size","output_size","num_params"])
print("Discriminator Summary:")
summary(discriminator, input_size=(64, 1, 5), col_names=["input_size","output_size","num_params"])
for epoch in range(num_epochs + 1):
g_error = 0.0
c_error = 0.0
STEP_PER_EPOCH = 1
for i in range(n_discriminator * STEP_PER_EPOCH):
batch_size = 64
fake_data = generator(row(batch_size, sequence_length, input_features=5)).detach()
real_data = noise(batch_size, sequence_length, input_features=5)
c_error += train_discriminator(c_optim, real_data, fake_data)
if (i + 1) % n_discriminator == 0:
fake_data = generator(row(batch_size, sequence_length, input_features=5))
g_error += train_generator(g_optim, fake_data)
# Print losses for each epoch
print(f"Epoch [{epoch}/{num_epochs}] - Discriminator Loss: {c_error.item():.4f}, Generator Loss: {g_error.item():.4f}")
if epoch % 100 == 0:
img = generator(row(batch_size, sequence_length, 5)).cpu().detach()
img = make_grid(img)
to_image = lambda x: (np.clip(x.numpy().copy().transpose(1,2,0),0,1)*255).astype(np.uint8)
img = to_image(img)
images.append(img)
c_losses.append(c_error)
r_img = to_image(make_grid(real_data.cpu().detach()))
from IPython.display import clear_output
clear_output()
plt.clf()
plt.plot([loss.cpu().detach().numpy() for loss in c_losses], label='Discriminator Losses')
plt.plot([0,len(g_losses)],[0,0])
plt.yscale("log")
plt.legend()
plt.savefig('loss.png', dpi=900)
if epoch%100==0:
with open('G_output_10_'+str(epoch)+'.txt', 'a') as G_output:
for num in range(100):
rand_img = generator(row(64,1,5)).cpu().detach().numpy().copy().flatten()
np.savetxt(G_output,rand_img)
G_output.close()
#output random number pictures
rand_img = generator(row(64,1,5)).cpu().detach()
rand_img_real = torch.rand(64,1,6,6).cpu().detach()
#to array
im_list = np.asarray(rand_img.view(-1,5))
im_list_real = np.asarray(rand_img_real.view(-1,6*6))
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(im_list,cmap='gray',vmin=0, vmax=1)
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(im_list_real,cmap='gray',vmin=0, vmax=1)
plt.colorbar()
plt.savefig('rand_image_'+str(epoch)+'.jpg', dpi=1200)
print('Training Finished')
# Save the generator's state dictionary
torch.save(generator.state_dict(), 'mnist_generator.pth')
# Export the generator to ONNX
dummy_input = row(1, 1, 5).to(device)
torch.onnx.export(generator, dummy_input, "generator.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
# Save the discriminator's state dictionary
torch.save(discriminator.state_dict(), 'mnist_discriminator.pth')
# Export the discriminator to ONNX
dummy_input = noise(1, 1, 5).to(device)
torch.onnx.export(discriminator, dummy_input, "discriminator.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})