Implementing a discontinuous NN layer in `torch`

I’m trying to implement the following discontinuous layer proposed by a recent interesting paper ,and I’m wondering whether I’m doing it correctly.

Layer:

My first attempt was the following:

class DiscontLayer(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, output_dim)
        self.epsilon = nn.Parameter(torch.randn(output_dim))
                
    def forward(self, x):
        
        cont_part = nn.ReLU()(self.linear(x))
            
        discont_part = self.epsilon * torch.heaviside(self.linear(x),torch.ones_like(self.linear(x)))
        
        return cont_part + discont_part

But trying to run it, I encountered an error due to backpropagating through the Heaviside activation function. The derivative of this function is should be 0 everywhere, except at the discontinuity, but this seems not being implemented. Then I found this post, suggesting to protect the torch.heaviside call with torch.no_grad(). So I tried in the following way:

Then:

class DiscontLayer(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, output_dim)
        self.epsilon = nn.Parameter(torch.randn(output_dim))
                
    def forward(self, x):
        
        cont_part = nn.ReLU()(self.linear(x))
        
        discont_part = self.linear(x)
        
        with torch.no_grad():
            
            discont_part = torch.heaviside(discont_part,torch.ones_like(discont_part))
            
        discont_part = self.epsilon * discont_part
        
        return cont_part + discont_part

But I’m not yet 100% sure that I have done that correctly.

Many thanks for your help,

Fede

Hey Fede, were you able to solve this problem? I was also interested in the effect of using a discontinuous activation function.