You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

79 lines
2.8 KiB

#pragma once
#include <map>
#include <memory>
#include <vector>
#include <ByteTrack/STrack.h>
#include <ByteTrack/lapjv.h>
#include <ByteTrack/Object.h>
namespace byte_track
{
class BYTETracker
{
public:
using STrackPtr = std::shared_ptr<STrack>;
BYTETracker(const int& frame_rate = 30,
const int& track_buffer = 30,
const float& track_thresh = 0.5,
const float& high_thresh = 0.6,
const float& match_thresh = 0.8);
~BYTETracker();
std::vector<STrackPtr> update(const std::vector<Object>& objects);
inline void set_thresh_values(float track_thresh, float high_thresh, float match_thresh) {
this->track_thresh_ = track_thresh;
this->high_thresh_ = high_thresh;
this->match_thresh_ = match_thresh;
}
private:
std::vector<STrackPtr> jointStracks(const std::vector<STrackPtr> &a_tlist,
const std::vector<STrackPtr> &b_tlist) const;
std::vector<STrackPtr> subStracks(const std::vector<STrackPtr> &a_tlist,
const std::vector<STrackPtr> &b_tlist) const;
void removeDuplicateStracks(const std::vector<STrackPtr> &a_stracks,
const std::vector<STrackPtr> &b_stracks,
std::vector<STrackPtr> &a_res,
std::vector<STrackPtr> &b_res) const;
void linearAssignment(const std::vector<std::vector<float>> &cost_matrix,
const int &cost_matrix_size,
const int &cost_matrix_size_size,
const float &thresh,
std::vector<std::vector<int>> &matches,
std::vector<int> &b_unmatched,
std::vector<int> &a_unmatched) const;
std::vector<std::vector<float>> calcIouDistance(const std::vector<STrackPtr> &a_tracks,
const std::vector<STrackPtr> &b_tracks) const;
std::vector<std::vector<float>> calcIous(const std::vector<Rect<float>> &a_rect,
const std::vector<Rect<float>> &b_rect) const;
double execLapjv(const std::vector<std::vector<float> > &cost,
std::vector<int> &rowsol,
std::vector<int> &colsol,
bool extend_cost = false,
float cost_limit = LONG_MAX,
bool return_cost = true) const;
private:
float track_thresh_;
float high_thresh_;
float match_thresh_;
const size_t max_time_lost_;
size_t frame_id_;
size_t track_id_count_;
std::vector<STrackPtr> tracked_stracks_;
std::vector<STrackPtr> lost_stracks_;
std::vector<STrackPtr> removed_stracks_;
};
}