class BinSFActive(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input = input.sign()
return input
@staticmethod
def backward(ctx, grad_output): #hard_tanh()=STE
input, = ctx.saved_tensors
grad_input = torch.clip(torch.ones_like(input),min=-1, max=1)#clip
grad_input[input.gt(1)] = 0
grad_input[input.lt(-1)] = 0
return grad_input
binsfactive = BinSFActive.apply
class BinSFConv2d(nn.Module):
def __init__(self, input_channels, output_channels,
kernel_size=-1, stride=-1, padding=-1, dropout=0):
super(BinSFConv2d, self).__init__()
self.layer_type = 'BinSFConv2d'
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dropout_ratio = dropout
self.bn = nn.BatchNorm2d(input_channels)
if dropout!=0:
self.dropout = nn.Dropout(dropout)
self.conv = nn.Conv2d(input_channels, output_channels,
kernel_size=kernel_size, stride=stride, padding=padding)
self.conv_g = nn.Conv2d(input_channels, output_channels * input_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=input_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
img_n, img_ch, img_r, img_c =x.size()
out_ch, _, _, _ = self.conv.weight.data.size()
t_min, t_max = x.min(), x.max()
t_variable = torch.div(255.0, torch.sub(t_max, t_min))
x_scaled = torch.multiply(torch.sub(x, t_min), t_variable)
x_scaled = torch.round(x_scaled)
results = torch.zeros((img_n, out_ch, img_r, img_c), device = x.device).cuda()
x = self.bn(x)
sign_x, mean = binsfactive(x) #shape N, Cin, W, H
with torch.no_grad():
filter_size = 11
order = 8
SF_matrix = smoothing_filter(x_scaled, filter_size, order)
#make scaling factor for binarized activation
if self.dropout_ratio!=0:
sign_x = self.dropout(sign_x)
outputs = self.conv_g(sign_x)
for i in range(img_ch):
results += torch.multiply(outputs[:,(out_ch*i):(out_ch*(i+1)),:,:],SF_matrix[:,i:i+1,:,:])
results = self.relu(results)
return results
This layer is binarized with weights and activations, and we will approximate real-value through the scaling factor.
So, the sign function is used for the forward of BinSFActive, which is responsible for the binarization of the activation. The gradient is defined by clip and straight through estimation because of the sign function’s gradient vanishing in the backward.
SF_matrix is a scaling factor and was derived using the histogram value of the input.
Grouped convolution(self.con_g) was used to multiply the scaling factor for each channel before they are added.
Since the smoothing filter function uses the input value, should I remove the no_grad()?
How to fix this code?