Background here. I want to implement Neural Style Transfer using Pytorch from scratch (for educational purpose). There are 2 models, which are transformer net (T)
and loss net (L)
. With given input x
, I then will compute o = T(x)
. Then compute loss
from L(x, o)
. But the loss from loss net did not propagate back to transformer net.
I don’t think I understand how pytorch autograd well enough to know what is going wrong. But the code below is what I came up with, and if I set lossNet
on requires_grad
to False
. The code give me a error of loss does not have grad_fn
.
Please advised, Thank you
Code
Here are the simplify code of my current situation where Model A
is transformer net and Model B
is loss net
class ModelA(nn.Module):
def __init__(self, requires_grad=True):
super(ModelA, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
self.relu1 = torch.nn.ReLU(inplace=False)
self.maxpool = torch.nn.MaxPool2d(2, padding=0)
self.conv2 = torch.nn.Conv2d(16, 16, 1)
self.relu2 = torch.nn.ReLU(inplace=False)
self.conv3T = torch.nn.ConvTranspose2d(16, 3, 1)
self.relu3 = torch.nn.ReLU(inplace=False)
self.upSampling = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, in_x):
x = in_x
x = self.relu1(self.conv1(x))
x = self.maxpool(x)
x = self.relu2(self.conv2(x))
x = self.relu3(self.conv3T(x))
x = self.sigmoid(self.upSampling(x))
return x
class ModelB(nn.Module):
def __init__(self):
super(ModelB, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 64, 3, padding=1)
self.relu1 = torch.nn.ReLU(inplace=False)
self.maxpool1 = torch.nn.MaxPool2d(2, padding=1)
self.conv2 = torch.nn.Conv2d(64, 128, 3, padding=1)
self.relu2 = torch.nn.ReLU(inplace=False)
self.maxpool2 = torch.nn.MaxPool2d(2, padding=1)
self.conv3 = torch.nn.Conv2d(128, 256, 3, padding=1)
self.relu3 = torch.nn.ReLU(inplace=False)
self.maxpool3 = torch.nn.MaxPool2d(2, padding=1)
def forward(self, in_x, in_o):
x = in_x
o = in_o
x = self.maxpool1(self.relu1(self.conv1(x)))
x = self.maxpool2(self.relu2(self.conv2(x)))
x = self.maxpool3(self.relu3(self.conv3(x)))
o = self.maxpool1(self.relu1(self.conv1(o)))
o = self.maxpool2(self.relu2(self.conv2(o)))
o = self.maxpool3(self.relu3(self.conv3(o)))
return x, o
The Model B
is then separate into subnetwork and wrapped inside another class before merge again
class ModuleWrapper(nn.Module):
def __init__(self, subnetwork, isUseCuda):
super(ModuleWrapper, self).__init__()
self.layers = list(subnetwork.children())
self.net = nn.Sequential(*self.layers)
self.loss = torch.tensor(0.0).float()
self.not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
if isUseCuda:
self.loss = self.loss.cuda()
def forward(self, in_x, in_o):
if torch.cuda.is_available():
x = in_x.cuda()
o = in_o.cuda()
else:
x = in_x
o = in_o
for layer in self.layers:
layer = self.not_inplace(layer)
x = layer(x)
o = layer(o)
if torch.cuda.is_available():
x = x.cuda()
o = o.cuda()
self.loss = F.mse_loss(x, o)
return x, o
class MergeWrapper(nn.Module):
def __init__(self, modules):
super(MergeWrapper, self).__init__()
self.net = nn.Sequential(*modules)
def forward(self, in_x, in_o):
x, o = in_x, in_o
for module in self.net:
x, o = module(x, o)
return x, o
def get_module(self, index):
assert 0 <= index < len(self.net)
return self.net[index]
def max_seq(self):
return len(self.net)
backboneNet= ModelA()
for param in backboneNet.parameters():
param.requires_grad = True
lossNet = ModelB()
# Assume Model B is pretrained
for param in lossNet.parameters():
param.requires_grad = False
layers = list(lossNet.children())
subnet = []
subnet.append(ModuleWrapper(nn.Sequential(*layers[0:3]), useCuda))
subnet.append(ModuleWrapper(nn.Sequential(*layers[3:6]), useCuda))
subnet.append(ModuleWrapper(nn.Sequential(*layers[6:9]), useCuda))
lossNet = MergeWrapper(subnet)
Then I test with the following code
testInput = torch.rand(1, 3, 8, 8)
def train(testInput, num_steps = 10):
optimizer = optim.LBFGS([testInput.requires_grad_()])
run = [0]
while run[0] <= num_steps:
def closure():
optimizer.zero_grad()
x = testInput
o = backboneNet(x)
x = x.data.clamp_(0, 1)
o = o.data.clamp_(0, 1)
lossNet(x, o)
loss = torch.tensor(0.0).float().to(device)
for module in subnet:
loss += module.loss
print(loss)
print_backprop(loss)
loss.backward(retain_graph=True)
run[0] += 1
return loss
optimizer.step(closure)
train(testInput)
The printed loss
value is not zero when set requires_grad = True
on lossNet
but the printed back-propagated layers is incorrect. And I have no idea how to fix the issue.