When I implement custom autograd function.
In the forward of class xxxFunction(torch.autograd.Function)
,
I get ctx.matches
, then I want to get the ctx.matches
outside this function.
I tried
@staticmethod
def get_matches(ctx):
return ctx.matches
Then in the class xxx(torch.nn.Module)
, I run xxxFunction.get_matches()
.
get_matches() takes exactly 1 argument (0 given)
.
Thank you in advance!
1 Like
@albanD could you help me in this problem?
albanD
(Alban D)
December 17, 2018, 10:37am
3
Hi,
You can access this from the backward
method that gets the same ctx
as input.
I don’t think you can extract it in a reliable way though.
Why do you need to do this? Can’t you just return it as another output of the forward
method?
1 Like
I am facing the same issue. If I return it as another output of the forward
method, how to adjust the backward
method? @albanD
Many thanks!
albanD
(Alban D)
September 10, 2019, 8:28pm
5
The backward method should get and return None for everything that is not differentiable.
1 Like
Zhaoyi-Yan
(Zhaoyi Yan)
September 17, 2019, 3:02am
6
Here is an example:
cur_device = torch.cuda.current_device()
self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
self.bz, self.c, self.h, self.w = input.size()
self._split_mask(self.bz)
self.flag = util.cal_flag_given_mask_thred(self.cur_mask, self.shift_sz, self.stride, self.mask_thred)
final_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow)
if self.show_flow:
self.flow_srcs = InnerShiftTripleFunction.get_flow_src()
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False