Okay here you go- I have a simple autoencoder and a modified loss function.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
#Encoder
self.conv1 = nn.Conv2d(1, 16, 3,padding=1)
self.conv2 = nn.Conv2d(16, 4, 3,padding=1)
self.pool = nn.MaxPool2d(2)
#Decoder
self.t_conv1 = nn.ConvTranspose2d(4, 16, 2, stride=2)
self.t_conv2 = nn.ConvTranspose2d(16, 1, 2, stride=2)
#weightmapper
self.convw=nn.Conv2d(1,16,3,padding='same')
self.conwm2=nn.Conv2d(16,1,3,padding='same')
def forward(self, x,wm):
x = F.relu(self.conv1(x),inplace=True)
x = self.pool(x)
x = F.relu(self.conv2(x),inplace=True)
x = self.pool(x)
x = F.relu(self.t_conv1(x),inplace=True)
x = (self.t_conv2(x))
wm=self.convw(wm)
wm=F.relu(self.conwm2(wm),inplace=True)
return x,wm
convo=ConvAutoencoder()
optimizer = torch.optim.Adam(convo.parameters(),
lr = 1e-3,
weight_decay = 1e-8)
class weightedloss(nn.Module):
def __init__(self):
super(weightedloss, self).__init__()
self.loss_fn = nn.L1Loss(reduction='none')
def forward(self,input,targets,wmat):
ls1=self.loss_fn(input,targets)
full=torch.mean(ls1*wmat)
return full
epochs = 20
outputs = []
weights=[]
losses = []
bmat=np.random.rand(32,32).astype('float32')
wmat=np.random.rand(32,32).astype('float32')
wmat=nn.Parameter(torch.tensor(wmat))
convo=ConvAutoencoder()
losser=weightedloss()
optimizer = torch.optim.Adam(convo.parameters(),
lr = 1e-3,
weight_decay = 1e-8)
bmat=(torch.from_numpy(bmat)).unsqueeze(0)
wmat=wmat.unsqueeze(0)
for epoch in range(epochs):
for i in range(0,5):
amat = np.random.rand(32,32).astype('float32')
amat=(torch.from_numpy(amat)).unsqueeze(0)
predict,weight = convo(amat,wmat)
loss = losser(predict, bmat,weight)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
weights.append(torch.mean(weight).item())
In this, the list named weights is collecting the mean of the weight map at each minbatch
Thanks again