Bitwise Operations on Cuda Float Tensor

Something like this?

class BitShiftFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, shift):
        ctx.x = x
        ctx.shift = shift.item()
        x_int = x.view(torch.int)
        if ctx.shift > 0:
            shifted = x_int << ctx.shift
        else:
            shifted = x_int >> -ctx.shift
        return shifted.view(torch.float32)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.x
        shift = ctx.shift
        ctx.mark_non_differentiable(shift)  # Mark shift as non-differentiable
        return x, None

# Usage example
shift = 2
x = torch.randn(3, 3, requires_grad=True)
result = BitShiftFunction.apply(x, torch.tensor(shift))
result.backward(torch.ones_like(result))

print(f"x:\n{x}")
print(f"Shifted x:\n{result}")
print(f"Gradients of x:\n{x.grad}")
x:
tensor([[-2.5461, -0.0995,  0.4327],
        [-0.5932, -0.4999, -0.7176],
        [-0.7571,  0.7862,  1.7181]], requires_grad=True)
Shifted x:
tensor([[ 1.2839e-38, -3.5585e+33, -1.2778e+36],
        [-4.6406e+36, -2.6573e+36, -9.2560e+36],
        [-1.1241e+37, -1.3716e+37, -3.1861e+38]],
       grad_fn=<BitShiftFunctionBackward>)
Gradients of x:
tensor([[-2.5461, -0.0995,  0.4327],
        [-0.5932, -0.4999, -0.7176],
        [-0.7571,  0.7862,  1.7181]])

:tophat:-tip: based off this Deepseek answer