Files
autopilot/misc/rtsp_ai_player/src-onnx-runtime/aiengineinferenceonnxruntime.cpp
T
2024-07-27 11:29:41 +03:00

218 lines
5.6 KiB
C++

#include <QDebug>
#include <QThread>
#include <vector>
#include "aiengineinferenceonnxruntime.h"
static const float confThreshold = 0.4f;
static const float iouThreshold = 0.4f;
static const float maskThreshold = 0.5f;
AiEngineInferencevOnnxRuntime::AiEngineInferencevOnnxRuntime(QString modelPath, QObject *parent) :
AiEngineInference{modelPath, parent},
mPredictor(modelPath.toStdString(), confThreshold, iouThreshold, maskThreshold)
{
qDebug() << "TUOMAS AiEngineInferencevOnnxRuntime() mModelPath=" << mModelPath;
/*
mClassNames = {
"Armoured vehicle",
"Truck",
"Vehicle",
"Artillery",
"Shadow artillery",
"Trenches",
"Military man",
"Tyre tracks",
"Additional protection tank",
"Smoke"
};
*/
mClassNames = {
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush"
};
}
cv::Mat AiEngineInferencevOnnxRuntime::drawLabels(const cv::Mat &image, const std::vector<Yolov8Result> &detections)
{
cv::Mat result = image.clone();
for (const auto &detection : detections)
{
cv::rectangle(result, detection.box, cv::Scalar(0, 255, 0), 2);
std::string label = mClassNames[detection.classId].toStdString() + ": " + std::to_string(detection.conf);
int baseLine;
cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
cv::rectangle(
result,
cv::Point(detection.box.x, detection.box.y - labelSize.height),
cv::Point(detection.box.x + labelSize.width, detection.box.y + baseLine),
cv::Scalar(255, 255, 255),
cv::FILLED);
cv::putText(
result,
label,
cv::Point(
detection.box.x,
detection.box.y),
cv::FONT_HERSHEY_SIMPLEX,
0.5,
cv::Scalar(0, 0, 0),
1);
}
return result;
}
void AiEngineInferencevOnnxRuntime::performInferenceSlot(cv::Mat frame)
{
//qDebug() << __PRETTY_FUNCTION__;
try {
mActive = true;
cv::Mat scaledImage = resizeAndPad(frame);
//cv::imwrite("/tmp/frame.png", scaledImage);
std::vector<Yolov8Result> detections = mPredictor.predict(scaledImage);
// Only keep following detected objects.
// car = 2
// train = 6
// cup = 41
// banana = 46
auto it = std::remove_if(detections.begin(), detections.end(),
[](const Yolov8Result& result) {
return result.classId != 2 &&
result.classId != 6 &&
result.classId != 41 &&
result.classId != 46;
});
detections.erase(it, detections.end());
AiEngineInferenceResult result;
for (uint i = 0; i < detections.size(); i++) {
const Yolov8Result &detection = detections[i];
// Add detected objects to the results
AiEngineObject object;
object.classId = detection.classId;
object.classStr = mClassNames[detection.classId];
object.propability = detection.conf;
object.rectangle.top = detection.box.y;
object.rectangle.left = detection.box.x;
object.rectangle.bottom = detection.box.y + detection.box.height;
object.rectangle.right = detection.box.x + detection.box.width;
result.objects.append(object);
}
if (result.objects.empty() == false) {
result.frame = drawLabels(scaledImage, detections);
emit resultsReady(result);
}
mActive = false;
}
catch (const cv::Exception& e) {
std::cout << "OpenCV exception caught: " << e.what() << std::endl;
}
catch (const std::exception& e) {
std::cout << "Standard exception caught: " << e.what() << std::endl;
}
catch (...) {
std::cout << "Unknown exception caught" << std::endl;
}
}
void AiEngineInferencevOnnxRuntime::initialize(int number)
{
(void)number;
}