Variable assignment on multiple GPUs

I need to set a boolean flag in the code running on multiple GPUs. When using a single GPU, self.calculate_running is being set to False correctly after the first iteration. It’s not being set when I use more than one GPU:

class PCTL_Layer(nn.Module):
    def __init__(self, calculate_running=False):
        super(PCTL_Layer, self).__init__()
        self.register_buffer('running', torch.zeros(1))
        self.calculate_running = calculate_running

    def forward(self, input):
        if self.calculate_running:
            pctl, _ = torch.kthvalue(input.view(-1), int(input.numel() * 0.9))
            self.running = pctl
            print(' gpu {}  calculate_running: {}'.format(torch.cuda.current_device(), self.calculate_running))
            self.calculate_running = False

        return self.running


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.percentile = PCTL_Layer(calculate_running=True)

    def forward(self, x):
        x = self.percentile(x)
        return x

model = Model()
model = torch.nn.DataParallel(model).cuda()

for i, image in enumerate(train_loader):
    output = model(image)

I tried inserting torch.cuda.synchronize() everywhere (before and after), but that didn’t help. Any advice on how to make it work?

p.s. I realize that a simple solution would be to pass i to the forward method of PCTL_layer, however it’s complicated because that layer is packaged into nn.Sequential list, so I don’t know how to do that either.

I’ve executed your code on a machine with 8 GPUs and got the following output:

 model(torch.randn(8, 100, device='cuda'))

>  gpu 0  calculate_running: True
 gpu 1  calculate_running: True
 gpu 2  calculate_running: True
 gpu 4  calculate_running: True
 gpu 3  calculate_running: True
 gpu 7  calculate_running: True
 gpu 6  calculate_running: True
 gpu 5  calculate_running: True

Which seems to indicates the attribute is properly set.

Did you try more than one iteration? Because when I use 2 GPUs with 3 iterations I get the following:

 gpu 1  calculate_running: True
 gpu 0  calculate_running: True
 gpu 1  calculate_running: True
 gpu 0  calculate_running: True
 gpu 0  calculate_running: True
 gpu 1  calculate_running: True

But I want to see this:

 gpu 1  calculate_running: True
 gpu 0  calculate_running: True

Thanks for the follow-up.
I’ve missed the change in this attribute and get the undesired behavior.
Register the condition as a BoolTensor and it should work:

self.calculate_running = torch.tensor(calculate_running, dtype=torch.bool)

Does not work :frowning:

I tried both setting it to False and to torch.tensor(False, dtype=torch.bool):


class PCTL_Layer(nn.Module):
    def __init__(self, calculate_running=False):
        super(PCTL_Layer, self).__init__()
        self.register_buffer('running', torch.zeros(1))
        self.calculate_running = torch.tensor(calculate_running, dtype=torch.bool)

    def forward(self, input):
        if self.calculate_running:
            pctl, _ = torch.kthvalue(input.view(-1), int(input.numel() * 0.9))
            self.running = pctl
            print(' gpu {}  calculate_running: {}'.format(torch.cuda.current_device(), self.calculate_running))
            self.calculate_running = False
            #self.calculate_running = torch.tensor(False, dtype=torch.bool)

        return self.running

Even when I synchronize GPUs like this:

    def forward(self, input):
        if self.calculate_running:
            pctl, _ = torch.kthvalue(input.view(-1), int(input.numel() * 0.9))
            self.running = pctl
            print(' gpu {}  calculate_running: {}'.format(torch.cuda.current_device(), self.calculate_running))
            torch.cuda.synchronize()
            self.calculate_running = False
            torch.cuda.synchronize()

        return self.running

It still does not work. Strange, isn’t it?

You are right. I missed the second output due to an error message, sorry.
Thinking about it, it might make sense that this flag is not propagated using DataParallel.
E.g. what would the expected result be, if the models on GPU0,1,2 set this flag to False, while all others keep it as True?

I think the cleanest approach would be to manipulate this flag after the forward pass manually using:

_ = model(torch.randn(8, 100, device='cuda'))
model.module.percentile.calculate_running = False

which will make sure that each new copy of the model uses the new value.

Great, boolean variable assignment works. Thank you. However, now self.running is not being assigned correctly. When I run the following code on 2 GPUs:

class PCTL_Layer(nn.Module):
    def __init__(self, calculate_running=False):
        super(PCTL_Layer, self).__init__()
        self.register_buffer('running', torch.zeros(1))
        self.calculate_running = calculate_running

    def forward(self, input):
        if self.calculate_running:
            pctl, _ = torch.kthvalue(input.view(-1), int(input.numel() * 0.9))
            self.running = pctl
            self.calculate_running = False
        print('calculate_running: {}   input: {}   running: {:.4f}'.format(self.calculate_running, input, self.running.item()))
        return self.running


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.percentile = PCTL_Layer(calculate_running=True)

    def forward(self, x):
        x = self.percentile(x)
        return x


