Hi all, sorry if my question is basic, I am inept at PyTorch. I am trying to train a network for sparse feature maps (It works like generative models), where I have two different parameters. One is the Conv layer weights and another one is the feature map (z). when I print the grad for Conv layer weights, I get None. I think I am not wrapping or detaching it anywhere. Could you please take a look? I appreciate it.
Here is the code:
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is:', device)
# dataset definition
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
mnist_trainset.data = mnist_trainset.data[:10000]
from torch.utils.data import DataLoader
train_dl = DataLoader(mnist_trainset, batch_size=16, shuffle=False)
from torch.optim import SGD
import torch.nn as nn
from torch.nn import Module
from torch.nn import Conv2d
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.autograd import Variable
class MNIST_ISTA(Module):
# define model elements
def __init__(self):
self.lambda_ = 0
super(MNIST_ISTA, self).__init__()
self.scale1 = Conv2d(in_channels = 1, out_channels = 1, kernel_size=20, bias = False)
self.alpha = 1e0
self.z = None
def ista_(self, img_batch):
output_test = self.scale1(img_batch)
self.z = nn.Parameter(torch.normal(0, 1, size = (output_test.shape[0], output_test.shape[1], output_test.shape[2], output_test.shape[3]), requires_grad=True))
self.scale1.weight.requires_grad=False
optim = SGD([{'params': self.z, "lr": 1e-5 }])
converged = False
while not converged:
z1_old = self.z.clone().detach()
output_image = (F.conv2d((self.z),(self.scale1.weight), padding=self.scale1.kernel_size[0]-1))
loss = ((img_batch-output_image)**2).sum() + self.alpha*torch.norm(self.z,p=1)
optim.zero_grad()
loss.backward()
optim.step()
self.z.grad.zero_()
converged = torch.norm(self.z - z1_old)/torch.norm(z1_old)<1e-2
def soft_thresholding_(self, x, alpha):
with torch.no_grad():
rtn = F.relu(x-alpha)- F.relu(-x-alpha)
return rtn.data
def forward(self, img_batch):
self.ista_(img_batch)
return F.conv2d((self.z),(self.scale1.weight), padding = self.scale1.kernel_size[0]-1)
def zero_grad(self):
self.scale1.zero_grad()
ista_model = MNIST_ISTA()
optim2 = SGD([{'params': ista_model.scale1.weight, "lr": 1e-3}])
for epoch in range(5):
running_loss = 0
for data in tqdm(train_dl, desc='training', total=len(train_dl)):
img_batch = data[0]
ista_model.scale1.weight.requires_grad=False
pred = ista_model(img_batch) #the original image size is returned
ista_model.z.requires_grad=False
ista_model.scale1.weight.requires_grad=True
criterion = nn.MSELoss()
loss2 = criterion(pred, img_batch)
running_loss += loss2.item()
print('grad')
print(ista_model.scale1.weight.grad)
optim2.zero_grad()
loss2.backward()
optim2.step()
ista_model.zero_grad()
ista_model.z.requires_grad=True