Hello,
I am attempting to build an image classification pipeline that uses a convolutional autoencoder (CAE) to compute error maps in the __getitem__
method of a Dataset for a CNN classifier. This way I don’t have to save a new set of reconstruction error maps for each trial of CAE alternatives, and instead can compute the error maps ‘on-the-fly’ with a selected CAE. You can see how this works in my dataset class where I call self.model
to get the reconstructed image and continue to return the error map.
class ErrorMapDataset(torch.utils.data.Dataset):
def __init__(self, model, root: str=None, dirs: List=None,
transform: str=None, label: torch.tensor=None):
super(ErrorMapDataset, self).__init__()
self.root = root
self.dirs = dirs
self.transform = transform
self.collect_sets()
self.model = model
# self.model.freeze()
self.model.eval()
def __len__(self):
return len(self.img_files)
def __getitem__(self, index):
label = self.img_files[index][1]
img_file = os.path.join(self.root, self.img_files[index][0])
x = np.load(img_file)
with torch.no_grad():
if self.transform:
x = self.transform(x)
x = x.float()
x_reconst = self.model(x.unsqueeze(0)) # Unsqueeze to account for batch
error_map = squared_error(x_reconst, x)
return error_map.squeeze(), label
def collect_sets(self):
self.img_files = []
for i, path in enumerate(self.dirs):
for f in os.listdir(os.path.join(self.root, path)):
file = os.path.join(path, f)
self.img_files.append( (file, i) )
The problem that I am running into (I think) is that PyTorch is building the CAE (in the Dataset) into the CNN computations, even though the only place the CAE is found is in the Dataset. This can be seen from calling print(cnn)
,
BinaryCNN(
(cae): CAE(
(conv1): Conv2d(6, 12, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
(conv2): Conv2d(12, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(conv3): Conv2d(8, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv4): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv5): Conv2d(8, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(conv6): Conv2d(12, 6, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
(unpool): Upsample(scale_factor=2.0, mode=bilinear)
)
(loss_function): BCEWithLogitsLoss()
(conv): Sequential(
(0): Sequential(
(0): Conv2d(6, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(1): Sequential(
(0): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
)
(fc): Sequential(
(0): Linear(in_features=16384, out_features=512, bias=True)
)
(output): Linear(in_features=512, out_features=1, bias=True)
)
Thus, when I cast the CNN to my GPU for training, it is also casting the CAE to the GPU and prohibiting the __getitem__
method in the Dataset from functioning properly (which runs on the CPU); I get the following error,
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataset.py", line 257, in __getitem__
return self.dataset[self.indices[idx]]
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataset.py", line 257, in __getitem__
return self.dataset[self.indices[idx]]
File "/home/brahste/Documents/masters_docs/code/lighn_cae/lightn/utils/datasets.py", line 69, in __getitem__
x_reconst = self.model(x.unsqueeze(0))
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/brahste/Documents/masters_docs/code/lighn_cae/lightn/models/cae.py", line 65, in forward
z = self.encode(x)
File "/home/brahste/Documents/masters_docs/code/lighn_cae/lightn/models/cae.py", line 49, in encode
x = F.relu( self.conv1(x) ) # 64.64.6 -> 64.64.12
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 345, in forward
return self.conv2d_forward(input, self.weight)
File "/home/brahste/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 341, in conv2d_forward
return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
To fix this I tried to manually override the device of CAE in the Dataset class, but it would not change. Futhermore, if I change the original image x
to the GPU I get the error RuntimeError: Cannot re-initialize CUDA in forked subprocess.
When training on the CPU no issues are found, only when using the GPU. I have done some digging and tried everything I could find about separating the CAE from the computation graph, including wrapping the returned item with torch.no_grad:
, using .detch()
, .freeze()
, .eval()
, and even casting the loaded CAE model to the CPU,
device = torch.device('cpu')
with torch.no_grad():
cae = models.cae.CAE(config['exp_params'])
cae.load_state_dict(torch.load('logs/CAE_.pt', map_location=device))
but with no success.
I am relatively new to PyTorch and expect that my issue is due to a lack of understanding with autograd
, any help to better understand my situation is highly appreciated.
System:
Python 3.8.2
Pytorch 1.4.0
Torchvision 0.5.0
Pytorch Lightning 0.7.5
CUDA Runtime 10.1