Last week i implemented a Resnet Block for my SRGAN code and i was doubtful about one aspect of the implementation.
When you pass the input through the residual skip connection, do you detach that from the computation graph ( Like var.detach() in Pytorch) or do you let it persist in the graph?
If so, does this mean the during the backward pass, the input’s weights would be influenced by both the next layer connection and the skip connection?
Below is the code for my implementation in Pytorch:
class ResnetBlock(nn.Module):
def __init__(self,kernel_size,stride,n_filters,in_channels):
super(ResnetBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,kernel_size=kernel_size,stride=stride,out_channels=n_filters,padding=1,bias=False)
self.bn = nn.BatchNorm2d(n_filters)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(in_channels=n_filters,kernel_size=kernel_size,stride=stride,out_channels=n_filters,padding=1,bias=False)
def forward(self,x):
xp = torch.clone(x).detach() #Detached this variable xp and passed it through the skip connection
x = self.prelu(self.bn(self.conv1(x)))
x = xp.add(self.prelu(self.bn(self.conv2(x))))
return x