When I use torch.ge() function in network, I got following error:
‘Variable’ object is not callable
Code:
def forward(self, x):
y = x.sum(dim = 1)
yy = y.view(y.size(0), -1)
avg = yy.mean(dim = 1)
mask = torch.ge(y, avg(0))
where, x is an NCH*W tensor.
This is because you are trying to call avg by passing in a value of 0. I think you wanted to do an access (i.e. avg[0] on it instead? Even so I don’t think the tensors y and avg[0] wouldn’t match here …?
It only accept float for input, I correct to
mask = torch.ge(y, float(avg.data.numpy()[0]) ])
Then it works well.