Can you just assign back the sub tensor to the original at the desired index?
t = torch.arange(0, 36).view(2, 2, 3, 3)
binary_map = (torch.arange(0, 4) % 2 == 0).view(2, 2)
sub_t = t[binary_map, :, :]
sub_t = sub_t * 0
t[binary_map, :, :] = sub_t # just assigning it back
print(t)
Output:
tensor([[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]],
[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[27, 28, 29],
[30, 31, 32],
[33, 34, 35]]]])