My parameters have none grad

This is my custom convolution function (I modified conv2d() ) and I replaced all conv2d functions with this function in Alex-net.
but I found that parameters in Qilconv2d were not changed after backpropagation and grad was also None.
What is the problem in my source?
My custom convolution has 3 steps.

  1. transformer (weight, inputs)
  2. discretized (weight, inputs)
  3. F.conv2d()
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import copy
import time
import math
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.modules.conv import _ConvNd


class RoundNoGradient(torch.autograd.Function):
	@staticmethod
	def forward(ctx, x):
		return x.round()
	
	@staticmethod
	def backward(ctx, g):
		return g 

class QilConv2d(_ConvNd):
	
	discretization_level = 32.0

	def __init__(self, in_channels, out_channels, kernel_size, stride=1,
				padding=0, dilation=1, groups=1,
				bias=True, padding_mode='zeros'):
		kernel_size = _pair(kernel_size)
		stride = _pair(stride)
		padding = _pair(padding)
		dilation = _pair(dilation)
		
		super(QilConv2d, self).__init__(
			in_channels, out_channels, kernel_size, stride, padding, dilation,
			False, _pair(0), groups, bias, padding_mode)

		# Qil parameters
		self.cw = nn.Parameter(torch.randn(1),requires_grad=True) 
		self.dw = nn.Parameter(torch.randn(1),requires_grad=True) 
		self.cx = nn.Parameter(torch.randn(1),requires_grad=True) 
		self.dx = nn.Parameter(torch.randn(1),requires_grad=True) 
#		self.gamma = nn.Parameter(torch.tensor(1.0)) 
		self.gamma = torch.tensor(1.0) 
	
	def transfomer_weights(self,weights):
		device = weights.device
		aw = (0.5 / self.dw)
		bw = (-0.5*self.cw / self.dw + 0.5)
		
		weights_t = torch.where( abs(weights) < self.cw - self.dw, weights - weights ,weights)
		weights_t = torch.where( abs(weights_t) > self.cw + self.dw,
									torch.sign(weights_t), weights_t)
		weights_t = torch.where( (abs(weights_t) >= self.cw - self.dw)&(abs(weights_t) <= self.cw + self.dw),
									(aw*abs(weights_t) + bw)**self.gamma  * weights_t.sign() , weights_t)
		return weights_t

	def transfomer_activation(self,x):
#		device = x.device
		ax,bx = (0.5 / self.dx) , (-0.5*self.cx / self.dx + 0.5)
		x_t = torch.where( x < self.cx - self.dx, x-x, x)
		x_t = torch.where( x_t > self.cx + self.dx,
				torch.sign(x_t), x_t )
		x_t = torch.where( ( x_t >= self.cx - self.dx ) & ( x_t <= self.cx + self.dx ),
							ax*x_t + bx, x_t)
		return x_t
	
	def discretizer(self,tensor):
		
		q_D = torch.pow(torch.tensor(2.0), torch.tensor(QilConv2d.discretization_level - 1)) 
		tensor_d = RoundNoGradient.apply( tensor * q_D) / q_D

		return tensor_d

	def params_print(self):
		print("\ncw : {}	dw : {} cx : {} dx : {} gamma : {} ".format(
			  self.cw, self.dw, self.cx, self.dx, self.gamma))

	def conv2d_forward(self, input, weight):
		if self.padding_mode == 'circular':
			expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
								(self.padding[0] + 1) // 2, self.padding[0] // 2)
			return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
							weight, self.bias, self.stride,
							_pair(0), self.dilation, self.groups)
		return F.conv2d(input, weight, self.bias, self.stride,
						self.padding, self.dilation, self.groups)

	def forward(self,input,Inference = False): 
		if not Inference:
			weight_t = self.transfomer_weights(self.weight)
			weight_qil = self.discretizer(weight_t)
		input_t = self.transfomer_activation(input)
		input_qil = self.discretizer(input_t)

		return self.conv2d_forward(input_qil,weight_qil)

Hi, I noticed two mistakes in this code after fast browsing the paper.
First, while may not influence in some circumstances:

weights_t = torch.where( abs(weights_t) > self.cw + self.dw,
									torch.sign(weights_t), weights_t)
weights_t = torch.where( (abs(weights_t) >= self.cw - self.dw)&(abs(weights_t) <= self.cw + self.dw),
									(aw*abs(weights_t) + bw)**self.gamma  * weights_t.sign() , weights_t)

The condition are based on the modified weights instead of the original weight, which may produce incorrect results.

Second, you are initializing

self.dw = nn.Parameter(torch.randn(1),requires_grad=True)

I think that dw should be the quantization interval and should be positive.
If it’s negative,

(abs(weights_t) >= self.cw - self.dw)&(abs(weights_t) <= self.cw + self.dw)

is never satisfied and your parameters never get trained, right?
You initialize it as random gaussian so it can be negative 50%.

Thanks for your reply
you pointed conditions are based on the modified weights instead of the original weight
Consequently, The result was the same when I replaced modified values with original weights in torch.where() condition.

And the second part, then If I initialize like

self.dw = nn.Parameter(torch.rand(1),requires_grad=True) # 0 ~ 1 real values

I think ‘self.dw’ can be negative after backpropagation.

Yes, it can.
But that’s the result after back propagation, so the self.dw received gradient and are updated to negative and then freeezed.

If you initialize self.dw negative, you always receive None gradients and the quantization parameters never updated from your intial value.