@ptrblck thanks for the reply, here is a minimal setup of what I currently have
In this example, the number of tasks set is 2.
This is how my trainer looks like. The section i get the error is indicated as “FAILURE POINT”
tasks=2
optimizer = timm.optim.optim_factory.create_optimizer(args, model)
lambda_weights = torch.ones((tasks, ), requires_grad=True).cuda()
lambda_optim = timm.optim.optim_factory.create_optimizer(args, [torch.nn.Parameter(lambda_weights)])
for idx_epoch in epochs:
for idx_batch, item in enumerate(ds_loader):
img = item['img'].requires_grad_(True)
target_a = item['target_a']
target_b = item['target_b']
out_a, out_b = model(img)
loss_a = criterion(target_a, out_a)
loss_b = criterion(target_b, out_b)
loss_tasks = torch.stack([loss_a, loss_b])
if idx_epoch == 0:
initial_loss = loss_tasks
weighted_loss = lambda_weights * loss_tasks
loss = weighted_loss.sum()
lambda_weights.retain_grad(True)
loss_tasks.retain_grad(True)
loss.backward(retain_graph=True)
norms = []
# FAILURE POINT
for w_i, l_i in zip(lambda_weights, loss_tasks):
local_grad = torch.autograd.grad(l_i, model.stage3.parameters(), retain_graph=True)
norms.append(torch.norm(w_i * local_grad))
norms = torch.stack(norms)
nw = norms.mean()
with torch.no_grad():
# loss ratios
loss_ratios = loss_tasks / initial_loss
# inverse training rate r(t)
inverse_train_rates = loss_ratios / loss_ratios.mean()
constant_term = nw * (inverse_train_rates ** alpha)
# compute Lgrad
lgrad = (norms - constant_term).abs().sum()
lambda_weights.grad = torch.autograd.grad(lgrad, lambda_weights)
optimizer.step()
gradnorm_optim.step()
# Renormalize
with torch.no_grad():
renormalize = args.tasks / lambda_weights.sum()
lambda_weights *= renormalize
My model keeps the same CSwin setup with the different that stage4 corresponds to an array of “stages” according to number of task heads.
class CSWinTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=[1000], embed_dim=96, depth=[2,2,6,2], split_size = [3,5,7],
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, use_chk=False):
super().__init__()
self.use_chk = use_chk
self.head_class_count = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
heads=num_heads
self.stage1_conv_embed = nn.Sequential(
nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
Rearrange('b c h w -> b (h w) c', h = img_size//4, w = img_size//4),
nn.LayerNorm(embed_dim)
)
curr_dim = embed_dim
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule
self.stage1 = nn.ModuleList([
CSWinBlock(
dim=curr_dim, num_heads=heads[0], reso=img_size//4, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0],
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth[0])])
self.merge1 = Merge_Block(curr_dim, curr_dim*2)
curr_dim = curr_dim*2
self.stage2 = nn.ModuleList(
[CSWinBlock(
dim=curr_dim, num_heads=heads[1], reso=img_size//8, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1],
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[np.sum(depth[:1])+i], norm_layer=norm_layer)
for i in range(depth[1])])
self.merge2 = Merge_Block(curr_dim, curr_dim*2)
curr_dim = curr_dim*2
temp_stage3 = []
temp_stage3.extend(
[CSWinBlock(
dim=curr_dim, num_heads=heads[2], reso=img_size//16, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2],
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[np.sum(depth[:2])+i], norm_layer=norm_layer)
for i in range(depth[2])])
self.stage3 = nn.ModuleList(temp_stage3)
self.merge3 = Merge_Block(curr_dim, curr_dim*2)
curr_dim = curr_dim*2
self.heads = nn.ModuleList([
nn.ModuleList([
Merge_Block(curr_dim, curr_dim * 2),
nn.Sequential(
*[CSWinBlock(
dim=curr_dim * 2, num_heads=heads[3], reso=img_size // 32, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1],
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[np.sum(depth[:-1]) + i], norm_layer=norm_layer, last_stage=True)
for i in range(depth[-1])]),
norm_layer(curr_dim * 2),
nn.Linear(curr_dim * 2, head_class_count)
])
for head_class_count in self.heads_class_counts])
trunc_normal_(self.head.weight, std=0.02)
self.apply(self._init_weights)
...
def forward_features(self, x):
B = x.shape[0]
x = self.stage1_conv_embed(x)
for blk in self.stage1:
if self.use_chk:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
for pre, blocks in zip([self.merge1, self.merge2],
[self.stage2, self.stage3]):
x = pre(x)
for blk in blocks:
if self.use_chk:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
results = []
for head_idx, head_unit in enumerate(self.heads):
merge, stage, norm, head = head_unit
head_x = merge(x)
head_x = stage(head_x)
head_x = norm(head_x)
head_x = torch.mean(head_x, dim=1)
# if self.aggregate is True:
# logits += [head_x]
head_x = head(head_x)
# results[head_idx, :, :head_x.shape[1]] = head_x
results += [head_x]
return results
def forward(self, x):
x = self.forward_features(x)
return x