Hi, I want to implement torch.autograd.Function with arbitary length of input, so here is the example
class my_func(torch.autograd.Function):
@staticmethod
def forward(ctx, x_1,x_2....,x_n):#the size of n can vary
result = f(x_1,.....x_n)
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return gradient_x_1, .....gradient_x_n# n corresponds to the length of input
I have tried something like input a list [x_1,x_2…x_n] into the function but it seems not to work. Another way is to stack x_1,x_2…x_n to one single tensor but in my implementation each x represents a neural network’s parameters so it will be quite troublesome and they may not have the same dimension.
Really appreciate it if you can help me with that.