DataParallel and tensors on different GPUs

I’m training a network in which the first layer is a filter bank like projection and the second is a convolution. The example below is just to illustrate the problem.

When using this network inside of DataParallel with more than one GPU I run into “tensors are on different GPUs”.

I also tried to not explicitly move the tensors to the GPU with the hope that calling .cuda() on model would solve the issue.
@SimonW Any thoughts on this?

Does anyone have a solution to this problem?

model = DataParallel(Network(frame_size, stride).cuda())
class Network(torch.nn.Module):
    def __init__(self, frame_size, stride):
        super(Network, self).__init__()
        self.projection = Projection(frame_size, stride)
        self.conf = torch.nn.Conv1d(frame_size/2, 1, 2, 1)

    def forward(self, signal):
        output = self.projection(signal)
        output = self.conv(output)
        return output

class Projection(torch.autograd.Function):
    def __init__(self, frame_size, stride):
        super(Projection, self).__init__()
        x = np.array(range(0, frame_size))
        f = np.array([np.sin(x*2*f*np.pi/frame_size) for f in range(0,int(frame_size/2))])
        f = f.astype(float)
        f = torch.from_numpy(f)
        f = torch.unsqueeze(f, 1).float()
        self.conv = torch.nn.Conv1d(1, int(frame_size/2),
                                    kernel_size=frame_size,
                                    stride = stride,
                                    bias=False)
        self.conv.weight.data = f
        if torch.cuda.is_available():
            self.conv.weight.data = self.conv.weight.data.cuda()
        self.conv.weight.requires_grad = False

    def forward(self, signals):
        signals = Variable(signals)
        conv_signal = self.conv(signals).data
        return conv_signal


class Filterbank(torch.nn.Module):
    def __init__(self, frame_size, stride):
        super(Filterbank, self).__init__()
        self.proj = Projection(frame_size, stride)

    def forward(self, signals):
        signals = torch.unsqueeze(signals, 1)
        signals_proj = self.proj(signals)**2
        return sin_signals

You shouldn’t attach parameters to autograd.Function. Instead, please pass them as an argument of forward. If you need them to be trainable, make them Parameters of a Module, otherwise make them Buffers of a Module.

@SimonW thanks for the reply!
I had the same error even when the parameter are defined like in the loss function below.
In this case where the projection is defined inside the loss, show the “weights” be defined as buffers?
What about the attribute self.mel_basis? What should it be defined as?

 class TacotronLoss(nn.Module):
    def __init__(self, sr=16000, mel_size=80, frame_size=512, stride=128, size_average=True):
        super(TacotronLoss, self).__init__()
        from librosa.filters import mel
        mel_basis = mel(sr, n_fft=frame_size-1, n_mels=mel_size)
        self.mel_basis = torch.from_numpy(mel_basis).float().cuda()
        self.spectrogram = Projection(frame_size, stride)
        self.size_average = size_average
        self.L1 = nn.L1Loss(self.size_average)

    def forward(self, denoised_mel, denoised_spect, original):
        original.requires_grad = False
        original_spect = self.spectrogram(original)
        original_mel = torch.matmul(self.mel_basis, original_spect.data)
        original_mel = torch.autograd.Variable(original_mel, requires_grad=False)
        mel_loss = self.L1(denoised_mel, original_mel)
        spec_loss = self.L1(denoised_spect, original_spect)
        return mel_loss + spec_loss

Sorry for the confusion. Let me answer it more clearly.

The problem is that a conv is initialized in a Function. Therefore, its weights are not registered as part of a nn.Module, so doing .cuda() on the module can’t change their locations, and DataParallel can’t assign them to correct GPU.

Since the projection is just a conv2d. I suggest not making an extra autograd.Function. Instead, do this:

def get_proj_conv_weight(frame_size, stride):
    ...

class FilterBank(nn.Module):
    def __init__(self, frame_size, stride):
        super(Filterbank, self).__init__()
        self.register_buffer('proj_w', get_proj_conv_weight(frame_size, stride)) 
        self.frame_size = frame_size
        self.stride = stride

    def forward(self, signals):
        signals = torch.unsqueeze(signals, 1)
        signals_proj = nn.functional.conv2d(signals, Variable(self.proj_w), kernel_size=self.frame_size, stride=self.stride)**2
        ...
        return sin_signals

If you want to use such operation in multiple modules. You can write a base module class that does all these and additionally has a method to do it. :slight_smile:

In your loss, mel_basis should be similarly defined as

self.register_buffer('mel_basis', torch.from_numpy(mel_basis).float())

Project weights should be applied similarly as my reply above :slight_smile:

Thanks a lot for your help, @SimonW !
With the register buffer for the loss function, i.e. mel_basis, what moves the buffer to the GPU device? Note that mel_basis is not part of a model but part of a loss function.

loss = TacotronLoss(...).cuda() should work. After all, loss is just another nn.Module.

1 Like