I`m using DDP on my work but I got this error:
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File “/mnt/disk3/nazik/RCPS_mama/models/umamba_botP.py”, line 251, in forward
[rank1]: x = self.stem(x)
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1739, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1750, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/nn/modules/container.py”, line 250, in forward
[rank1]: input = module(input)
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1739, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1750, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/dynamic_network_architectures/building_blocks/residual.py”, line 111, in forward
[rank1]: out += residual
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/monai/data/meta_tensor.py”, line 282, in torch_function
[rank1]: ret = super().torch_function(func, types, args, kwargs)
[rank1]: File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/_tensor.py”, line 1648, in torch_function
[rank1]: ret = func(*args, **kwargs)
[rank1]: RuntimeError: Output 0 of SyncBatchNormBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.
E0303 10:37:05.454000 16464 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 16550) of binary: /home/nazik/anaconda3/envs/nazik/bin/python
Traceback (most recent call last):
File “/home/nazik/anaconda3/envs/nazik/bin/torchrun”, line 8, in
sys.exit(main())
File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py”, line 355, in wrapper
return f(*args, **kwargs)
File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/distributed/run.py”, line 918, in main
run(args)
File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/distributed/run.py”, line 909, in run
elastic_launch(
File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/distributed/launcher/api.py”, line 138, in call
return launch_agent(self._config, self._entrypoint, list(args))
File “/home/nazik/anaconda3/envs/nazik/lib/python3.10/site-packages/torch/distributed/launcher/api.py”, line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
train.py FAILED
Failures:
[1]:
time : 2025-03-03_10:37:05
host : siat140
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 16551)
error_file: <N/A>
traceback : To enable traceback see: Error Propagation — PyTorch 2.6 documentation
Root Cause (first observed failure):
[0]:
time : 2025-03-03_10:37:05
host : siat140
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 16550)
error_file: <N/A>
traceback : To enable traceback see: Error Propagation — PyTorch 2.6 documentation
my model definition is :
import numpy as np
import math
import torch
from torch import nn
from torch.nn import functional as F
from typing import Union, Type, List, Tuple
from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager
from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
from nnunetv2.utilities.network_initialization import InitWeights_He
from mamba_ssm import Mamba
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
from torch.cuda.amp import autocast
from dynamic_network_architectures.building_blocks.residual import BasicBlockD
class UpsampleLayer(nn.Module):
def init(
self,
conv_op,
input_channels,
output_channels,
pool_op_kernel_size,
mode=‘nearest’
):
super().init()
self.conv = conv_op(input_channels, output_channels, kernel_size=1)
self.pool_op_kernel_size = pool_op_kernel_size
self.mode = mode
def forward(self, x):
print(f"Upsampling with scale factor: {self.pool_op_kernel_size}") # Debug: Print scale factor
x = F.interpolate(x, scale_factor=self.pool_op_kernel_size, mode=self.mode)
x = x.clone() # Clone the interpolated tensor
print(f"Upsampled shape: {x.shape}") # Debug: Print upsampled shape
x = self.conv(x)
return x
class MambaLayer(nn.Module):
def init(self, dim, d_state=16, d_conv=4, expand=2):
super().init()
self.dim = dim
self.norm = nn.LayerNorm(dim)
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand, # Block expansion factor
)
@autocast(enabled=False)
def forward(self, x):
if x.dtype == torch.float16:
x = x.type(torch.float32)
x = x.clone() # Clone the input tensor
B, C = x.shape[:2]
assert C == self.dim
n_tokens = x.shape[2:].numel()
img_dims = x.shape[2:]
# print(f"Input shape: {x.shape}") # Debug: Print input shape
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
# print(f"Reshaped input shape: {x_flat.shape}") # Debug: Print reshaped input shape
x_norm = self.norm(x_flat)
x_mamba = self.mamba(x_norm)
# print(f"Mamba output shape: {x_mamba.shape}") # Debug: Print Mamba output shape
out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
# print(f"Final output shape: {out.shape}") # Debug: Print final output shape
return out
class BasicResBlock(nn.Module):
def init(
self,
conv_op,
input_channels,
output_channels,
norm_op,
norm_op_kwargs,
kernel_size=3,
padding=1,
stride=1,
use_1x1conv=False,
nonlin=nn.LeakyReLU,
nonlin_kwargs={‘inplace’: False}
):
super().init()
self.conv1 = conv_op(input_channels, output_channels, kernel_size, stride=stride, padding=padding)
self.norm1 = norm_op(output_channels, **norm_op_kwargs)
self.act1 = nonlin(**nonlin_kwargs)
self.conv2 = conv_op(output_channels, output_channels, kernel_size, padding=padding)
self.norm2 = norm_op(output_channels, **norm_op_kwargs)
self.act2 = nonlin(**nonlin_kwargs)
if use_1x1conv:
self.conv3 = conv_op(input_channels, output_channels, kernel_size=1, stride=stride)
else:
self.conv3 = None
def forward(self, x):
y = self.conv1(x)
print(f"Shape after conv1: {y.shape}") # Debug: Print shape after conv1
y = self.act1(self.norm1(y).clone())
print(f"Shape after norm1 and act1: {y.shape}") # Debug: Print shape after norm1 and act1
y = self.norm2(self.conv2(y))
if self.conv3:
x = self.conv3(x)
y=y + x.clone()
return self.act2(y.clone())
class UNetResEncoder(nn.Module):
def init(self,
input_channels: int,
n_stages: int,
features_per_stage: Union[int, List[int], Tuple[int, …]],
conv_op: Type[_ConvNd],
kernel_sizes: Union[int, List[int], Tuple[int, …]],
strides: Union[int, List[int], Tuple[int, …], Tuple[Tuple[int, …], …]],
n_blocks_per_stage: Union[int, List[int], Tuple[int, …]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
return_skips: bool = False,
stem_channels: int = None,
pool_type: str = ‘conv’,
):
super().init()
if isinstance(kernel_sizes, int):
kernel_sizes = [kernel_sizes] * n_stages
if isinstance(features_per_stage, int):
features_per_stage = [features_per_stage] * n_stages
if isinstance(n_blocks_per_stage, int):
n_blocks_per_stage = [n_blocks_per_stage] * n_stages
if isinstance(strides, int):
strides = [strides] * n_stages
assert len(
kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"
assert len(
n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"
assert len(
features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"
assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \
"Important: first entry is recommended to be 1, else we run strided conv drectly on the input"
pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None
self.conv_pad_sizes = []
for krnl in kernel_sizes:
self.conv_pad_sizes.append([i // 2 for i in krnl])
stem_channels = features_per_stage[0]
self.stem = nn.Sequential(
BasicResBlock(
conv_op=conv_op,
input_channels=input_channels,
output_channels=stem_channels,
norm_op=norm_op,
norm_op_kwargs=norm_op_kwargs,
kernel_size=kernel_sizes[0],
padding=self.conv_pad_sizes[0],
stride=1,
nonlin=nonlin,
nonlin_kwargs=nonlin_kwargs,
use_1x1conv=True
),
*[
BasicBlockD(
conv_op=conv_op,
input_channels=stem_channels,
output_channels=stem_channels,
kernel_size=kernel_sizes[0],
stride=1,
conv_bias=conv_bias,
norm_op=norm_op,
norm_op_kwargs=norm_op_kwargs,
nonlin=nonlin,
nonlin_kwargs=nonlin_kwargs,
) for _ in range(n_blocks_per_stage[0] - 1)
]
)
input_channels = stem_channels
stages = []
for s in range(n_stages):
stage = nn.Sequential(
BasicResBlock(
conv_op=conv_op,
norm_op=norm_op,
norm_op_kwargs=norm_op_kwargs,
input_channels=input_channels,
output_channels=features_per_stage[s],
kernel_size=kernel_sizes[s],
padding=self.conv_pad_sizes[s],
stride=strides[s],
use_1x1conv=True,
nonlin=nonlin,
nonlin_kwargs=nonlin_kwargs
),
*[
BasicBlockD(
conv_op=conv_op,
input_channels=features_per_stage[s],
output_channels=features_per_stage[s],
kernel_size=kernel_sizes[s],
stride=1,
conv_bias=conv_bias,
norm_op=norm_op,
norm_op_kwargs=norm_op_kwargs,
nonlin=nonlin,
nonlin_kwargs=nonlin_kwargs,
) for _ in range(n_blocks_per_stage[s] - 1)
]
)
stages.append(stage)
input_channels = features_per_stage[s]
self.stages = nn.Sequential(*stages)
self.output_channels = features_per_stage
self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]
self.return_skips = return_skips
self.conv_op = conv_op
self.norm_op = norm_op
self.norm_op_kwargs = norm_op_kwargs
self.nonlin = nonlin
self.nonlin_kwargs = nonlin_kwargs
self.conv_bias = conv_bias
self.kernel_sizes = kernel_sizes
# print(f"Kernel sizes: {kernel_sizes}") # Debug: Print kernel sizes
# print(f"Strides: {strides}") # Debug: Print strides
def forward(self, x):
if self.stem is not None:
x = self.stem(x)
print(f"After stem: {x.shape}") # Debug: Print after stem
ret = []
for s in self.stages:
x = s(x)
print(f"Stage output shape: {x.shape}") # Debug: Print after each stage
ret.append(x)
if self.return_skips:
return ret
else:
return ret[-1]
def compute_conv_feature_map_size(self, input_size):
if self.stem is not None:
output = self.stem.compute_conv_feature_map_size(input_size)
else:
output = np.int64(0)
for s in range(len(self.stages)):
output += self.stages[s].compute_conv_feature_map_size(input_size)
input_size = [i // j for i, j in zip(input_size, self.strides[s])]
return output
class UNetResDecoder(nn.Module):
def init(self,
encoder,
num_classes,
n_conv_per_stage: Union[int, Tuple[int, …], List[int]],
deep_supervision, nonlin_first: bool = False, project_dim: int = 64):
super().__init__()
self.deep_supervision = deep_supervision
self.encoder = encoder
self.num_classes = num_classes
n_stages_encoder = len(encoder.output_channels)
if isinstance(n_conv_per_stage, int):
n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)
assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \
"resolution stages - 1 (n_stages in encoder - 1), " \
"here: %d" % n_stages_encoder
stages = []
upsample_layers = []
seg_layers = []
# Add a projector for the second upsampling block
self.projector = nn.Sequential(
nn.Conv3d(encoder.output_channels[-3], encoder.output_channels[-3], kernel_size=1),
nn.PReLU(),
nn.Conv3d(encoder.output_channels[-3], project_dim, kernel_size=1)
)
for s in range(1, n_stages_encoder):
input_features_below = encoder.output_channels[-s]
input_features_skip = encoder.output_channels[-(s + 1)]
stride_for_upsampling = encoder.strides[-s]
upsample_layers.append(UpsampleLayer(
conv_op=encoder.conv_op,
input_channels=input_features_below,
output_channels=input_features_skip,
pool_op_kernel_size=stride_for_upsampling,
mode='nearest'
))
stages.append(nn.Sequential(
BasicResBlock(
conv_op=encoder.conv_op,
norm_op=encoder.norm_op,
norm_op_kwargs=encoder.norm_op_kwargs,
nonlin=encoder.nonlin,
nonlin_kwargs=encoder.nonlin_kwargs,
input_channels=2 * input_features_skip,
output_channels=input_features_skip,
kernel_size=encoder.kernel_sizes[-(s + 1)],
padding=encoder.conv_pad_sizes[-(s + 1)],
stride=1,
use_1x1conv=True
),
*[
BasicBlockD(
conv_op=encoder.conv_op,
input_channels=input_features_skip,
output_channels=input_features_skip,
kernel_size=encoder.kernel_sizes[-(s + 1)],
stride=1,
conv_bias=encoder.conv_bias,
norm_op=encoder.norm_op,
norm_op_kwargs=encoder.norm_op_kwargs,
nonlin=encoder.nonlin,
nonlin_kwargs=encoder.nonlin_kwargs,
) for _ in range(n_conv_per_stage[s - 1] - 1)
]
))
seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True))
self.stages = nn.ModuleList(stages)
self.upsample_layers = nn.ModuleList(upsample_layers)
self.seg_layers = nn.ModuleList(seg_layers)
def forward(self, skips):
lres_input = skips[-1]
seg_outputs = []
for s in range(len(self.stages)):
x = self.upsample_layers[s](lres_input)
x = torch.cat((x, skips[-(s + 2)]), 1)
x = self.stages[s](x)
# Apply projection to the second level of the decoder
if s == 1: # Second level (index 1)
project_output = self.projector(x.clone())
if self.deep_supervision:
seg_outputs.append(self.seg_layers[s](x.clone()))
elif s == (len(self.stages) - 1):
seg_outputs.append(self.seg_layers[-1](x.clone()))
lres_input = x
seg_outputs = seg_outputs[::-1]
# Prepare output dictionary
out = dict()
out['project'] = project_output
out['project_map'] = F.interpolate(seg_outputs[-1], size=skips[-1].shape[2:], mode='trilinear', align_corners=False)
out['level5'] = seg_outputs[0] if self.deep_supervision else None
out['level4'] = seg_outputs[1] if self.deep_supervision else None
out['level3'] = seg_outputs[2] if self.deep_supervision else None
out['level2'] = seg_outputs[3] if self.deep_supervision else None
out['level1'] = seg_outputs[4] if self.deep_supervision else None
out['out'] = seg_outputs[-1]
return out
# if not self.deep_supervision:
# r = seg_outputs[0]
# else:
# r = seg_outputs
# return r
def compute_conv_feature_map_size(self, input_size):
skip_sizes = []
for s in range(len(self.encoder.strides) - 1):
skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])])
input_size = skip_sizes[-1]
assert len(skip_sizes) == len(self.stages)
output = np.int64(0)
for s in range(len(self.stages)):
output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s + 1)])
output += np.prod([self.encoder.output_channels[-(s + 2)], *skip_sizes[-(s + 1)]], dtype=np.int64)
if self.deep_supervision or (s == (len(self.stages) - 1)):
output += np.prod([self.num_classes, *skip_sizes[-(s + 1)]], dtype=np.int64)
return output
class UMambaBot(nn.Module):
def init(self,
input_channels: int,
n_stages: int,
features_per_stage: Union[int, List[int], Tuple[int, …]],
conv_op: Type[_ConvNd],
kernel_sizes: Union[int, List[int], Tuple[int, …]],
strides: Union[int, List[int], Tuple[int, …]],
n_conv_per_stage: Union[int, List[int], Tuple[int, …]],
num_classes: int,
n_conv_per_stage_decoder: Union[int, Tuple[int, …], List[int]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
deep_supervision: bool = False,
stem_channels: int = None,
project_dim: int = 64 # Add project_dim for the projector layer
):
super().init()
n_blocks_per_stage = n_conv_per_stage
if isinstance(n_blocks_per_stage, int):
n_blocks_per_stage = [n_blocks_per_stage] * n_stages
if isinstance(n_conv_per_stage_decoder, int):
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
for s in range(math.ceil(n_stages / 2), n_stages):
n_blocks_per_stage[s] = 1
for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1):
n_conv_per_stage_decoder[s] = 1
assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \
f"resolution stages. here: {n_stages}. " \
f"n_blocks_per_stage: {n_blocks_per_stage}"
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
f"as we have resolution stages. here: {n_stages} " \
f"stages, so it should have {n_stages - 1} entries. " \
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
self.encoder = UNetResEncoder(
input_channels,
n_stages,
features_per_stage,
conv_op,
kernel_sizes,
strides,
n_blocks_per_stage,
conv_bias,
norm_op,
norm_op_kwargs,
nonlin,
nonlin_kwargs,
return_skips=True,
stem_channels=stem_channels
)
self.mamba_layer = MambaLayer(dim=features_per_stage[-1])
self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision)
def forward(self, x):
skips = self.encoder(x)
skips[-1] = self.mamba_layer(skips[-1])
return self.decoder(skips)
def compute_conv_feature_map_size(self, input_size):
assert len(input_size) == convert_conv_op_to_dim(
self.encoder.conv_op), "just give the image size without color/feature channels or "
"batch channel. Do not give input_size=(b, c, x, y(, z)). "
“Give input_size=(x, y(, z))!”
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(
input_size)
It says I’m doing some inplace operation which modified certain fields, but I don’t think I’m doing any in-place operations here. What could be wrong?