So do I need to do something like this?
for x, y in iterator:
x, y = x.to(self.device), y.to(self.device)
y = torch.unsqueeze(y, x.shape[1])` not sure if this is correct?
Adopted from https://github.com/qubvel/segmentation_models.pytorch/blob/e5d3db20e9c2ddb76f88642409e527239943c983/segmentation_models_pytorch/utils/train.py#L48-L52
class TrainEpoch(Epoch):
def __init__(self, model, loss, metrics, optimizer, logger, device='cpu', verbose=True, writer=None ):
super().__init__(
model=model,
loss=loss,
logger=logger,
metrics=metrics,
stage_name='train',
device=device,
verbose=verbose,
writer=None,
)
self.writer = writer
self.optimizer = optimizer
self.logger = logger
self.log_on_start = True
def on_epoch_start(self):
self.model.train()
def batch_update(self, x, y):
print('Shape', x.shape, y.shape)
self.optimizer.zero_grad()
prediction = self.model.forward(x)
if isinstance(prediction, dict):
prediction = prediction['out']
if self.log_on_start:
Warning("prediction is a dictionary, using 'out' key")
self.logger.warning("prediction is a dictionary, using 'out' key")
self.log_on_start = False
print(prediction.shape, y.shape)
loss = self.loss(prediction, y)
loss.backward()
self.optimizer.step()
return loss, prediction
I get some errors like this
for x, y in iterator:
x, y = x.to(self.device), y.to(self.device)
print('Shape iterator', x.shape, y.shape, torch.unsqueeze(y, dim=x.shape[1]).shape, x.shape[1])
# Add number of channels to y
# Get the number of channels from tensor x
num_channels = x.size(1)
# If y lacks the channel dimension, add it with the same number of channels as x
y = y.unsqueeze(1).expand(-1, num_channels, -1, -1)
Shape iterator torch.Size([2, 3, 512, 512]) torch.Size([2, 512, 512]) torch.Size([2, 512, 512, 1]) 3
new shape torch.Size([2, 3, 512, 512]) torch.Size([2, 3, 512, 512])
Shape Train torch.Size([2, 3, 512, 512]) torch.Size([2, 3, 512, 512])
prediction shape torch.Size([2, 8, 512, 512]) torch.Size([2, 3, 512, 512])
train Epoch 0: 0%| | 0/35 [00:04<?, ?it/s]
Traceback (most recent call last):
File "/kristina/dev/training/UNet/train_smp.py", line 96, in <module>
train_logs = train_epoch.run(_train_loader, epoch=i)
File "/kristina/dev/training/UNet/../smp/utils/train.py", line 65, in run
loss, y_pred = self.batch_update(x, y)
File "/kristina/dev//-training/UNet/../smp/utils/train.py", line 139, in batch_update
loss = self.loss(prediction, y)
File "/miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 1174, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/kristina/miniconda3/envs/conda_env/lib/python3.10/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [2, 3, 512, 512]
@ptrblck