Hi, I am trying to use DataParallel
to get ride of out-of-memory that I ran into when I train my code. I read the docs of PyTorch and I found it quite easy. But I do not know for what reason it doesn’t work in my code.
Here is my code. The model is a resnet in 3D.
model = generate_resnet3D(conf.model_depth,conf.in_c)
model = nn.DataParallel(model, device_ids = [0,4,5])
model.to(device)
def train(epoch):
model.train()
# to track the training loss as the model trains
train_losses = 0
for num,(vol1,vol2) in enumerate(train_loader):
# first we change volums to torch tensors
vol1 = vol1.to(device)
vol2 = vol2.to(device)
# next we apply the model to get representations
x1,h1 = model(vol1)
x2,h2 = model(vol2)
print("Outside: input size", vol1.size(),
"output_size", x1.size())```
I use a batch of size 3. But when I check the output I found that the batch is not split.
```Outside: input size torch.Size([3, 1, 128, 128, 128]) output_size torch.Size([3, 2048, 10, 10, 10])
Outside: input size torch.Size([3, 1, 128, 128, 128]) output_size torch.Size([3, 2048, 10, 10, 10])
Outside: input size torch.Size([3, 1, 128, 128, 128]) output_size torch.Size([3, 2048, 10, 10, 10])
Outside: input size torch.Size([3, 1, 128, 128, 128]) output_size torch.Size([3, 2048, 10, 10, 10])
Outside: input size torch.Size([3, 1, 128, 128, 128]) output_size torch.Size([3, 2048, 10, 10, 10])
Outside: input size torch.Size([3, 1, 128, 128, 128]) output_size torch.Size([3, 2048, 10, 10, 10])```
Here is another example that I checked DataParallel:
```gpu_usage()
model = nn.DataParallel(model,device_ids=[0,1])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
input = torch.randn(2,1,128,128,128)
input.to(device)
output1, output2 = model(input)
print("Outside: input size", input.size(),
"output_size", output2.size())
torch.cuda.empty_cache()
sys.stdout.flush()```
The output again wasn't split in two parts!
```| ID | GPU | MEM |
------------------
| 0 | 0% | 7% |
| 1 | 0% | 0% |
| 2 | 84% | 64% |
| 3 | 56% | 50% |
| 4 | 0% | 7% |
| 5 | 0% | 7% |
| 6 | 94% | 99% |
| 7 | 92% | 99% |
Outside: input size torch.Size([2, 1, 128, 128, 128]) output_size torch.Size([2, 2048])```
The first line shows the memory usage. I will be thankful if you help me to fix this problem.