Hello,
This must be already known issue (but just in case it might be worth reporting it): Pytorch behavior is not consistent when working on GPU Vs CPU: parameters are deep copied only in the former case when initializing an nn.Module
While I understand why this might be the case, it’s rather cumbersome. Below is a ‘minimal’ example where I build 2 modules with the same parameters on the CPU and on the GPU :
import torch.nn as nn
class dummy_class(nn.Module):
def __init__(self, param):
super().__init__()
self.param = nn.Parameter(param)
def print_param(self):
return str(self.param[0].item())
# Initial Parameter
param = torch.tensor([0.5])
# Devices
cpu_device = torch.device("cpu")
gpu_device = torch.device("cuda")
# Build 2 dummy_class object on CPU and GPU
on_cpu = [dummy_class(param).to(cpu_device) for _ in range(2)]
on_gpu = [dummy_class(param).to(gpu_device) for _ in range(2)]
Now if, I modify only one object and print the result, the behavior differs:
# Print Param
print_param = lambda x: [print('Object ' + str(ii) + '= ' + obji.print_param()) for ii, obji in enumerate(x)]
print('======')
print('On CPU')
print('======')
print('Initial Parameter')
print_param(on_cpu)
on_cpu[0].state_dict()['param'] += 1
print('After Modifying Object 0')
print_param(on_cpu)
print('')
print('======')
print('On GPU')
print('======')
print('Initial Parameter')
print_param(on_gpu)
on_gpu[0].state_dict()['param'] += 1
print('After Modifying Object 0')
print_param(on_gpu)
print('')
Outputs:
======
On CPU
======
Initial Parameter
Object 0= 0.5
Object 1= 0.5
After Modifying Object 0
Object 0= 1.5
Object 1= 1.5
======
On GPU
======
Initial Parameter
Object 0= 0.5
Object 1= 0.5
After Modifying Object 0
Object 0= 1.5
Object 1= 0.5