model = Model()
model = torch.nn.DataParallel(model).cuda()

for i in range(2):
    print('\nIteration', i)
    output = model(torch.randn(2, 4, device='cuda'))
    print('model.module.percentile.running: {}  model output: {}'.format(model.module.percentile.running, output))
    if i == 0:
        model.module.percentile.calculate_running = False

Here’s the output:

Iteration 0
calculate_running: False   input: tensor([[ 0.3747, -1.1507, -1.4812,  0.1900]], device='cuda:1')   running: 0.1900
calculate_running: False   input: tensor([[ 0.3961, -0.0903, -0.0330, -1.9596]], device='cuda:0')   running: -0.0330
/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
model.module.percentile.running: tensor([0.], device='cuda:0')  model output: tensor([-0.0330,  0.1900], device='cuda:0')

Iteration 1
calculate_running: False   input: tensor([[ 0.6758, -0.2846,  0.9055, -0.0532]], device='cuda:0')   running: 0.0000
calculate_running: False   input: tensor([[-0.3517, -0.8317, -0.7304,  0.7370]], device='cuda:1')   running: 0.0000
model.module.percentile.running: tensor([0.], device='cuda:0')  model output: tensor([0., 0.], device='cuda:0')

Why self.running is being reset to 0 after the first iteration? In the worst case, I want each GPU to use its own value of running. But ideally I want it to be the mean across all GPUs.

I can fix it like this:

class PCTL_Layer(nn.Module):
    def __init__(self, calculate_running=False):
        super(PCTL_Layer, self).__init__()
        self.register_buffer('running', torch.zeros(1))
        self.calculate_running = calculate_running

    def forward(self, input):
        if self.calculate_running:
            pctl, _ = torch.kthvalue(input.view(-1), int(input.numel() * 0.9))
            self.running = pctl
            model.module.running_list.append(self.running)
            self.calculate_running = False

        print('calculate_running: {}   input: {}   running: {:.4f}'.format(self.calculate_running, input, self.running.item()))
        return self.running


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.percentile = PCTL_Layer(calculate_running=True)
        self.running_list = []

    def forward(self, x):
        x = self.percentile(x)
        return x

model = Model()
model = torch.nn.DataParallel(model).cuda()

for i in range(3):
    print('\nIteration', i)
    output = model(torch.randn(2, 4, device='cuda'))
    print('model.module.percentile.running: {}'.format(model.module.percentile.running))
    if i == 0:
        model.module.percentile.calculate_running = False
        model.module.percentile.running = torch.tensor(model.module.running_list, device='cuda:0').mean()
    print('model output (running): {}  running_list: {}\n'.format(output, model.module.running_list))

Which produces the following output:

Iteration 0
calculate_running: False   input: tensor([[0.2759, 1.1083, 0.3523, 0.6275]], device='cuda:1')   running: 0.6275
calculate_running: False   input: tensor([[-1.5641,  0.3595,  1.7428, -0.5368]], device='cuda:0')   running: 0.3595
/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
model.module.percentile.running: tensor([0.], device='cuda:0')
model output (running): tensor([0.3595, 0.6275], device='cuda:0')  running_list: [tensor(0.3595, device='cuda:0'), tensor(0.6275, device='cuda:1')]


Iteration 1
calculate_running: False   input: tensor([[ 0.4074,  1.4806, -0.5506,  0.3985]], device='cuda:0')   running: 0.4935
calculate_running: False   input: tensor([[ 0.2667,  0.4528,  0.1397, -0.5707]], device='cuda:1')   running: 0.4935
model.module.percentile.running: 0.4935113787651062
model output (running): tensor([0.4935, 0.4935], device='cuda:0')  running_list: [tensor(0.3595, device='cuda:0'), tensor(0.6275, device='cuda:1')]

But this is pretty ugly. Is there a better way?

Also, note that when I’m averaging the running_list I have to move it to gpu-0. How do other GPUs access it? I mean, do they copy it from GPU-0 every time, or just once? If every time, how can I distribute it manually?

@ptrblck any insight about keeping the variable on all GPUs?

nn.DataParallel will scatter the model to all devices and gather their outputs on the master device. This blog post gives a good overview.

Since you are assigning a new value to running, I think the best way would be to have a look how data parallel works internally and adapt these methods to your use case.