Pytorch c++ extensions versus Tensorflow op

I converted a TensorFlow based model into PyTorch. The tf code used c++ op that I converted into PyTorch C++ Extention. I compared the timings of both running a model for 100 epochs, the tf op takes less than 3 seconds to finish for a 32 batch sizes while for PyTorch it takes on average 4 minutes on CPU and on GPU it takes double the time for the same batch size. Please see below the attached code sample from PyTorch

torch::Tensor extract_spans(
    torch::Tensor span_scores,
    torch::Tensor candidate_starts,
    torch::Tensor candidate_ends,
    torch::Tensor num_output_spans,
    int max_sentence_length,
    bool _sort_spans,
    bool _suppress_crossing
) {

    int num_sentences = span_scores.size(0);
    int num_input_spans = span_scores.size(1);
    int max_num_output_spans = 0;


    for (int i = 0; i < num_sentences; i++) {

      if (num_output_spans[i].item<int64_t>() > max_num_output_spans) {
        max_num_output_spans = num_output_spans[i].item<int64_t>();
      }
    }


    std::vector<std::vector<int>> sorted_input_span_indices(num_sentences,
                                                            std::vector<int>(num_input_spans));

    torch::Tensor output_span_indices = torch::ones({num_sentences, max_num_output_spans});

    for (int i = 0; i < num_sentences; i++) {
      std::iota(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(), 0);
      std::sort(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(),
                [&span_scores, &i](int j1, int j2) {
                 if (j1 >= span_scores.size(1) || j1 < 0 || j2 >= span_scores.size(1) || j2 < 0) {
                    return false;
                 }

                  return span_scores[i][j2].item<int64_t>() < span_scores[i][j1].item<int64_t>();
                });
    }


    for (int l = 0; l < num_sentences; l++) {
      std::vector<int> top_span_indices;
      std::unordered_map<int, int> end_to_earliest_start;
      std::unordered_map<int, int> start_to_latest_end;
      int current_span_index = 0, num_selected_spans = 0;

      while (num_selected_spans < num_output_spans[l].item<int64_t>() && current_span_index < num_input_spans) {
        int i = sorted_input_span_indices[l][current_span_index];
        bool any_crossing = false;
        if (_suppress_crossing) {
          const int& start = candidate_starts[l][i].item<int64_t>();
          const int& end = candidate_ends[l][i].item<int64_t>();

          for (int j = start; j <= end; ++j) {
            if (j > start) {
              auto latest_end_iter = start_to_latest_end.find(j);
              if (latest_end_iter != start_to_latest_end.end() && latest_end_iter->second > end) {
                // Given (), exists [], such that ( [ ) ]
                any_crossing = true;
                break;
              }
            }
            if (j < end) {
              auto earliest_start_iter = end_to_earliest_start.find(j);
              if (earliest_start_iter != end_to_earliest_start.end() && earliest_start_iter->second < start) {
                // Given (), exists [], such that [ ( ] )
                any_crossing = true;
                break;
              }
            }
          }
        }
        if (!any_crossing) {
          if (_sort_spans) {
            top_span_indices.push_back(i);
          } else {
            output_span_indices[l][num_selected_spans] = i;
          }
          ++num_selected_spans;
          if (_suppress_crossing) {
            // Update data struct.
            const int& start = candidate_starts[l][i].item<int64_t>();
            const int& end = candidate_ends[l][i].item<int64_t>();
            auto latest_end_iter = start_to_latest_end.find(start);
            if (latest_end_iter == start_to_latest_end.end() || end > latest_end_iter->second) {
              start_to_latest_end[start] = end;
            }
            auto earliest_start_iter = end_to_earliest_start.find(end);
            if (earliest_start_iter == end_to_earliest_start.end() || start < earliest_start_iter->second) {
              end_to_earliest_start[end] = start;
            }
          }
        }
        ++current_span_index;
      }
      // Sort and produce span indices.
      if (_sort_spans) {
        std::sort(top_span_indices.begin(), top_span_indices.end(),
                [&candidate_starts, &candidate_ends, &l] (int i1, int i2) {
                 if (i1 >= candidate_starts.size(1) || i1 < 0 || i2 >= candidate_starts.size(1) || i2 < 0) {
                    return false;
                 }
                  if (candidate_starts[l][i1].item<int64_t>() < candidate_starts[l][i2].item<int64_t>()) {
                    return true;
                  } else if (candidate_starts[l][i1].item<int64_t>() > candidate_starts[l][i2].item<int64_t>()) {
                    return false;
                  } else if (candidate_ends[l][i1].item<int64_t>() < candidate_ends[l][i2].item<int64_t>()) {
                    return true;
                  } else if (candidate_ends[l][i1].item<int64_t>() > candidate_ends[l][i2].item<int64_t>()) {
                    return false;
                  } else {
                    return i1 < i2;
                  }
                });

        for (int i = 0; i < num_output_spans[l].item<int64_t>(); ++i) {
          output_span_indices[l][i] = top_span_indices[i];
        }
      }

      // Pad with the last selected span index to ensure monotonicity.
      int last_selected = num_selected_spans - 1;
      if (last_selected >= 0) {
        for (int i = num_selected_spans; i < max_num_output_spans; ++i) {
          output_span_indices[l][i]= output_span_indices[l][last_selected].item<int64_t>();
        }
      }

    }



    return output_span_indices;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("extract_spans", &extract_spans, "extract_spans");
}

The equivalent in tf is

REGISTER_OP("ExtractSpans")
.Input("span_scores: float32")
.Input("candidate_starts: int32")
.Input("candidate_ends: int32")
.Input("num_output_spans: int32")
.Input("max_sentence_length: int32")
.Attr("sort_spans: bool")
.Attr("suppress_crossing: bool")
.Output("output_span_indices: int32");

