libtorch c++ 使用例子

pytorch可使用flask作为服务器部署,但是由于Python的可移植性和速度比不上c++, pytorch还提供将模型转化到c++端运行的方案。主要结合TorchScript与libtorch

TorchScript

TorchScript可以视为PyTorch模型的一种中间表示,TorchScript表示的PyTorch模型可以直接在C++中进行读取。所以第一步先将模型转化为c++可读取的格式. 以下以resnet 图像分类为例子.

import torch
import torchvision.models.resnet as resnet
from PIL import Image
from torchvision import transforms


tran = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
])

img = Image.open("cat.jpg")
img = tran(img)

model = resnet.resnet50(pretrained=True)
model.eval()

out = model(img.unsqueeze(0))
print(torch.argmax(out))

# 输入示例
sample = torch.rand((1, 3, 224, 224))
# torch.jit.trace方法对模型构建TorchScript
trace_model = torch.jit.trace(model, sample)
trace_model.save('trace_model.pt')
# script 方式
script_model = torch.jit.script(model)

trace 方法需要给模型传入一个sample input,它会跟踪在模型的forward方法中的过程,使用于不含控制语句的网络,若包含控制结构如if-else 则使用Script方式.

libtorch

libtorch 是c++下的Torch,从官网下载后解压缩,在项目的 CMakeLists.txt 中添加

# 指定libtorch 路径
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "/Users/linyang/Downloads/libtorch") 
find_package(Torch REQUIRED)
target_link_libraries(example ${TORCH_LIBRARIES})
#include <iostream>
#include <torch/script.h>
#include <opencv2/opencv.hpp>


using namespace std;

cv::Mat load_img(const string& img_path){
    auto image = cv::imread(img_path,cv::ImreadModes::IMREAD_COLOR);
    cv::Mat image_transfomed;
    cv::resize(image, image_transfomed, cv::Size(224, 224));
    cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);
    return image_transfomed;
}

torch::Tensor process_img(cv::Mat img){
    // 图像转换为Tensor
    torch::Tensor tensor_image = torch::from_blob(img.data, {img.rows, img.cols,3},torch::kByte);
    tensor_image = tensor_image.permute({2,0,1});
    tensor_image = tensor_image.toType(torch::kFloat);
    tensor_image = tensor_image.div(255);
    tensor_image = tensor_image.unsqueeze(0);
    return tensor_image;
}


int main() {
    torch::jit::script::Module module;
    try {
        module = torch::jit::load("/Users/linyang/CLionProjects/example_torch/trace_model.pt");
    }

    catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }
    auto image_transfomed = load_img("cat.jpg");
    auto tensor_image = process_img(image_transfomed);

    at::Tensor output = module.forward({tensor_image}).toTensor();
    std::cout << torch::argmax(output)<< '\n';
    return 0;
}

输出与Python下一样,通过.

这是个最简单的将pytorch模型转换到c++下,对于目标检测等模型,还需解码输出等操作.

ref

pytorch.org