Hello,
I want to introduce a kind of total variation penalty.
I want to realize this by the fact, that the neural network is penalized for it, if the neighboured weights differ strongly.
But I want to apply this only in one layer. So I have added this additionally in the loss function.
My code parts are therefore as follows:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# decoder
self.conv1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=50, bias=False)
self.conv2 = nn.BatchNorm1d(1)
self.conv3 = nn.Conv1d(in_channels=1,out_channels=1, kernel_size=30, bias=False)
self.conv4 = nn.BatchNorm1d(1)
self.conv5 = nn.Conv1d(in_channels=1,out_channels=1, kernel_size=11, bias=False)
self.conv6 = nn.BatchNorm1d(1)
self.ASC = nn.Softmax(dim=1)
# encoder
self.DEC = nn.Linear(in_features=3,out_features=91)
def forward(self, x):
x = F.leaky_relu(self.conv1(x))
x = F.leaky_relu(self.conv2(x))
x = F.leaky_relu(self.conv3(x))
x = F.leaky_relu(self.conv4(x))
x = F.leaky_relu(self.conv5(x))
x = x*3.5
x = F.leaky_relu(self.conv6(x))
a = self.ASC(x)
x = F.leaky_relu(self.DEC(a))
return x, a
net = CNN()
def train(net, trainloader, NUM_EPOCHS, cur_epoch=0):
train_loss = []
for epoch in range(cur_epoch, NUM_EPOCHS):
running_loss = 0.0
for data in trainloader:
inputData = data
inputData = inputData.to(device)
optimizer.zero_grad()
outputData ,a = net(inputData)
loss_MSE = criterion(outputData, inputData)
loss_TVPenalty = TV_Penalty(net.DEC)
loss = loss_MSE + loss_TVPenalty
loss.backward()
optimizer.step()
net.DEC.weight.data.clamp_(min=0)
running_loss += loss.item()
loss = running_loss / len(trainloader)
train_loss.append(loss)
return train_loss
def TV_Penalty(Layer):
loss = torch.pow((Layer.weight.data[1:]-Layer.weight.data[:-1]), 2).sum()*TV_PENALTY_CONSTANT
return loss
If I use it this way, the autograde function of Pytorch is not applied to the python def TV_Penalty(Layer):
function.
How to use the autograd in Pytocht on the structure of the weights of the network itself? So then this is considered in an additional loss function.