Tensor's shape error when pass tensor[...,i] into thread i and replace the elements in it

I’m writing a cpp extension to sample different columns of a table, the sampling process on each column is independent. So I hope to speed it up with C++ multithreading. My code is like this:

mysampler.sample(tuples, new_tuples,new_preds,self.columns_size, self.has_nones, num_samples)

//c++ interface
sample(const torch::Tensor& tuples, torch::Tensor& new_tuples, torch::Tensor& new_preds, const std::vector<int>& columns_size, std::vector<bool>& has_nones, int num_samples){
    unsigned long long col_num = columns_size.size();
    std::vector<std::thread> thread_list;
    for (int i=0;i<col_num; ++i){
        torch::Tensor new_tuple_i = new_tuples.index({"...", i});
        torch::Tensor new_pred_i = new_preds.index({"...", Slice(i*5,(i+1)*5)});
        // shape is correct here
        // I didn't do anything in the thread, just print the shape of new_pred_i in the thread, and the shape is always smaller than it should be (except for the first two threads)
        std::thread t(sample_i, i, tuples.index({"...", i}), std::ref(new_tuple_i), std::ref(new_pred_i), columns_size[i], 0, num_samples, has_nones[i], num_samples, -1, nullptr);

sample_i(int i, const torch::Tensor& tuples, torch::Tensor& new_tuples, torch::Tensor& new_preds, int column_size, int start_idx, int end_idx, bool has_none, int num_samples, int first_pred, const torch::Tensor* first_mask) {
    try {
        //...other stuff
       // I use things like torch::index_put_(new_preds.index({mask}), samples);
   catch (std::exception &e){

I got output like this:

col:0new_preds.shape:[2048, 2, 5]          // correct
col:1new_preds.shape:[2048, 2, 5]          // correct
col:2new_preds.shape:[5]                        // incorrect
...all incorrect

The tuples is the tensor that I want to sample on, and the new_tuples and new_preds are the tensors where I put my result in.
Their shape are: tuples[batch, col_num], new_tuples[batch, 2, col_num], new_preds[batch, 2, col_num*5]. So I index them on dim=-1 and pass the result (as a reference) into each column’s thread.
However, only the first threads receive the correct new_tuples and new_preds, the others have the wrong shape such as [539,], [4,], [].

I guess that these tensors are generated from on common tensor and share a common data storage, so maybe there are thread safety problems. However, adding mutex lock to every operation to these tensors is not working.