0. Introduction
These days I am reading Programming Massively Parallel Processors: A Hands-on Approach, 4th Edition, and created a project to store my notes as I learn.
One of the most important parts in the book is writing cuda kernels, so I decided to build all kernels into shared libraries and test those implementations both in C++ and Python.
I generated my project using this template specifically tailored for the similar scenario, but still met some problems such as conflicts when linking libtorch and gtest 🤯.
So the purpose of this blog is to provide a concise guide to:
- Build a C++, CUDA and LibTorch library, test it with gtest.
- Load the library into torch, call the operaters in Python.
- Resolve problems when linking all the libraries.
⚠️WARNING
Find some tutorials on how to use cmake and vcpkg before reading this blog.
1. Environment and Quick Start
Check README.md of the project repository.
2. Create a C++, CUDA and LibTorch Project
I put all C++ codes in “./csrc/” and build them with cmake. The intermediate files should be generated in “./build/” and that is just about using some command-line arguments, see this line.
Vcpkg is used to manage the dependencies of the project. I am not going to teach you how to use vcpkg in this blog, but I will mention some pitfalls I met when using it.
😍️ I really enjoy building C++ projects with cmake and vcpkg. Have a try if you haven’t used them before.
2.1. How to Link against LibTorch
Since you have installed pytorch in 1. Environment, now you already have libtorch installed in your conda environment. Run this command, and you will get the cmake prefix path of libtorch:
|
|
To integrate libtorch into cmake, I create this file and this file to find libtorch in the current project and use them here.
Now you can link your targets against libtorch simply like what I do here.
📝NOTE
When you link your target against${TORCH_LIBRARIES}
, cuda libraries are being linked automatically, which means you don’t have to find and link cuda using something like I write here
2.2. CMake and VCPKG Configuration
Currently, I am planning to use the C/C++ packages listed in this file. I load the packages with these lines in "./csrc/CMakeLists.txt" . Then I link those packages to my targets here and here.
📝NOTE
libtorch < 2.6
is compiled with_GLIBCXX_USE_CXX11_ABI=0
to use legacy ABI before C++11, which conflicts with the packages managed by vcpkg in default. Consequentially, you have to create a custom vcpkg triplet to control the behaviors when vcpkg actually build the packages. The triplet file is here and is enabled by these lines when building the C++ part.
I also set CMAKE_CXX_SCAN_FOR_MODULES
to OFF
on this line because some compile errors occurs. This is a temporary solution but I am not planning to use modules from C++20 in this project, so just ignoring it.
2.3. Write and Register Custom Torch Operators
In order to register a custom torch operator, basically what you need to do next is to write a function that usually takes several torch::Tensor
as input and returns a torch::Tensor
as output, and then register this function to torch.
For example, I implement pmpp::ops::cpu::launchVecAdd
in this cpp file and pmpp::ops::cuda::launchVecAdd
in this cu file and provide the corresponding torch implentations pmpp::ops::cpu::vectorAddImpl
and pmpp::ops::cuda::vectorAddImpl
in this file.
🤔 I didn’t add any of those function declarations in hpp files under “./include” because I don’t think they should be exposed to the users of the library. For the testing part, I will get and test the functions using
torch::Dispatcher
which aligns with the operaters invoked in python.
To register these implementations as an operater into pytorch, see this line, this line, and this line, where I:
- Define a python function
vector_add
with signature:vector_add(Tensor a, Tensor b) -> Tensor
. - Register the CPU implementation of the function.
- Register the CUDA implementation of the function.
Now vector_add
is a custom torch operator which can be called in both C++ and Python. All you need to do is to build these codes into a shared library like what I did here in cmake.
2.4. Test the Custom Torch Operators in C++
As long as a custom torch operator is registered, normally one or multiple shared libraries will be generated. For C++ users, you should link your executable target against libtorch and the generated shared libraries so that those registered operators can be called.
Since I have linked libPmppTorchOps
against libtorch as PUBLIC
in this line, the test target will link against libtorch automatically as long as it links against libPmppTorchOps
, see this line.
📝NOTE
You may be confused about why-Wl,--no-as-needed
is added before${PROJECT_NAMESPACE}pmpp-torch-ops
. This is because the shared libraries are not directly used in the test target (an operator is register in the library but not called directly in the executable), and the linker will not link against them by default. This flag will force the linker to link against the shared libraries even if they are not directly used.
The registered operators can be dispatched in a not-so-intuitional way 🤣 based on the official documentation, see here.
Now the only thing is to test the operators in C++ using gtest, but this is not the focus of this blog. So let’s move on to the next part.
3. Create and Package a Python Project
3.1. pyproject.toml
and setup.py
In modern python, pyproject.toml is a de-facto standard configuration file for packaging, and in this project, setuptools is used as the build backend because I believe it is the most popular one and is easy to cooperate with cmake.
Particularly, “./pyproject.toml” and “./setup.py” defines what will happen when you run pip install .
in the root directory of the project. I created CMakeExtention
and CMakeBuild
(here) and pass them to setup
function (here) so that the C++ library libPmppTorchOps
(under “./csrc/”) will be built and installed before installing the python package.
You can easily understand what I did by reading the source code of these two files, and there is one more thing I want to mention.
Based on 2. Create a C++, CUDA and LibTorch Project, you should find that the generated shared library is under ./build/lib
ending with .so
on linux or .dll
on windows. Additionally, I added an install procedure here which will copy the shared libraries to “./src/pmpp/_torch_ops”.
Note that “./src/pmpp” is already an existing directory being the root of the actual python package, and “./src/pmpp/_torch_ops” will be created automatically while installing the shared libraries.
The problem is, when packaging the python project, only the directory containing “__init__.py” will be considered as a package (or module), and I don’t want to add this file to “./src/pmpp/_torch_ops” due to my mysophobia 😷. Therefore, I used find_namespace_packages
instead of find_packages
and specified package_data
to include the shared libraries here.
3.2. Install the Python Package
If you are planning to build your libraries with dependencies listed here while installing the python project, I don’t really suggest installing it in an isolated python environment (which is the default behavior of setuptools). All packages listed here have to be re-installed and in our case you need to at least append torch
to that list.
Alternatively, try this command, which will directly use the torch installed in current conda environment:
pip install --no-build-isolation -v .
3.3. Test the Custom Torch Operators in Python
As long as you have the shared libraries built in 2. Create a C++, CUDA and LibTorch Project, all you need to do is to use torch.ops.load_library
to load the shared libraries and call the registered operators.
I write this process into “src/pmpp/__init__.py”, so the time you import pmpp
in python, your custom torch operators will be ready to use. See this file for an example of testing the operators.