Skip to content

Commit 324de59

Browse files
committed
[update] add point number function parameter to YOLO-Pose
1 parent 71ba7e4 commit 324de59

3 files changed

Lines changed: 18 additions & 14 deletions

File tree

projects/llm_framework/main_yolo/src/EngineWrapper.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ static const std::vector<std::vector<uint8_t>> SKELETON = {
311311
{8, 10}, {9, 11}, {2, 3}, {1, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, {5, 7}};
312312

313313
void post_process(AX_ENGINE_IO_INFO_T* io_info, AX_ENGINE_IO_T* io_data, const cv::Mat& mat, int& input_w, int& input_h,
314-
int& cls_num, float& prob_threshold, float& nms_threshold, std::vector<detection::Object>& objects,
315-
std::string& model_type)
314+
int& cls_num, int& point_num, float& prob_threshold, float& nms_threshold,
315+
std::vector<detection::Object>& objects, std::string& model_type)
316316
{
317317
// std::vector<detection::Object> objects;
318318
std::vector<detection::Object> proposals;
@@ -352,7 +352,7 @@ void post_process(AX_ENGINE_IO_INFO_T* io_info, AX_ENGINE_IO_T* io_data, const c
352352
auto feat_kps_ptr = output_kps_ptr[i];
353353
int32_t stride = (1 << i) * 8;
354354
detection::generate_proposals_yolov8_pose_native(stride, feat_ptr, feat_kps_ptr, prob_threshold, proposals,
355-
input_h, input_w, 17, cls_num);
355+
input_h, input_w, point_num, cls_num);
356356
}
357357
detection::get_out_bbox_kps(proposals, objects, nms_threshold, input_h, input_w, mat.rows, mat.cols);
358358
// detection::draw_keypoints(mat, objects, KPS_COLORS, LIMB_COLORS, SKELETON, "yolo11_pose_out");
@@ -368,10 +368,12 @@ void post_process(AX_ENGINE_IO_INFO_T* io_info, AX_ENGINE_IO_T* io_data, const c
368368
}
369369
}
370370

371-
int EngineWrapper::Post_Process(cv::Mat& mat, int& input_w, int& input_h, int& cls_num, float& pron_threshold,
372-
float& nms_threshold, std::vector<detection::Object>& objects, std::string& model_type)
371+
int EngineWrapper::Post_Process(cv::Mat& mat, int& input_w, int& input_h, int& cls_num, int& point_num,
372+
float& pron_threshold, float& nms_threshold, std::vector<detection::Object>& objects,
373+
std::string& model_type)
373374
{
374-
post_process(m_io_info, &m_io, mat, input_w, input_h, cls_num, pron_threshold, nms_threshold, objects, model_type);
375+
post_process(m_io_info, &m_io, mat, input_w, input_h, cls_num, point_num, pron_threshold, nms_threshold, objects,
376+
model_type);
375377
return 0;
376378
}
377379

projects/llm_framework/main_yolo/src/EngineWrapper.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ class EngineWrapper {
4949

5050
int RunSync();
5151

52-
int Post_Process(cv::Mat& mat, int& input_w, int& input_, int& cls_num, float& pron_threshold, float& nms_threshold,
53-
std::vector<detection::Object>& objects, std::string& model_type);
52+
int Post_Process(cv::Mat& mat, int& input_w, int& input_, int& cls_num, int& point_num, float& pron_threshold,
53+
float& nms_threshold, std::vector<detection::Object>& objects, std::string& model_type);
5454

5555
int GetOutput(void* pOutput, int index);
5656

projects/llm_framework/main_yolo/src/main.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ typedef struct {
3131
int img_h = 640;
3232
int img_w = 640;
3333
int cls_num = 80;
34+
int point_num = 17;
3435
float pron_threshold = 0.45f;
3536
float nms_threshold = 0.45;
3637
} yolo_config;
@@ -115,6 +116,7 @@ class llm_task {
115116
CONFIG_AUTO_SET(file_body["mode_param"], nms_threshold);
116117
CONFIG_AUTO_SET(file_body["mode_param"], cls_name);
117118
CONFIG_AUTO_SET(file_body["mode_param"], cls_num);
119+
CONFIG_AUTO_SET(file_body["mode_param"], point_num);
118120
CONFIG_AUTO_SET(file_body["mode_param"], model_type);
119121
mode_config_.yolo_model = base_model + mode_config_.yolo_model;
120122
yolo_ = std::make_unique<EngineWrapper>();
@@ -207,19 +209,19 @@ class llm_task {
207209
}
208210
std::vector<detection::Object> objects;
209211
yolo_->Post_Process(img_mat, mode_config_.img_w, mode_config_.img_h, mode_config_.cls_num,
210-
mode_config_.pron_threshold, mode_config_.nms_threshold, objects,
211-
mode_config_.model_type);
212+
mode_config_.point_num, mode_config_.pron_threshold, mode_config_.nms_threshold,
213+
objects, mode_config_.model_type);
212214
std::vector<nlohmann::json> yolo_output;
213215
for (size_t i = 0; i < objects.size(); i++) {
214216
const detection::Object &obj = objects[i];
215217
nlohmann::json output;
216218
output["class"] = mode_config_.cls_name[obj.label];
217219
output["confidence"] = format_float(obj.prob, 2);
218220
output["bbox"] = nlohmann::json::array();
219-
output["bbox"].push_back(format_float(obj.rect.x, 0));
220-
output["bbox"].push_back(format_float(obj.rect.y, 0));
221-
output["bbox"].push_back(format_float(obj.rect.x + obj.rect.width, 0));
222-
output["bbox"].push_back(format_float(obj.rect.y + obj.rect.height, 0));
221+
output["bbox"].push_back(format_float(obj.rect.x, 2));
222+
output["bbox"].push_back(format_float(obj.rect.y, 2));
223+
output["bbox"].push_back(format_float(obj.rect.x + obj.rect.width, 2));
224+
output["bbox"].push_back(format_float(obj.rect.y + obj.rect.height, 2));
223225
if (mode_config_.model_type == "segment") {
224226
std::vector<std::string> formatted_mask_feat;
225227
for (const auto &mask : obj.mask_feat) {

0 commit comments

Comments
 (0)