Hi, I’m currently working on PyTorch project to implement a minibatch discrimination(kind of nested module in discriminator) on my GAN code.
But after monitoring the training procedure, I find that my RAM usage is increasing over epochs. And RAM usage causes my whole system halts so that my GAN cannot continue to learn.
I already refered CPU RAM usage increases inside each epoch and keeps increasing for all epochs (OSError: [Errno 12] Cannot allocate memory), but I cannot detach it cuz I have to train T and x.
So I wanna stop gradient computation with respect to some temporary variables(in my case, M
, cb
, Ox
). But whenever I try to set requires_grad=True
on this parameters, It shows me " can’t change requires_grad for non leaf variables" error.
This is my module for minibatch discriminator.
class MinibatchDiscriminator(nn.Module):
def __init__(self, in_features, out_features, row_size, batch_size):
super(MinibatchDiscriminator,self).__init__();
self.in_features = in_features; # A
self.out_features = out_features; # B in paper
self.row_size = row_size; # C in paper
self.batch_size = batch_size; # N
self.T = nn.Parameter(torch.Tensor(self.in_features, self.out_features, self.row_size)); # A*B*C
init.normal_(self.T, 0, 1)
def forward(self,x) :
M = torch.mm(x, self.T.view(self.in_features,-1)).view(-1,self.out_features,self.row_size); # R^(N*A) * R^(A*B*C) => R^(N*B*C)
cb = torch.zeros(self.batch_size, self.batch_size, self.out_features); # N*N*B
for i in range(self.batch_size) :
for j in range(self.batch_size) :
cb[i][j] = torch.abs(M[i]-M[j]).sum(1);
cb = torch.exp(-cb);
Ox = cb.sum(1).cuda(); # N*B
# Ox.requires_grad = False; // Error
return torch.cat([x,Ox],1);
And this is the module for discriminator.
class Discriminator(nn.Module) :
def __init__(self) :
super(Discriminator, self).__init__();
self.ff1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=128, stride =4 , kernel_size = 4, padding = 0 ),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.Conv2d(in_channels=128, out_channels=256, stride =1 , kernel_size = 4, padding = "same" ),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(in_channels=256, out_channels=512, stride =4 , kernel_size = 4, padding = 0 ),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
nn.Conv2d(in_channels=512, out_channels=1024, stride =1 , kernel_size = 4, padding = "same" ),
nn.BatchNorm2d(1024),
nn.LeakyReLU(),
nn.Conv2d(in_channels=1024, out_channels=1024, stride =4 , kernel_size = 4, padding = 0 ),
nn.BatchNorm2d(1024),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features = 1024 * 4 * 4 , out_features= minibatch_A),
)
self.ff2 = nn.Sequential(
nn.Linear(in_features = minibatch_A+minibatch_B, out_features=1),
nn.Sigmoid()
)
def forward (self, x) :
p = self.ff1(x);
p = minibatch_discriminator(p);
p = self.ff2(p);
return p;
Thank you for your help.