libtorch c++ 使用例子
pytorch可使用flask作为服务器部署,但是由于Python的可移植性和速度比不上c++, pytorch还提供将模型转化到c++端运行的方案。主要结合TorchScript与libtorch
TorchScript
TorchScript可以视为PyTorch模型的一种中间表示,TorchScript表示的PyTorch模型可以直接在C++中进行读取。所以第一步先将模型转化为c++可读取的格式. 以下以resnet 图像分类为例子.
import torch
import torchvision.models.resnet as resnet
from PIL import Image
from torchvision import transforms
tran = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = Image.open("cat.jpg")
img = tran(img)
model = resnet.resnet50(pretrained=True)
model.eval()
out = model(img.unsqueeze(0))
print(torch.argmax(out))
# 输入示例
sample = torch.rand((1, 3, 224, 224))
# torch.jit.trace方法对模型构建TorchScript
trace_model = torch.jit.trace(model, sample)
trace_model.save('trace_model.pt')
# script 方式
script_model = torch.jit.script(model)
trace 方法需要给模型传入一个sample input,它会跟踪在模型的forward方法中的过程,使用于不含控制语句的网络,若包含控制结构如if-else 则使用Script方式.
libtorch
libtorch 是c++下的Torch,从官网下载后解压缩,在项目的 CMakeLists.txt 中添加
# 指定libtorch 路径
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "/Users/linyang/Downloads/libtorch")
find_package(Torch REQUIRED)
target_link_libraries(example ${TORCH_LIBRARIES})
#include <iostream>
#include <torch/script.h>
#include <opencv2/opencv.hpp>
using namespace std;
cv::Mat load_img(const string& img_path){
auto image = cv::imread(img_path,cv::ImreadModes::IMREAD_COLOR);
cv::Mat image_transfomed;
cv::resize(image, image_transfomed, cv::Size(224, 224));
cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);
return image_transfomed;
}
torch::Tensor process_img(cv::Mat img){
// 图像转换为Tensor
torch::Tensor tensor_image = torch::from_blob(img.data, {img.rows, img.cols,3},torch::kByte);
tensor_image = tensor_image.permute({2,0,1});
tensor_image = tensor_image.toType(torch::kFloat);
tensor_image = tensor_image.div(255);
tensor_image = tensor_image.unsqueeze(0);
return tensor_image;
}
int main() {
torch::jit::script::Module module;
try {
module = torch::jit::load("/Users/linyang/CLionProjects/example_torch/trace_model.pt");
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
auto image_transfomed = load_img("cat.jpg");
auto tensor_image = process_img(image_transfomed);
at::Tensor output = module.forward({tensor_image}).toTensor();
std::cout << torch::argmax(output)<< '\n';
return 0;
}
输出与Python下一样,通过.
这是个最简单的将pytorch模型转换到c++下,对于目标检测等模型,还需解码输出等操作.