Something like this?
class BitShiftFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.x = x
ctx.shift = shift.item()
x_int = x.view(torch.int)
if ctx.shift > 0:
shifted = x_int << ctx.shift
else:
shifted = x_int >> -ctx.shift
return shifted.view(torch.float32)
@staticmethod
def backward(ctx, grad_output):
x = ctx.x
shift = ctx.shift
ctx.mark_non_differentiable(shift) # Mark shift as non-differentiable
return x, None
# Usage example
shift = 2
x = torch.randn(3, 3, requires_grad=True)
result = BitShiftFunction.apply(x, torch.tensor(shift))
result.backward(torch.ones_like(result))
print(f"x:\n{x}")
print(f"Shifted x:\n{result}")
print(f"Gradients of x:\n{x.grad}")
x:
tensor([[-2.5461, -0.0995, 0.4327],
[-0.5932, -0.4999, -0.7176],
[-0.7571, 0.7862, 1.7181]], requires_grad=True)
Shifted x:
tensor([[ 1.2839e-38, -3.5585e+33, -1.2778e+36],
[-4.6406e+36, -2.6573e+36, -9.2560e+36],
[-1.1241e+37, -1.3716e+37, -3.1861e+38]],
grad_fn=<BitShiftFunctionBackward>)
Gradients of x:
tensor([[-2.5461, -0.0995, 0.4327],
[-0.5932, -0.4999, -0.7176],
[-0.7571, 0.7862, 1.7181]])
-tip: based off this Deepseek answer