I have the following snippet, I’m wondering how to effectively vectorize this in ‘pure’ pytorch:
indices = []
for i in range(0, self.dataset.shape[0]):
if torch.mean(self.dataset[i]) >= .2:
indices.append(i)
self.dataset = self.dataset[indices, :, :, :, :]
self.target = self.target[indices, :, :, :, :]
I tried
b = torch.mean(self.dataset, dim=0) > .2
indices = b.nonzero()
and
torch.where(torch.mean(self.dataset[:, ???]) > .2, self.dataset, torch.FloatTensor([0.0]))
to no avail. Any thoughts?
bsridatta
(Sri Datta Budaraju)
2
The first one should work, it would be nice to see the expected output and the one you get.
a = torch.FloatTensor(3,3)
b = a.mean(dim=0) > 2
c = b.nonzero()
a - > tensor([[-2.8721e+27, 4.5780e-41, -2.8721e+27], [ 4.5780e-41, 0.0000e+00, 0.0000e+00], [ 0.0000e+00, 6.8929e+34, 8.5771e-39]])
b -> tensor([False, True, False])
c -> tensor([[1]])
given the following:
dataset = torch.randn((123, 1, 64, 64, 64))
b = torch.mean(dataset, dim=0) > .2
c = b.nonzero()
print(c.shape)
this prints:
torch.Size([3495, 4])
where I expect it to be more like assuming that 23 of the original 123 3d volume elements don’t exceed the threshold
torch.Size(100, 1, 64, 64, 64)