class ExtractSpansOp : public OpKernel {
public:
  explicit ExtractSpansOp(OpKernelConstruction* context) : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("sort_spans", &_sort_spans));
    OP_REQUIRES_OK(context, context->GetAttr("suppress_crossing", &_suppress_crossing));
  }

  void Compute(OpKernelContext* context) override {
    TTypes<float>::ConstMatrix span_scores = context->input(0).matrix<float>();
    TTypes<int32>::ConstMatrix candidate_starts = context->input(1).matrix<int32>();
    TTypes<int32>::ConstMatrix candidate_ends = context->input(2).matrix<int32>();
    TTypes<int32>::ConstVec num_output_spans = context->input(3).vec<int32>();
    int max_sentence_length = context->input(4).scalar<int32>()();

    int num_sentences = span_scores.dimension(0);
    int num_input_spans = span_scores.dimension(1);
    int max_num_output_spans = 0;
    for (int i = 0; i < num_sentences; i++) {
      if (num_output_spans(i) > max_num_output_spans) {
        max_num_output_spans = num_output_spans(i);
      }
    }

    Tensor* output_span_indices_tensor = nullptr;
    TensorShape output_span_indices_shape({num_sentences, max_num_output_spans});
    OP_REQUIRES_OK(context, context->allocate_output(0, output_span_indices_shape,
                                                     &output_span_indices_tensor));
    TTypes<int32>::Matrix output_span_indices = output_span_indices_tensor->matrix<int32>();

    std::vector<std::vector<int>> sorted_input_span_indices(num_sentences,
                                                            std::vector<int>(num_input_spans));
    for (int i = 0; i < num_sentences; i++) {
      std::iota(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(), 0);
      std::sort(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(),
                [&span_scores, &i](int j1, int j2) {
                  return span_scores(i, j2) < span_scores(i, j1);
                });
    }

    for (int l = 0; l < num_sentences; l++) {
      std::vector<int> top_span_indices;
      std::unordered_map<int, int> end_to_earliest_start;
      std::unordered_map<int, int> start_to_latest_end;
      int current_span_index = 0,
          num_selected_spans = 0;
      while (num_selected_spans < num_output_spans(l) && current_span_index < num_input_spans) {
        int i = sorted_input_span_indices[l][current_span_index];
        bool any_crossing = false;
        if (_suppress_crossing) {
          const int& start = candidate_starts(l, i);
          const int& end = candidate_ends(l, i);
          for (int j = start; j <= end; ++j) {
            if (j > start) {
              auto latest_end_iter = start_to_latest_end.find(j);
              if (latest_end_iter != start_to_latest_end.end() && latest_end_iter->second > end) {
                // Given (), exists [], such that ( [ ) ]
                any_crossing = true;
                break;
              }
            }
            if (j < end) {
              auto earliest_start_iter = end_to_earliest_start.find(j);
              if (earliest_start_iter != end_to_earliest_start.end() && earliest_start_iter->second < start) {
                // Given (), exists [], such that [ ( ] )
                any_crossing = true;
                break;
              }
            }
          }
        }
        if (!any_crossing) {
          if (_sort_spans) {
            top_span_indices.push_back(i);
          } else {
            output_span_indices(l, num_selected_spans) = i;
          }
          ++num_selected_spans;
          if (_suppress_crossing) {
            // Update data struct.
            const int& start = candidate_starts(l, i);
            const int& end = candidate_ends(l, i);
            auto latest_end_iter = start_to_latest_end.find(start);
            if (latest_end_iter == start_to_latest_end.end() || end > latest_end_iter->second) {
              start_to_latest_end[start] = end;
            }
            auto earliest_start_iter = end_to_earliest_start.find(end);
            if (earliest_start_iter == end_to_earliest_start.end() || start < earliest_start_iter->second) {
              end_to_earliest_start[end] = start;
            }
          }
        }
        ++current_span_index;
      }
      // Sort and produce span indices.
      if (_sort_spans) {
        std::sort(top_span_indices.begin(), top_span_indices.end(),
                [&candidate_starts, &candidate_ends, &l] (int i1, int i2) {
                  if (candidate_starts(l, i1) < candidate_starts(l, i2)) {
                    return true;
                  } else if (candidate_starts(l, i1) > candidate_starts(l, i2)) {
                    return false;
                  } else if (candidate_ends(l, i1) < candidate_ends(l, i2)) {
                    return true;
                  } else if (candidate_ends(l, i1) > candidate_ends(l, i2)) {
                    return false;
                  } else {
                    return i1 < i2;
                  }
                });
        for (int i = 0; i < num_output_spans(l); ++i) {
          output_span_indices(l, i) = top_span_indices[i];
        }
      }
      // Pad with the last selected span index to ensure monotonicity.
      int last_selected = num_selected_spans - 1;
      if (last_selected >= 0) {
        for (int i = num_selected_spans; i < max_num_output_spans; ++i) {
          output_span_indices(l, i) = output_span_indices(l, last_selected);
        }
      }
    }
  }
private:
  bool _sort_spans, _suppress_crossing;
};

REGISTER_KERNEL_BUILDER(Name("ExtractSpans").Device(DEVICE_CPU), ExtractSpansOp);

So why pytorch is much slower than tf in that case? and why running on GPU takes more time compared to running the same code on a cpu?

How did you measure the timings and did you synchronize both codes?
I’m not familiar enough with TensorFlow and don’t know how the code can be synchronized, but in PyTorch you should call torch.cuda.synchronize() before starting and stopping the timer. Otherwise your measurements are invalid.

Thanks for your reply. I at least logged the time before and after inside the c++ code assuming it’s a blocking code. I now converted the code back to python and it works very fast compared to the c++ setup, I don’t know why.