Is there any method to call sub networks paramters method using main network

I want to use subnetwork’s parameters method when I call main network’s parameters method
suppose there are two sub networks and one main network

import torch 
from torch import nn 
from torch.optim import Adam

class SubNetworkFrist(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.ln1=nn.Linear(dim,dim)
        self.ln2=nn.Linear(dim,dim)
        
    def parameters(self):
        optim_groups=[ 
                      {"params":self.ln1.parameters(), "weight_decay":0.01}, 
                      {"params":self.ln2.parameters(), "weight_decay":0.02}]
        return optim_groups
        
    def forward(self, x):
        return self.ln2(self.ln1(x))
        
        

class SubNetworkSecond(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.ln1=nn.Linear(dim,dim)
        self.ln2=nn.Linear(dim,dim)
        
    def parameters(self):
        optim_groups=[ 
                      {"params":self.ln1.parameters(), "weight_decay":0.03}, 
                      {"params":self.ln2.parameters(), "weight_decay":0.04}]
        return optim_groups
    
    def forward(self, x):
        return self.ln2(self.ln1(x))
        
class MainNetwork(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.subnetwork_first = SubNetworkFrist(dim)
        self.subnetwork_second = SubNetworkSecond(dim)
        self.ln1 = nn.Linear(dim,dim)
    
    def forward(self, x):
        return self.subnetwork_second(self.subnetwork_first(x))
    
    # def parameters(self, recurse=True):
    #     return list(self.subnetwork_first.parameters())+list(self.subnetwork_second.parameters())

If I do NOT code main networks parameters method like above(treat it comment # ), then If I create optimizer using nn.Module’s parameters like below, then It do not contain weight decay information of sub networks like below

net = MainNetwork(dim=5)
optimizer = Adam(net.parameters(recurse=True))

result

[{'params': [Parameter containing:
   tensor([[ 0.2269, -0.0205,  0.3401, -0.2977, -0.0842],
           [-0.3822,  0.0752, -0.0242, -0.0512,  0.2593],
           [ 0.0075, -0.1386, -0.3152,  0.3361,  0.0259],
           [-0.4198, -0.2412,  0.1211,  0.1048, -0.1516],
           [ 0.0945, -0.4442,  0.0972, -0.2301, -0.3241]], requires_grad=True),
   Parameter containing:
   tensor([-0.1699,  0.2401, -0.0934, -0.0333, -0.1239], requires_grad=True),
   Parameter containing:
   tensor([[-0.2171, -0.2811, -0.0994,  0.1280,  0.1468],
           [-0.1294, -0.2983,  0.2863, -0.2734,  0.2802],
           [-0.2428, -0.2191,  0.0845, -0.1452, -0.2173],
           [-0.0500,  0.0876,  0.2817, -0.1917,  0.4266],
           [ 0.3072,  0.2910, -0.3065,  0.3595,  0.3016]], requires_grad=True),
   Parameter containing:
   tensor([ 0.1471,  0.3241, -0.3961, -0.3012,  0.0260], requires_grad=True),
   Parameter containing:
   tensor([[ 0.1207, -0.3128,  0.0178,  0.4305,  0.4022],
           [-0.2462,  0.2588,  0.1159, -0.2598, -0.0643],
           [ 0.0877, -0.1187,  0.2677,  0.0473,  0.3103],
           [ 0.4395,  0.4039, -0.2941,  0.3161,  0.1373],
           [ 0.2146, -0.1177,  0.2910,  0.3985,  0.1692]], requires_grad=True),
   Parameter containing:
   tensor([ 0.4458, -0.4463,  0.3167, -0.2502,  0.2004], requires_grad=True),
   Parameter containing:
   tensor([[-0.4124,  0.2103,  0.0301, -0.2410,  0.1652],
           [ 0.3087, -0.3044, -0.3833,  0.0512,  0.4460],
           [ 0.1965, -0.0507,  0.3019, -0.0489, -0.4239],
           [-0.2510, -0.0977,  0.0828, -0.3054, -0.4008],
           [ 0.4293, -0.3715, -0.3389,  0.2993, -0.1951]], requires_grad=True),
   Parameter containing:
   tensor([-0.3922, -0.0883, -0.2532, -0.2097,  0.4341], requires_grad=True)],
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False}]

If I code main networks parameters method like above WITHOUT comment, then result become like as I targeted

[{'params': [Parameter containing:
   tensor([[-0.1662, -0.3731,  0.2447,  0.0158,  0.3589],
           [-0.3973,  0.0554,  0.1346,  0.2705,  0.2953],
           [ 0.0382,  0.4328,  0.2174, -0.0453, -0.0708],
           [-0.0371, -0.3114, -0.2699, -0.1803, -0.3551],
           [ 0.4103, -0.3735, -0.2199, -0.4052, -0.0822]], requires_grad=True),
   Parameter containing:
   tensor([0.1290, 0.1823, 0.2404, 0.3977, 0.2686], requires_grad=True)],
  'weight_decay': 0.01,
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'amsgrad': False},
 {'params': [Parameter containing:
   tensor([[ 0.4347, -0.1732, -0.3305,  0.2113, -0.1103],
           [ 0.0536, -0.1947,  0.4224,  0.2709,  0.1174],
           [ 0.3652, -0.1602,  0.0927, -0.3235, -0.1919],
           [-0.3630,  0.2735, -0.2341, -0.4448,  0.4014],
           [-0.1127, -0.2531,  0.0986,  0.1517,  0.1193]], requires_grad=True),
   Parameter containing:
   tensor([ 0.2007, -0.1275,  0.1425, -0.0773, -0.0505], requires_grad=True)],
  'weight_decay': 0.02,
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'amsgrad': False},
 {'params': [Parameter containing:
   tensor([[-0.0983, -0.1208, -0.2512, -0.3053,  0.1468],
           [ 0.3571,  0.4010,  0.0321,  0.0985, -0.1396],
           [-0.0251, -0.3274, -0.1493,  0.3199, -0.1097],
           [-0.1964, -0.0931, -0.2085, -0.1951,  0.0726],
           [ 0.0462, -0.1719, -0.1833,  0.4282,  0.2670]], requires_grad=True),
   Parameter containing:
   tensor([ 0.2455, -0.3165,  0.3203, -0.1628,  0.3272], requires_grad=True)],
  'weight_decay': 0.03,
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'amsgrad': False},
 {'params': [Parameter containing:
   tensor([[ 0.3983,  0.2805, -0.2257,  0.1010, -0.3014],
           [ 0.0643, -0.1134, -0.3481, -0.3178,  0.3751],
           [-0.3416,  0.1170,  0.2967, -0.0593,  0.1188],
           [-0.2931, -0.1958,  0.4420,  0.3204, -0.1029],
           [ 0.0860,  0.0527, -0.2205,  0.2188, -0.1995]], requires_grad=True),
   Parameter containing:
   tensor([ 0.1071, -0.3164,  0.2664, -0.0405, -0.1424], requires_grad=True)],
  'weight_decay': 0.04,
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'amsgrad': False}]

my question is are there any ways to do it without custom parameters method?