AMP not casting custom Parameter tensor

I am trying to reproduce SReLU activation.

The model I wrote is below and it runs fine with normal precision (float32).

class SReLU(nn.Module):
    def __init__(self, normalized_shape=(1,), threshold=0.8, alpha=1e-1):
        """
SReLU activation (S-shaped Rectified Linear Unit), according to
arxiv.org/abs/1512.07030

The normalized_shape is similar to LayerNorm parameter (https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html).
For example, if normalized_shape is (3, 5) (a 2-dimensional shape), alpha
and threshold are computed over the last 2 dimensions of the input.

---

Common use cases:

# SReLU after Linear() layer which output shape is (N, C)
activation = SReLU(normalized_shape=(C,))

# SReLU after Conv1d() layer which output shape is (N, C, L)
activation = SReLU(normalized_shape=(C, 1))

# SReLU after Conv2d() layer which output shape is (N, C, H, W)
activation = SReLU(normalized_shape=(C, 1, 1))

# SReLU after Linear() layer which output shape is (N, L1, L2, L3)
activation = SReLU(normalized_shape=(L1, L2, L3))


        :param normalized_shape: (int or iterable) Input shape from an expected
        input of size: (*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[−1]).
        If a single integer is used, it is treated as a singleton
        list, and this module will normalize over the last dimension which is
        expected to be of that specific size.
        :param threshold: (float) Initial threshold value for both sides.
        :param alpha: (float) Initial slope value for both sides.
        """
        super().__init__()

        # Cast to Tuple, whatever the original type
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        else:
            normalized_shape = tuple(normalized_shape)

        self.threshold_l = nn.Parameter(torch.full(normalized_shape, -threshold, requires_grad=True))
        self.threshold_r = nn.Parameter(torch.full(normalized_shape, +threshold, requires_grad=True))

        self.alpha_l = nn.Parameter(torch.full(normalized_shape, alpha, requires_grad=True))
        self.alpha_r = nn.Parameter(torch.full(normalized_shape, alpha, requires_grad=True))

    def forward(self, x):
        return torch.where(x > self.threshold_r, self.threshold_r + self.alpha_r * (x - self.threshold_r),
                           torch.where(x < self.threshold_l, self.threshold_l + self.alpha_r * (x - self.threshold_l), x))

When I set with autocast(enabled=True): it gives me an error message in the forward pass:

torch.where(x < self.threshold_l, self.threshold_l + self.alpha_r * (x - self.threshold_l), x))
RuntimeError: expected scalar type float but found c10::Half

Debugging right before this operation shows that x.dtype == float16 and my parameters have float32 dtype.

Shouldn’t AMP cast my custom Parameters to the correct type before operations?
Maybe it doesn’t work inside torch.where() method?
Or did I badly define my Parameters?

I cannot reproduce the issue using a recent nightly binary:

class SReLU(nn.Module):
    def __init__(self, normalized_shape=(1,), threshold=0.8, alpha=1e-1):
        super().__init__()

        # Cast to Tuple, whatever the original type
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        else:
            normalized_shape = tuple(normalized_shape)

        self.threshold_l = nn.Parameter(torch.full(normalized_shape, -threshold, requires_grad=True))
        self.threshold_r = nn.Parameter(torch.full(normalized_shape, +threshold, requires_grad=True))

        self.alpha_l = nn.Parameter(torch.full(normalized_shape, alpha, requires_grad=True))
        self.alpha_r = nn.Parameter(torch.full(normalized_shape, alpha, requires_grad=True))

    def forward(self, x):
        return torch.where(x > self.threshold_r, self.threshold_r + self.alpha_r * (x - self.threshold_r),
                           torch.where(x < self.threshold_l, self.threshold_l + self.alpha_r * (x - self.threshold_l), x))

act = SReLU().cuda()
x = torch.randn(10).cuda()

with torch.cuda.amp.autocast():
    out = act(x)
    
print(out)
# tensor([-9.5997e-01, -2.3201e-01, -2.2078e-04,  1.8353e-01,  6.2051e-01,
#          2.6683e-01,  4.3467e-01,  6.7360e-01,  3.1208e-02,  3.9367e-01],
#        device='cuda:0', grad_fn=<WhereBackward0>)