Feature_Extraction in Conv-Tasnet

Hello, I want to extract the feature from our conv_tanset model and use these features for speaker diarization. Here I am uploading my model code. Thank you very much.

class ConvTasNet(nn.Module):
def init(self,
L=16,
N=512,
X=8,
R=3,
B=128,
Sc=128,
H=512,
P=3,
norm=“cLN”,
num_spks=2,
non_linear=“sigmoid”,
causal=False):
super(ConvTasNet, self).init()
supported_nonlinear = {
“relu”: F.relu,
“sigmoid”: th.sigmoid,
“softmax”: F.softmax
}
if non_linear not in supported_nonlinear:
raise RuntimeError(“Unsupported non-linear function: {}”,
format(non_linear))
self.non_linear_type = non_linear
self.non_linear = supported_nonlinear[non_linear]
# n x S => n x N x T, S = 4s*8000 = 32000
self.encoder_1d = Conv1D(1, N, L, stride=L // 2, padding=0)
# keep T not change
# T = int((xlen - L) / (L // 2)) + 1
# before repeat blocks, always cLN
self.ln = ChannelWiseLayerNorm(N)
# n x N x T => n x B x T
self.proj = Conv1D(N, B, 1)
# repeat blocks
# n x B x T => n x B x T
self.repeats = self._build_repeats(
R,
X,
Sc=Sc,
in_channels=B,
conv_channels=H,
kernel_size=P,
norm=norm,
causal=causal)
self.PRelu = nn.PReLU()
# output 1x1 conv
# n x B x T => n x N x T
# NOTE: using ModuleList not python list
# self.conv1x1_2 = th.nn.ModuleList(
# [Conv1D(B, N, 1) for _ in range(num_spks)])
# n x Sc x T => n x 2N x T
self.mask = Conv1D(Sc, num_spks * N, 1)
# using ConvTrans1D: n x N x T => n x 1 x To
# To = (T - 1) * L // 2 + L
self.decoder_1d = ConvTrans1D(
N, 1, kernel_size=L, stride=L // 2, bias=True)
self.num_spks = num_spks
self.R = R #numbers of repeat
self.X = X #numbers of Conv1Dblock in each repeat

def _build_blocks(self, num_blocks, **block_kwargs):
    """
    Build Conv1D block
    """
    blocks = [
        Conv1DBlock(**block_kwargs, dilation=(2**b))
        for b in range(num_blocks)
    ]
    return nn.Sequential(*blocks)

def _build_repeats(self, num_repeats, num_blocks, **block_kwargs):
    """
    Build Conv1D block repeats
    """
    repeats = [
        self._build_blocks(num_blocks, **block_kwargs)
        for r in range(num_repeats)
    ]
    return nn.Sequential(*repeats)

def forward(self, x):
    if x.dim() >= 3:
        raise RuntimeError(
            "{} accept 1/2D tensor as input, but got {:d}".format(
                self.__name__, x.dim()))
    # when inference, only one utt
    if x.dim() == 1:
        x = th.unsqueeze(x, 0)
    #encoder
    # n x 1 x S => n x N x T
    w = F.relu(self.encoder_1d(x))
    
    #Seperation
    #   LayerNorm & 1X1 Conv
    # n x B x T
    y = self.proj(self.ln(w))
    
    #TCN
    # n x B x T
    skip_connection = 0
    for i in range(self.R):
        for j in range(self.X):
            skip, y = self.repeats[i][j](y)
            skip_connection = skip_connection + skip
    
    y = self.PRelu(skip_connection)
    print(y.size())
    # n x 2N x T
    e = th.chunk(self.mask(y), self.num_spks, 1)
    # print(e.size())
    #e = create_feature_extractor(e)
    print("feature extraction", e)
    # n x N x T
    if self.non_linear_type == "softmax":
        m = self.non_linear(th.stack(e, dim=0), dim=0)
    else:
        m = self.non_linear(th.stack(e, dim=0))
    # spks x [n x N x T]
    s = [w * m[n] for n in range(self.num_spks)]
    # spks x n x S
    return [self.decoder_1d(x, squeeze=True) for x in s]