After some debugging I have found that during some backward passes of a KPFCNN model (defined here ~line 189) the custom gather operation (defined here ~line 35), based on model input size, is always the culprit of an OOM exception.
def gather(x, idx, method=2): """ implementation of a custom gather operation for faster backwards. :param x: input with shape [N, D_1, ... D_d] :param idx: indexing with shape [n_1, ..., n_m] :param method: Choice of the method :return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d] """ if method == 0: return x[idx] elif method == 1: x = x.unsqueeze(1) x = x.expand((-1, idx.shape[-1], -1)) idx = idx.unsqueeze(2) idx = idx.expand((-1, -1, x.shape[-1])) return x.gather(0, idx) elif method == 2: for i, ni in enumerate(idx.size()[1:]): x = x.unsqueeze(i+1) new_s = list(x.size()) new_s[i+1] = ni x = x.expand(new_s) n = len(idx.size()) for i, di in enumerate(x.size()[n:]): idx = idx.unsqueeze(i+n) new_s = list(idx.size()) new_s[i+n] = di idx = idx.expand(new_s) return x.gather(0, idx) else: raise ValueError('Unkown method')
Based on the code comments this is a custom implementation of
torch.gather that provides some execution speed-up. I am attempting to understand/determine if this custom operation is truly the culprit of the OOM I am experiencing, though I am not exactly well versed enough just yet to fully understand the memory requirements of the custom operation during a backward pass.
I am currently attempting to replace the use of this operation with
torch.gather to determine if any decrease in memory usage is possible. But until then, or if that does not provide a clear answer, is it possible that this custom
gather operation has a larger, or much larger, memory requirement than