Use custom stream in custom torch.autograd.Function

I want to know if there are any notes for using custom stream in custom torch.autograd.Function. My custom Function is as follows. I want to offload input and grad_output to CPU side. Is the code wrong?

import torch
from torch.nn import functional as F

class LinearFunction(torch.autograd.Function):
    copy_stream = torch.cuda.Stream(priority=-1)
    tensors = []
    @staticmethod
    def pack_to_cpu(tensor):
        packed = torch.empty(
            tensor.size(),
            dtype=tensor.dtype,
            layout=tensor.layout,
            pin_memory=True)
        packed.copy_(tensor, non_blocking=True)
        return (tensor.device, packed)

    @staticmethod
    def unpack_from_cpu(packed):
        device, tensor = packed
        return tensor.to(device, non_blocking=True)

    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = F.linear(input, weight, bias)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ weight
        if ctx.needs_input_grad[1]:
            # grad_weight = grad_output.t().mm(input)
            grad_weight = grad_output.view(-1, grad_output.size()[-1]).t() @ input.view(-1, input.size()[-1])
        if bias is not None and ctx.needs_input_grad[2]:
            assert False, "bias is not supported yet!"
            grad_bias = grad_output.sum(0).squeeze(0)
        with torch.cuda.stream(LinearFunction.copy_stream):
            input.record_stream(LinearFunction.copy_stream)
            grad_output.record_stream(LinearFunction.copy_stream)
            input_detach = input.detach()
            grad_output_detach = grad_output.detach()
            input_cpu = LinearFunction.pack_to_cpu(input_detach)
            grad_output_cpu = LinearFunction.pack_to_cpu(grad_output_detach)
            LinearFunction.tensors.append((input_cpu, grad_output_cpu))
        return grad_input, grad_weight, grad_bias

cc @ptrblck
Could you help me?