Is there any better way to add peephole connection for LSTM with pytorch?

I tried adding peephole connection for LSTM using torch.nn.Linear(), but it really is slow to calculate.
Is there any better way to make it more faster?
My current code is below.
Any help would be appreciated.

class Handmade_LSTMCell(torch.nn.Module):
    def __init__(self, _input_size, _hidden_size):
        super().__init__()
        self.input_size = _input_size
        self.hidden_size = _hidden_size

        #input gate
        self.linear_ii = torch.nn.Linear(_input_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_ii.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_ii.bias)
        self.linear_hi = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_hi.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_hi.bias)
        self.input_activation = torch.nn.Sigmoid()
        
        #forget gate
        self.linear_if = torch.nn.Linear(_input_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_if.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_if.bias)
        self.linear_hf = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_hf.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_hf.bias)
        self.forget_activation = torch.nn.Sigmoid()

        #cell gate
        self.linear_ic = torch.nn.Linear(_input_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_ic.weight, gain = 5.0/3.0)
        torch.nn.init.zeros_(self.linear_ic.bias)
        self.linear_hc = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_hc.weight, gain = 5.0/3.0)
        torch.nn.init.zeros_(self.linear_hc.bias)
        self.cell_activation = torch.nn.Tanh()

        #output gate
        self.linear_io = torch.nn.Linear(_input_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_io.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_io.bias)
        self.linear_ho = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_ho.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_ho.bias)
        self.output_activation = torch.nn.Sigmoid()

        #peephole connection
        self.linear_ci = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_ci.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_ci.bias)
        self.linear_cf = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_cf.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_cf.bias)
        self.linear_co = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
        torch.nn.init.xavier_uniform_(self.linear_co.weight, gain = 1.0)
        torch.nn.init.zeros_(self.linear_co.bias)

    def forward(self, _input, _state = None):
        is_batched = _input.dim() == 2
        
        if is_batched:
            _input = _input.unsqueeze(0)

        #_inpt.size(0) = sequence length
        if _state == None:
            hx = torch.zeros(_input.size(0), self.hidden_size, dtype=_input.dtype, device=_input.device)
            cx = torch.zeros(_input.size(0), self.hidden_size, dtype=_input.dtype, device=_input.device)
        else:
            hx, cx = _state

        outputs = torch.empty(_input.size(1), _input.size(0), self.hidden_size, dtype=_input.dtype, device=_input.device)
        for i in range(_input.size(1)):
            input_gate = self.input_activation(self.linear_ii(_input[:,i,:]) + self.linear_hi(hx) + self.linear_ci(cx))

            forget_gate = self.forget_activation(self.linear_if(_input[:,i,:]) + self.linear_hf(hx) + self.linear_cf(cx))

            cell_gate = self.cell_activation(self.linear_ic(_input[:,i,:]) + self.linear_hc(hx))

            cy = (forget_gate * cx) + (input_gate * cell_gate)
        
            output_gate = self.output_activation(self.linear_io(_input[:,i,:]) + self.linear_ho(hx) + self.linear_co(cy))
        
            hx = output_gate * torch.tanh(cy)
            outputs[i] = hx
            cx = cy
        if is_batched:
            return outputs.transpose(0, 1).contiguous().squeeze(0), (hx, cy)
        else:
            return outputs.transpose(0, 1).contiguous(), (hx.unsqueeze(0), cy.unsqueeze(0))