I tried to customize a layer to automatically adjust the brightness and contrast by a linear layer to determine the coefficient, but could not backpropagation, my layer is:
class MyLinear(torch.autograd.Function):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(0))
self.weight.data.uniform_(0, stdv)
def forward(self, x1, x2):
x1 = torch.matmul(x1, self.weight.data) + self.bias.data
print(x1 ** 2)
x1 = nn.functional.relu(x1)
print(x1 ** 2)
print('-' * 50)
img_new = torch.zeros((x2.shape[0], 3, size, size))
for i in range(0, x2.shape[0]):
for j in range(0, x2.shape[1]):
if j == 0:
coef1 = round((x1[0][0] * x1[0][0]).item(), 8)
coef2 = round((x1[0][1] * x1[0][1]).item(), 8)
elif j == 1:
coef1 = round((x1[0][2] * x1[0][2]).item(), 8)
coef2 = round((x1[0][3] * x1[0][3]).item(), 8)
elif j == 2:
coef1 = round((x1[0][4] * x1[0][4]).item(), 8)
coef2 = round((x1[0][5] * x1[0][5]).item(), 8)
# print(coef1, coef2)
img = x2[i, j, :, :]
img = F.adjust_brightness(img, coef1)
img = img.reshape(-1, size, size)
img = F.adjust_contrast(img, coef2)
if j == 0:
z = img.reshape(1, size, size).cuda()
else:
z = torch.cat((z, img), 0)
z = z.reshape(1, 3, size, size)
img_new[i] = z.clone().detach()
img_new = img_new.clone().detach().requires_grad_(True)
return img_new.cuda()
And my model is:
module = torchvision.models.efficientnet_b4(pretrained=True)
class bsEfficientnetb4(nn.Module):
def __init__(self):
super(bsEfficientnetb4, self).__init__()
self.brightcontrast = MyLinear(1, 6)
self.eff = module.features
self.avgpool = module.avgpool
self.classifier = module.classifier
def forward(self, x1, x2):
x = self.brightcontrast(x1, x2)
x = self.eff(x)
x = self.avgpool(x)
x = x.reshape(-1, 1792)
x = self.classifier(x)
return x
My input x1 is:
x1 = torch.tensor([[1.,]], requires_grad=True).cuda()
And x2 is image, I’d like to get your suggestions for modifications, thanks!