I am doing a project about image registration. received this error:
Epoch 0/99: 33%|████████████████████████████████████████████████████████████████████████▎ | 5/15 [00:04<00:05, 1.74batch/s]> collate dict key “fixed_image” out of 2 keys
collate/stack a list of tensors
E: stack expects each tensor to be equal size, but got [1, 1024, 1024] at entry 0 and [4, 1024, 1024] at entry 1, shape [torch.Size([1, 1024, 1024]), torch.Size([4, 1024, 1024])] in collate([metatensor([[[1.8903e-02, 1.6555e-02, 1.4817e-02, …, 1.8006e-01,
1.8057e-01, 1.8318e-01],
[1.7500e-02, 1.5912e-02, 1.5085e-02, …, 1.7924e-01,
1.7876e-01, 1.8112e-01],
[1.7203e-02, 1.4965e-02, 1.5309e-02, …, 1.7940e-01,
1.7656e-01, 1.7810e-01],
…,
[5.5115e-04, 1.3683e-04, 0.0000e+00, …, 1.6299e-01,
1.6520e-01, 1.6168e-01],
[1.5516e-03, 7.7386e-04, 2.7677e-05, …, 1.6049e-01,
1.6329e-01, 1.6351e-01],
[2.6995e-03, 1.3090e-03, 2.9158e-04, …, 1.6152e-01,
1.6472e-01, 1.6307e-01]]]), metatensor([[[0.1033, 0.1383, 0.1528, …, 0.1248, 0.1120, 0.0833],
[0.1207, 0.1617, 0.1792, …, 0.1439, 0.1293, 0.0963],
[0.1208, 0.1617, 0.1793, …, 0.1421, 0.1276, 0.0951],
…,
[0.0073, 0.0099, 0.0111, …, 0.1347, 0.1210, 0.0900],
[0.0074, 0.0100, 0.0112, …, 0.1343, 0.1206, 0.0898],
[0.0064, 0.0085, 0.0095, …, 0.1145, 0.1028, 0.0765]],
[[0.1033, 0.1383, 0.1528, ..., 0.1248, 0.1120, 0.0833],
[0.1207, 0.1617, 0.1792, ..., 0.1439, 0.1293, 0.0963],
[0.1208, 0.1617, 0.1793, ..., 0.1421, 0.1276, 0.0951],
...,
[0.0073, 0.0099, 0.0111, ..., 0.1347, 0.1210, 0.0900],
[0.0074, 0.0100, 0.0112, ..., 0.1343, 0.1206, 0.0898],
[0.0064, 0.0085, 0.0095, ..., 0.1145, 0.1028, 0.0765]],
[[0.1033, 0.1383, 0.1528, ..., 0.1248, 0.1120, 0.0833],
[0.1207, 0.1617, 0.1792, ..., 0.1439, 0.1293, 0.0963],
[0.1208, 0.1617, 0.1793, ..., 0.1421, 0.1276, 0.0951],
...,
[0.0073, 0.0099, 0.0111, ..., 0.1347, 0.1210, 0.0900],
[0.0074, 0.0100, 0.0112, ..., 0.1343, 0.1206, 0.0898],
[0.0064, 0.0085, 0.0095, ..., 0.1145, 0.1028, 0.0765]],
[[0.1412, 0.1900, 0.2119, ..., 0.2119, 0.1900, 0.1412],
[0.1645, 0.2213, 0.2469, ..., 0.2469, 0.2213, 0.1645],
[0.1646, 0.2215, 0.2471, ..., 0.2471, 0.2215, 0.1646],
...,
[0.1646, 0.2215, 0.2471, ..., 0.2471, 0.2215, 0.1646],
[0.1645, 0.2213, 0.2469, ..., 0.2469, 0.2213, 0.1645],
[0.1412, 0.1900, 0.2119, ..., 0.2119, 0.1900, 0.1412]]])])
collate dict key “moving_image” out of 2 keys
collate/stack a list of tensors
Epoch 0/99: 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 12/15 [00:06<00:01, 1.77batch/s]
Traceback (most recent call last):
File “/project/med/Hassan_Ghavidel/transformer_target_localization/code/main_train_unsup_TransMorph.py”, line 269, in
train_val_test.train_val_model_unsupervised(model, train_loader, optimizer, config.loss_name, config.loss_weights,
File “/project/med/Hassan_Ghavidel/transformer_target_localization/code/auxiliary/train_val_test.py”, line 49, in train_val_model_unsupervised
for train_batch_data in tepoch:
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/tqdm/std.py”, line 1181, in iter
for obj in iterable:
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py”, line 628, in next
data = self._next_data()
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py”, line 1333, in _next_data
return self._process_data(data)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py”, line 1359, in _process_data
data.reraise()
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/_utils.py”, line 543, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/monai/data/utils.py”, line 514, in list_data_collate
ret[key] = collate_fn(data_for_batch)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py”, line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py”, line 120, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/monai/data/utils.py”, line 458, in collate_meta_tensor_fn
collated = collate_fn(batch) # type: ignore
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py”, line 163, in collate_tensor_fn
return torch.stack(batch, 0, out=out)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/monai/data/meta_tensor.py”, line 282, in torch_function
ret = super().torch_function(func, types, args, kwargs)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/_tensor.py”, line 1279, in torch_function
ret = func(*args, **kwargs)
RuntimeError: stack expects each tensor to be equal size, but got [1, 1024, 1024] at entry 0 and [4, 1024, 1024] at entry 1
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py”, line 302, in _worker_loop
data = fetcher.fetch(index)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py”, line 61, in fetch
return self.collate_fn(data)
File “/project/med/Hassan_Ghavidel/TransMorph_Transformer_for_Medical_Image_Registration/myenv/lib/python3.8/site-packages/monai/data/utils.py”, line 529, in list_data_collate
raise RuntimeError(re_str) from re
RuntimeError: stack expects each tensor to be equal size, but got [1, 1024, 1024] at entry 0 and [4, 1024, 1024] at entry 1
Collate error on the key ‘fixed_image’ of dictionary data.
MONAI hint: if your transforms intentionally create images of different shapes, creating your DataLoader
with collate_fn=pad_list_data_collate
might solve this problem (check its documentation).
the line that the code is complaining is (for train_batch_data in tepoch:)
in train_val_test.py script located in auxiliary folder:
# loop over all batches of data
with tqdm.tqdm(train_loader, unit="batch", initial=0) as tepoch:
for train_batch_data in tepoch:
tepoch.set_description(f"Epoch {epoch}/{epoch_start + epoch_nr - 1}")
if wandb_usage:
wandb.log({"lr": optimizer.param_groups[0]['lr']})
# clear stored gradients
optimizer.zero_grad()
# get batch of data
train_inputs, train_targets = train_batch_data["moving_image"].to(device), train_batch_data["fixed_image"].to(device)
and dataloader in main_train_unsup_TransMorph.py script are as follows:
train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=False, num_workers=4)
if config.unsupervised_validation:
val_loader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False, num_workers=4)
else:
val_loader = None
if config.supervised_validation:
val_loader_supervised = DataLoader(val_ds_supervised, batch_size=200, shuffle=False, num_workers=2)
else:
val_loader_supervised = None
check_data = monai.utils.first(train_loader)
print(f"Shape of check data: {check_data[“fixed_image”].shape}") # (h,w)
fixed_image = check_data[“fixed_image”][0][0] # eg (2,1,224,224)
moving_image = check_data[“moving_image”][0][0]
i dont have any clue how can i fix this issue. please help me.
this is the github link of the code:
i am trying to use this datasets of chests: