헤더 파일

#pragma once
#ifndef __TORCH_CLASSIFICATION_H__
#define __TORCH_CLASSIFICATION_H__

#include <opencv2/opencv.hpp>
#include <iostream>

// 설정 파일
#include "config.h"

// 이미지 처리
#include "img_utils.h"

// dlib
#include "dlib_utils.h"

// pytorch
#include "torch/torch.h"
#include "torch/script.h"

class EmoTorch
{
private:
	static void EmoTorch::torch_model_load(torch::jit::script::Module& module);
	static int EmoTorch::torch_predict(cv::Mat frame, torch::jit::script::Module& model);
public:
	static int torch_process();
};

#endif

 

CPP 파일

#include "torch_classification.h"

void EmoTorch::torch_model_load(torch::jit::script::Module& module)
{
	try {
		// Load Model
		module = torch::jit::load(TORCH_MODEL_PATH);

		// Convert to Device
		torch::Device device(torch::kCPU, 0);
		module.to(at::kCPU);

		return;
	}
	catch (const c10::Error& e) {
		std::cerr << "[error 02]" << std::endl;
		std::cerr << "error loading the pytorch model" << std::endl;
	}
	return;
}

int EmoTorch::torch_predict(cv::Mat frame, torch::jit::script::Module& model)
{
	// Using Model
	cv::cvtColor(frame, frame, cv::COLOR_BGR2RGB);
	frame.convertTo(frame, CV_32FC1);
	torch::Tensor tensor_image = torch::from_blob(frame.data, { 1, IMG_SIZE, IMG_SIZE, 3 });
	tensor_image = tensor_image.permute({ 0, 3, 1, 2 });
	//tensor_image[0][0] = tensor_image[0][0].sub_(0.485).div_(0.229);
	//tensor_image[0][1] = tensor_image[0][1].sub_(0.456).div_(0.224);
	//tensor_image[0][2] = tensor_image[0][2].sub_(0.406).div_(0.225);

	torch::Tensor output = model.forward({ tensor_image }).toTensor();
	//std::cout << output.sizes() << std::endl;
	//std::cout << output << std::endl;

	int emo_idx = output.argmax(1).item().toFloat();

	return emo_idx;
}

int EmoTorch::torch_process()
{
	Img_preprocessing img_tools = Img_preprocessing();
	face_detect face_detector = face_detect();
	mmod_net mmod_model;
	dlib::frontal_face_detector fog_model;

	// Create variable for Pytorch, 모델 변수 선언
	torch::jit::script::Module model;

	// Load Model, 모델 로드
	EmoTorch::torch_model_load(model);

	// Read Video
	cv::VideoCapture cap;

	// Setting camera
	img_tools.set_camera(cap);

	// Setting face_detector
	if (FACE_METHOD == 0)	// dlib
	{
		if (DLIB_STYLE == "CNN")
		{
			face_detector.load_mmod(mmod_model);
		}
		else
		{
			face_detector.load_hog(fog_model);
		}
	}

	// Start Video
	cv::Mat frame;
	cv::Mat face_img;
	int output = -1;
	std::vector<int> face_positions;

	while (true)
	{
		cap >> frame;

		double start_time = clock();
		face_detector.detect_face(mmod_model, fog_model, frame, face_positions);
		if (face_positions.size() != 0)
		{
			for (int idx = 0; idx < face_positions.size() - 3; idx += 4)
			{
				cv::Rect face_rect(cv::Point(face_positions[idx], face_positions[idx + 1]), cv::Point(face_positions[idx + 2], face_positions[idx + 3]));
				frame(face_rect).copyTo(face_img);
				cv::cvtColor(face_img, face_img, cv::COLOR_BGR2RGB);

				if (PADDING) { img_tools.add_padding(face_img); }
				img_tools.resize_img(face_img);
				output = EmoTorch::torch_predict(face_img, model);

				img_tools.draw_bbox(frame, face_positions[idx], face_positions[idx + 1], face_positions[idx + 2], face_positions[idx + 3], output);
				img_tools.put_text(frame, face_positions[idx], face_positions[idx + 1], face_positions[idx + 2], face_positions[idx + 3], output);
				
				output = -1;
			}
		}
		double terminate_time = clock();
		std::cout << "start_time : " << start_time << std::endl;
		std::cout << "terminate_time : " << terminate_time << std::endl;
		std::cout << (terminate_time - start_time) / CLOCKS_PER_SEC << std::endl;
		if (SHOW_FPS)
		{
			img_tools.put_fps(frame, (1.0 / ((terminate_time - start_time) / CLOCKS_PER_SEC)));
		}

		face_positions.clear();

		cv::imshow("test", frame);
		if (cv::waitKey(1) == 27) { break; }
	}

	return 0;
}

+ Recent posts