Torch.where breaks the gradient flow

Hi!

I’m working with a dataset of wireless interference measurements, and for my calculations, I need to extract the measurements made for specific channels. The channels are chosen according to a hopping algorithm specified by the IEEE 802.15.4 standard. I’m using a neural networkto blacklist a few of the channels and exclude them from the hopping procedure. I came up with the following solution that uses torch.where:

datapoints, channels = interference.shape
available_channels, = torch.where(blacklisted < 0.5)
channels_in_a_round = available_channels.view(-1, 1).repeat(1, 20).flatten()
round_length, = channels_in_a_round.shape
length_ratio = np.ceil(datapoints / round_length).astype('int')
assigned_channels = channels_in_a_round.repeat(length_ratio)[:datapoints]
interference_in_assigned_channels = interference[torch.arange(datapoints), assigned_channels]
# Some calculations on interference_in_assigned_channels leading to the final output of the model

Now, even though blacklisted is a tensor with requires_grad=True (it’s actually the sigmoid-activated output of a neural network with 16 dimensions) the gradient flow breaks and the available_channels tensor no longer carries this property. I’m guessing this operation (or at least the way I’m doing it) is not supported by autograd. I’d appreciate it if you could help me out. How should I modify this code so that the gradient flow stays intact?

According to https://pytorch.org/docs/stable/torch.html#torch.where, there r 2 ways to use torch.where. Here, it seems that u r using the non-differential one. torch.where(condition) is equivalent to torch.nonzero(condition, as_tuple=True), which returns indices, and thus broke the graph.