Hi, I am new to pytorch and I have some problems running my program on multiple GPUs.
I wrapped my network as the docs said
net = CNN() net.load_state_dict(torch.load(“cnn/model/epoch_1_subfile_55.pkl”)) net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3]) net = net.cuda()
The net module has a method named forward_batch. When I ran this code, it told me that ‘DataParallel’ object has no attribute ‘forward_batch’. I wonder if there is any method to solve this problem?
Thanks for your reply! I just want to forward batches. I know people usually rewrite the forward method and derive the result by out = net(input)
I add the forward_batch method because the inputs to my network sometimes are different. I derive the result by out = net.forward_batch(input)
It seems that the returned class doesn’t support this method. Maybe I should change the name of the method to forward?