Hi! I’m trying to do something fairly non-standard with my network, locally warping the 1D input signal in a manner similar to Spatial Transformer Networks. In other words, the first portion of my network (implementation shown below, albeit simplified to omit irrelevant code) will learn an optimal warping transformation and apply it to the data before the latter portion of the network processes it.
The STN portion should learn the absolute time difference between the previous and current sample, and then linear interpolation is performed to rescale the input signal. Basically, we can’t trust that the original input signal sampled the data at consistent intervals, so we’re warping the measured signal to make that the case. Unfortunately,
grid_sample() won’t work for me since I’m not interpolating for a regularly-spaced set of points; I used a 3rd party library instead, Interp1d.
The network isn’t learning anything useful, and I’m not sure if that’s because I don’t understand well enough how
class STN(nn.Module): # omitting initialization stuff here def forward(self, y, raw_lengths): # extract features for STN feats = self.localization([y]) # just a nn.Sequential(layers) # predict length of each time step dt = self.fc_loc(feats) # just a nn.Sequential(layers) # get new indices by accumulating predicted timesteps shifted_indices = torch.cumsum(dt,2) # subtract first element so indices start at 0 old_indices = shifted_indices.sub(dt[:,:,0].unsqueeze(2).expand_as(shifted_indices)) # output indices will be regularly spaced new_indices = torch.arange(30000, dtype=torch.float).cuda().repeat(y.shape,1) # interpolate new output values new_vals = None new_vals = Interp1d.apply(old_indices.squeeze(), y.squeeze(), new_indices, new_vals) # return result return new_vals.unsqueeze(1).type(torch.HalfTensor).cuda()
Should this work as-is, or is there something else I need to do? Assuming the
backward() function in Interp1d works correctly, I’ve just done some summation and subtraction, which should be easily differentiable. Do I need to implement a custom
backward() function? Do I have to ensure