Functional Conv2 accept a batch of weights

Hi,

At the moment I am trying to implement a meta-learning algorithms and the size of the model is quite large so I am also trying to use DataParallel. However, I am currently encountering an issue with one GPU taking the brunt of the load and running out of memory. This is since I generate weights for each sample in my batch, which means I have to loop over these weights and apply a functional conv and thus this operation cant be data paralleled and it ends up on the same GPU.

Is there any easy way to feed a batch of weights to a functional conv or are there any plans to implement this in pytorch in the near future?

Cheers,
Vincent

Not sure if I understand the request clearly, it would be helpful if you could share some pseudo code.

If all you need is scatter conv weights (one weight per sample) across different GPUs, looks like you can wrap that (samples + one conv layer per GPU) into a custom function? In the forward function, you can do sth like:

def forward(self, samples):
    outputs = []
    for sample in samples:
        weight = generate_per_sample_weight(sample)
        replace_conv_weight(self.conv, weight)
        outputs.append(self.conv(sample))
    return outputs

Sorry for the late response.

The code you provide is the pseudo code I would give and very similar to the code I have in my code base, one small change is that the weights are generated by different samples.

def forward(self, samples, reference_samples):
    outputs = []
    for sample in samples:
        weight = generate_per_sample_weight(reference_samples)
        replace_conv_weight(self.conv, weight)
        outputs.append(self.conv(sample))
    return outputs

However, the main issue with this is that pytorch doesn’t distribute this over multiple GPU’s because of the for loop. These calculation are all located on the first main GPU (with the standard dataparallel package)

I was wondering if there is an easy way to write something like this which still allows for data parallelisation:

weights: Batch x # INPUT FILTERS x # OUTPUT FILTERS x FILTER WIDTH x FILTER HEIGHT
samples: BATCH x CHANNELS x WIDTH x HEIGHT
def forward(self, samples, weights):
    outputs = self.conv(sample, weights)
    return outputs

So this self.conv function would then be one purely based on matrices like the original conv one, which should allow data parallelisation.

This should work, as DataParallel simply replicates model and scatters input. (assuming self.conv is a customized conv layer that replaces weight) So if you wrap that with DataParallel, different thread/replica should see samples/weights on a different device. Did you encounter any issue when doing this?

At the moment I have the following psuedo code:

def batch_conv(x, weight, bias=None, stride=1):
    for i in range(x.size()[0]):
        yi = F.conv_transpose2d(x[i:i+1], weight=weight[i], bias=bias[i,:weight.size(2)], padding=1, stride=int(1/stride), output_padding=1, groups=groups)
        y = concat(y, yi)
    return y

class AdaptiveConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, input, weight=None, bias=None, stride=1):
        return batch_conv(input, weight, bias, stride)

However this doesn’t distribute properly and my assumption is that the data parallel isn’t able to handle the for loop in my code.

I haven’t tried implementing a conv layer that takes a batch of weights instead of a single sample. Since I have only switched to pytorch recently and I am a bit out of depth with this.