Vectorizing index selection by nested dimension mean condition

I have the following snippet, I’m wondering how to effectively vectorize this in ‘pure’ pytorch:

indices = []
for i in range(0, self.dataset.shape[0]):
    if torch.mean(self.dataset[i]) >= .2:
        indices.append(i)

self.dataset = self.dataset[indices, :, :, :, :]
self.target = self.target[indices, :, :, :, :]

I tried
b = torch.mean(self.dataset, dim=0) > .2
indices = b.nonzero()

and

torch.where(torch.mean(self.dataset[:, ???]) > .2, self.dataset, torch.FloatTensor([0.0]))

to no avail. Any thoughts?

The first one should work, it would be nice to see the expected output and the one you get.

a = torch.FloatTensor(3,3)
b = a.mean(dim=0) > 2
c = b.nonzero()

a - > tensor([[-2.8721e+27, 4.5780e-41, -2.8721e+27], [ 4.5780e-41, 0.0000e+00, 0.0000e+00], [ 0.0000e+00, 6.8929e+34, 8.5771e-39]])

b -> tensor([False, True, False])

c -> tensor([[1]])

given the following:

dataset = torch.randn((123, 1, 64, 64, 64))
b = torch.mean(dataset, dim=0) > .2
c = b.nonzero()
print(c.shape)

this prints:

torch.Size([3495, 4])

where I expect it to be more like assuming that 23 of the original 123 3d volume elements don’t exceed the threshold

torch.Size(100, 1, 64, 64, 64)