import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.autograd import Function
class Gaussian(Function):
@staticmethod
def forward(self,input,sigma,center):
dist1=center-input
self.dist1=dist1
dist=torch.pow(dist1,2)
dist2=torch.sum(dist,1)
self.dist2=dist2
sigma2=torch.pow(sigma,2)
dist=-1/2*sigma2*dist2
output=torch.exp(dist)
self.save_for_backward(input,sigma,center,output)
return output
@staticmethod
def backward(self,grad_output):
input,sigma,center,output=self.saved_variables
grad_input=grad_sigma=grad_center=None
if self.needs_input_grad[0]:
tmp=grad_output*output*self.sigma2
"""
the following is the error message
RuntimeError: mul() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
* (float other)
didn't match because some of the arguments have invalid types: (!torch.FloatTensor!)
* (Variable other)
didn't match because some of the arguments have invalid types: (!torch.FloatTensor!)
but grad_output,output,self.sigma2 are all variables, and I think there is no problem
"""
grad_input=torch.mv(center.t(),tmp)
if self.needs_input_grad[1]:
grad_sigma=-grad_output*output*sigma*self.dist2
if self.needs_input_grad[2]:
tmp = -grad_output*output*self.sigma2
tmp = torch.diag(tmp)
grad_center = torch.mm(tmp,self.dist1)
return grad_input,grad_sigma,grad_center