Hi! I’m trying to do something fairly non-standard with my network, locally warping the 1D input signal in a manner similar to Spatial Transformer Networks. In other words, the first portion of my network (implementation shown below, albeit simplified to omit irrelevant code) will learn an optimal warping transformation and apply it to the data before the latter portion of the network processes it.
The STN portion should learn the absolute time difference between the previous and current sample, and then linear interpolation is performed to rescale the input signal. Basically, we can’t trust that the original input signal sampled the data at consistent intervals, so we’re warping the measured signal to make that the case. Unfortunately, grid_sample()
won’t work for me since I’m not interpolating for a regularly-spaced set of points; I used a 3rd party library instead, Interp1d.
The network isn’t learning anything useful, and I’m not sure if that’s because I don’t understand well enough how autograd
works.
class STN(nn.Module):
# omitting initialization stuff here
def forward(self, y, raw_lengths):
# extract features for STN
feats = self.localization([y])[0] # just a nn.Sequential(layers)
# predict length of each time step
dt = self.fc_loc(feats) # just a nn.Sequential(layers)
# get new indices by accumulating predicted timesteps
shifted_indices = torch.cumsum(dt,2)
# subtract first element so indices start at 0
old_indices = shifted_indices.sub(dt[:,:,0].unsqueeze(2).expand_as(shifted_indices))
# output indices will be regularly spaced
new_indices = torch.arange(30000, dtype=torch.float).cuda().repeat(y.shape[0],1)
# interpolate new output values
new_vals = None
new_vals = Interp1d.apply(old_indices.squeeze(), y.squeeze(), new_indices, new_vals)
# return result
return new_vals.unsqueeze(1).type(torch.HalfTensor).cuda()
Should this work as-is, or is there something else I need to do? Assuming the backward()
function in Interp1d works correctly, I’ve just done some summation and subtraction, which should be easily differentiable. Do I need to implement a custom backward()
function? Do I have to ensure y.requires_grad=True
?