How to calculate the instance-wise mean value in a batch?


(李志) #1

Might a little confused about the question,I’ll clarify here.
Given a training batch with NCHW,I want to calculate the mean for each example, that’s to say, for every CHW, calculate a mean value. so the result is N1, N stands for the batch size.
then I 'll do a point-wise division in each instance, so that each point in each image is divided by their own sum.


#2

Here is a small code snippet calculating the mean for each sample in the batch:

N, C, H, W = 10, 3, 24, 24
x = torch.randn(N, C, H, W)
x_mean = x.view(N, -1).mean(1, keepdim=True)
x_norm = x / x_mean[:, :, None, None]

I’m not sure, if the last line is achieving, what you are trying to do.
Each point would be divided by the mean of the sample not its sum.


(李志) #3

Seems right.I never run it But the api ‘torch.renorm’ works well of my own purpose.
I’m trying to complete a code to do image inpainting.If works, I’ll push it on github.
Thank u!


(李志) #4

another question: Is there any existing API can achieve the follows:
a point-wise tensor calculation:
e.g, if threshold set to 0.3,all value less than 0.3 will replace by 0,otherwise the rest will be set to 1(Cuz they are bigger than 0.3)
Is there any API like this?


#5

I think torch.where is what you are looking for:

x = torch.randn(10)
torch.where(x > 0.3, torch.tensor(1.0), torch.tensor(0.0))

(李志) #6

Wow! Fantastic! Cuz my Pytorch Env lays on the remote host in my lab,it shows none documentation of function in module torch, such as torch.abs,torch,where, etc. and I 'm still getting familiar with this language, Thank you!


(李志) #7

but this API do the delete operation on my pytorch env, it delete all the False condition value and form a new tensor with condition True.


(李志) #8

I have an idea about this ,for instance ,if u want all value >0.3 to 1 and <0.3 to 0, U can minus this threshold,so u subtract 0.3 and then apply a sign op.


#9

I don’t really understand the issue.
torch.where should already return a tensor with ones and zeros based on the condition you’ve provided.