hello, if I do something like this,
z = torch.randn(5, 5).unfold(0, 3, 1).unfold(1, 3, 1).reshape(1, 9, 9)
and now I zero out some elements in this tensor, then how do I get a 5x5 tensor back, if a value is zeroed out in any of these 9 3x3 blocks, then it should be zeroed out in my output 5x5 block also, rest of them should stay as they were?
if I use fold, then it adds up the non zero values in these 9 3x3 blocks, but I do not want addition.
for example,
if
z
is something like,
tensor([[[ 0.4094, 1.3269, 2.1112, -1.8682, 0.0420, -0.9150, 1.7852,
1.2070, 0.6966],
[ 1.3269, 2.1112, -0.1709, 0.0420, -0.9150, 0.9318, 1.2070,
0.6966, -0.0834],
[ 2.1112, -0.1709, -0.7779, -0.9150, 0.9318, 0.3695, 0.6966,
-0.0834, -0.7832],
[-1.8682, 0.0420, -0.9150, 1.7852, 1.2070, 0.6966, -0.8919,
-0.7964, 0.1060],
[ 0.0420, -0.9150, 0.9318, 1.2070, 0.6966, -0.0834, -0.7964,
0.1060, -0.4739],
[-0.9150, 0.9318, 0.3695, 0.6966, -0.0834, -0.7832, 0.1060,
-0.4739, 0.5941],
[ 1.7852, 1.2070, 0.6966, -0.8919, -0.7964, 0.1060, -0.2107,
1.1313, 0.1733],
[ 1.2070, 0.6966, -0.0834, -0.7964, 0.1060, -0.4739, 1.1313,
0.1733, -0.9812],
[ 0.6966, -0.0834, -0.7832, 0.1060, -0.4739, 0.5941, 0.1733,
-0.9812, -0.3873]]])
and then, I do,
z[0][1][0] = 0
so,
z
is now,
tensor([[[ 0.4094, **1.3269**, *2.1112*, -1.8682, 0.0420, -0.9150, 1.7852,
1.2070, 0.6966],
[ **0.0000**, *2.1112*, -0.1709, 0.0420, -0.9150, 0.9318, 1.2070,
0.6966, -0.0834],
[ *2.1112*, -0.1709, -0.7779, -0.9150, 0.9318, 0.3695, 0.6966,
-0.0834, -0.7832],
[-1.8682, 0.0420, -0.9150, 1.7852, 1.2070, 0.6966, -0.8919,
-0.7964, 0.1060],
[ 0.0420, -0.9150, 0.9318, 1.2070, 0.6966, -0.0834, -0.7964,
0.1060, -0.4739],
[-0.9150, 0.9318, 0.3695, 0.6966, -0.0834, -0.7832, 0.1060,
-0.4739, 0.5941],
[ 1.7852, 1.2070, 0.6966, -0.8919, -0.7964, 0.1060, -0.2107,
1.1313, 0.1733],
[ 1.2070, 0.6966, -0.0834, -0.7964, 0.1060, -0.4739, 1.1313,
0.1733, -0.9812],
[ 0.6966, -0.0834, -0.7832, 0.1060, -0.4739, 0.5941, 0.1733,
-0.9812, -0.3873]]])
and now, I apply fold, it give me,
x = nn.Fold((5, 5), 3)
x(z)
tensor([[[[ 0.4094, **1.3269**, *6.3337*, -0.3418, -0.7779],
[-3.7364, 0.1680, -5.4902, 3.7274, 0.7390],
[ 5.3557, 7.2419, 6.2697, -0.5001, -2.3496],
[-1.7838, -3.1855, 0.6361, -1.8955, 1.1882],
[-0.2107, 2.2625, 0.5198, -1.9624, -0.3873]]]])
for the value of 1.3269
, I want zero, because it was zeroed out in one of the 9 3x3 blocks, and for rest of the values, I do not want addition, like 2.1112*3
give 6.3337
, but I want 2.112
only.