class RBFfun(Function):
def forward(self,input,weight,sigma,center):
t=torch.mv(weight,input)
dtc=t-center
dist2=torch.sum(dtc*dtc)
output=-(sigma*sigma)*dist2/2
output=torch.exp(output)
self.save_for_backward(input,weight,sigma,center,output)
self.dtc=dtc
return output
def backward(self,grad_output):
input,weight,sigma,center,output=self.saved_tensors
dtc=self.dtc
grad_input=grad_weight=grad_sigma=grad_center=None
if self.needs_input_grad[0]:
m,n=weight.size()[0],weight.size()[1]
Hess=Variable(torch.zeros(n,m))
for i in range(n):
Hess[i]=-(torch.dot(dtc,weight[:,i]))*output*pow(sigma,2)
grad_input=torch.mv(Hess,grad_output)
return grad_input,grad_weight,grad_sigma,grad_center
the error message is “mv(): argument ‘vec’ (position 1) must be Variable, not torch.FloatTensor”
but I think the grad_output is Variable