Working with Nvidia Tensor RT and Pose Estimation
Photo by Pietro Jeng on Unsplash
Using a custom Pytorch model with Tensor RT⌗
Our Goal: to create a C++ API for Pose Estimation, this API will help loading models, assigning inputs and processing the model outputs.
What is tensor RT?⌗
Tensor RT os an SDK that provides a model optimizer and runtime that increases the performance of our Deep Learning models.
Model optimization is done through various means such as Precision Calibration, Layer and Tensor Fusion, Kernel Auto-Tuning etc.
Why tensor RT?⌗
We are currently focussed on embedded / edge Artificial Intelligence using low power devices.
Tensor RT helps us by delivering additional performance allowing us to handle more demanding AI inference tasks in realtime e.g Video analysis, Speech recognition, Natural language processing etc.
What we want to learn?⌗
We want to learn how to create API’s for any type of model.
This is an extrememly helpful skill as it will allow you to use ALL of the publicly available models.
Prerequisites⌗
Make sure you have followed the previous 3 tutorials (Here, Here and Here)
Install the following dependencies if you haven’t already:
First we have to convert our Pytorch (.pth
) models to ONNX
We can do this using the built in Pytorch call torch.onnx.export
here is the python code (We are using Python 3)
import trt_pose.models
# Uncomment based on which model you have
# for resnet18_baseline_att_224x224_A_epoch_249
# model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links, pretrained=False).cuda().eval()
# for densenet121_baseline_att_256x256_B_epoch_160
model = trt_pose.models.densenet121_baseline_att(num_parts, 2 * num_links, pretrained=False).cuda().eval()
# MODEL_WEIGHTS = 'resnet18_baseline_att_224x224_A_epoch_249.pth'
# ONNX_OUTPUT = 'resnet18_baseline_att_224x224_A_epoch_249.onnx'
MODEL_WEIGHTS = 'densenet121_baseline_att_256x256_B_epoch_160.pth'
ONNX_OUTPUT = 'densenet121_baseline_att_256x256_B_epoch_160.onnx'
model.load_state_dict(torch.load(MODEL_WEIGHTS))
# WIDTH = 224 # for resnet18_baseline_att_224x224_A_epoch_249
# HEIGHT = 224 # for resnet18_baseline_att_224x224_A_epoch_249
WIDTH = 256 # for densenet121_baseline_att_256x256_B_epoch_160
HEIGHT = 256 # for densenet121_baseline_att_256x256_B_epoch_160
data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()
torch_out = model(data)
# export the model
torch.onnx.export(model,
data,
ONNX_OUTPUT,
export_params=True,
opset_version=10,
do_constant_folding=True
)
import onnx
# load our newly converted ONNX model
onnx_model = onnx.load(ONNX_OUTPUT)
# check our newly converted ONNX model for any errors
onnx.checker.check_model(onnx_model)
# Print a human readable representation of the model graph
print(onnx.helper.printable_graph(onnx_model.graph))
There is also a jupyter note book on the repo that you can use
Move your newly produces ONNX models to <path-to-jetson-inference-folder>/data/networks/
Time to make our Model API’s⌗
Note: we are going to be creating a C++ API and A ROS Node for testing the API.
We will not be compiling the code just yet as we still have to draw the results on the image using CUDA which will be in the near future.
Make sure to download the parse/
folder and the poseNet.cu/poseNet.cuh
files from the repo
Tensor Template Class⌗
Create 2 files Tensor.tcc
and Tensor.h
in your <catkin-workspace>/<your-package>/src/
directory.
Our tensor class is a template class allowing us to work with multiple data types e.g integers, floats etc.
Tensor.h
#ifndef __JETSONCAM_TENSOR_H__
#define __JETSONCAM_TENSOR_H__
#include <vector>
#include <cstdint>
namespace jetsoncam {
/**
Create a tensor class / struct, look at how
tensors are used in places like "find_peaks_out_torch" in plugins.hpp!
IT CAN BE DONE!
*/
template <class T>
class Tensor
{
public:
Tensor();
/**
Create a tensor an empty tensor and allocate memory for it
*/
Tensor(const char* name, std::vector<int> dimensions);
/**
Create a tensor an empty tensor and allocate memory for it
and fill it with the fill values
*/
Tensor(const char* name, std::vector<int> dimensions, T fill_value);
/**
Create a tensor with existing values in shared CPU, GPU memory
*/
Tensor(const char* name, std::vector<int> tensorDims, T* outputCPU, T* outputCUDA);
/**
* Destroy
*/
virtual ~Tensor();
T operator [] (T i) const {return CUDA[i];}
T& operator [] (T i) {return CUDA[i];}
T* data_ptr();
int size(int d) const;
void printDims();
T retrieve(std::vector<int> dimensions);
/**
* Return the stride of a tensor at some dimension.
*/
// virtual int64_t stride(int64_t d) const;
const char* tensor_name;
std::vector<int> dims;
size_t dims_size;
uint32_t dataSize;
T* CPU;
T* CUDA;
protected:
};
struct ParseResult
{
Tensor<int> object_counts;
Tensor<int> objects;
Tensor<float> normalized_peaks;
};
}
// include our template file definitions
#include "Tensor.tcc"
#endif
The Tensor
class has been namespaced to avoid collisions as Tensor is a very generic name.
Tensor.tcc
#include "Tensor.h"
#include <jetson-utils/cudaMappedMemory.h>
#include <vector>
#include <cstdint>
#include <string>
#include <algorithm>
namespace jetsoncam {
template <class T>
Tensor<T>::Tensor(){}
// Constructor - create an empty tensor with specified dimensions and allocate memory
template <class T>
Tensor<T>::Tensor(const char* name, std::vector<int> tensorDims)
{
// allocate output memory
void* outputCPU = NULL;
void* outputCUDA = NULL;
// size_t outputSize = 1 * DIMS_C(tensorDims) * DIMS_H(tensorDims) * DIMS_W(tensorDims) * sizeof(T);
size_t outputSize = 1;
for (std::size_t i = 0; i < tensorDims.size(); i++) {
// access element as v[i]
int val = tensorDims[i];
outputSize *= val;
// any code including continue, break, return
}
outputSize *= sizeof(T);
if( !cudaAllocMapped((void**)&outputCPU, (void**)&outputCUDA, outputSize) )
{
printf("failed to alloc CUDA mapped memory for tensor %s output, %zu bytes\n", name, outputSize);
}
// create output tensors
tensor_name = name;
CPU = (T*)outputCPU;
CUDA = (T*)outputCUDA;
dataSize = outputSize;
dims = tensorDims;
dims_size = tensorDims.size();
}
// Constructor - create fill a tensor with default values
template <class T>
Tensor<T>::Tensor(const char* name, std::vector<int> tensorDims, T fill_value)
{
// allocate output memory
void* outputCPU = NULL;
void* outputCUDA = NULL;
// size_t outputSize = 1 * DIMS_C(tensorDims) * DIMS_H(tensorDims) * DIMS_W(tensorDims) * sizeof(T);
size_t outputSize = 1;
size_t array_Len = 1;
for (std::size_t i = 0; i < tensorDims.size(); i++) {
// access element as v[i]
int val = tensorDims[i];
outputSize *= val;
array_Len *= val;
// any code including continue, break, return
}
outputSize *= sizeof(T);
if( !cudaAllocMapped((void**)&outputCPU, (void**)&outputCUDA, outputSize) )
{
printf("failed to alloc CUDA mapped memory for tensor %s output, %zu bytes\n", name, outputSize);
}
// create output tensors
tensor_name = name;
CPU = (T*)outputCPU;
CUDA = (T*)outputCUDA;
// fill our array with the values
std::fill_n(CUDA, array_Len, fill_value);
dataSize = outputSize;
dims = tensorDims;
dims_size = tensorDims.size();
}
// constructor - create a tensor around an already existing data in memory
template <class T>
Tensor<T>::Tensor(const char* name, std::vector<int> tensorDims, T* outputCPU, T* outputCUDA)
{
// size_t outputSize = 1 * DIMS_C(tensorDims) * DIMS_H(tensorDims) * DIMS_W(tensorDims) * sizeof(T);
size_t outputSize = 1;
for (std::size_t i = 0; i < tensorDims.size(); i++) {
// access element as v[i]
int val = tensorDims[i];
outputSize *= val;
// any code including continue, break, return
}
outputSize *= sizeof(T);
// create output tensors
tensor_name = name;
CPU = outputCPU;
CUDA = outputCUDA;
dataSize = outputSize;
dims = tensorDims;
dims_size = tensorDims.size();
}
template <class T>
T* Tensor<T>::data_ptr() {
return CUDA;
}
// print the size at a specific dimesion
template <class T>
int Tensor<T>::size(int d) const {
return dims[d];
}
// print tensor dimensions
template <class T>
void Tensor<T>::printDims() {
printf("Tensor %s ", tensor_name);
printf("%lu dimensions ", dims.size());
printf("{");
for (std::vector<int>::size_type i = 0; i < dims.size(); i++) {
int val = dims[i];
printf(" %d ", val);
}
printf("}\n");
}
// retrieve values via dimesional indeces
template <class T>
T Tensor<T>::retrieve(std::vector<int> indexes) {
if (indexes.size() == 1) {
int index = indexes[0];
return CUDA[index];
}
else if (indexes.size() == 2) {
int index = indexes[1] + (indexes[0] * dims[1]);
return CUDA[index];
}
else if (indexes.size() == 3) {
int index = indexes[2] + (indexes[1] * dims[2]) + (indexes[0] * dims[1] * dims[2]);
return CUDA[index];
}
else if (indexes.size() == 4) {
int index = indexes[3] + (indexes[2] * dims[3]) + (indexes[1] * dims[2] * dims[3]) + (indexes[0] * dims[1] * dims[2] * dims[3]);
return CUDA[index];
}
else {
return CUDA[0];
}
// size_t access_dim = indexes.size();
// int index = x + y * D1 + z * D1 * D2 + t * D1 * D2 * D3;
// return CUDA[index];
}
// destructor
template <class T>
Tensor<T>::~Tensor()
{
// CUDA(cudaFreeHost(imgCPU));
}
}
Tensor class usage / features⌗
- It is a simple wrapper to allocate shared (GPU / CPU) memory and specify dimensions
- The data for models are simply a 1-Dimensional array of floats
- The Tensor class allows you to access that array in specified ways
- e.g. A
Tensor T
of size {10, 10} is:- A 2 Dimensional tensor with a length of 10 and a height of 10
- The underlying data is an array of floats of length 100
- when we access the tensor at T[0][4] will access the underlying array at array[4]
- when we access the tensor at T[1][4] will access the underlying array at array[15]
- etc.
- An example usage of our Tensor class is:
jetsoncam::Tensor<float> cmap_tensor = jetsoncan::Tensor<float>();
- Note: this Tensor class only supports a maximum of 4 dimensions
poseNet API Class⌗
Create 2 files, poseNet.cpp
and poseNet.h
in your <catkin-workspace>/<your-package>/src/
directory
Lets fill these out
poseNet.h
#ifndef __POSE_NET_H__
#define __POSE_NET_H__
#include "Tensor.h"
#include "ParseObjects.hpp"
#include <jetson-inference/tensorNet.h>
#include <vector>
/**
* Name of default input blob for segmentation model.
* @ingroup poseNet
*/
#define POSENET_DEFAULT_INPUT "Input"
/**
* Name of default output blob for segmentation model.
* @ingroup poseNet
*/
#define RESNET_DEFAULT_CMAP_OUTPUT "262"
#define RESNET_DEFAULT_PAF_OUTPUT "264"
#define DENSENET_DEFAULT_CMAP_OUTPUT "262"
#define DENSENET_DEFAULT_PAF_OUTPUT "264"
/**
* Command-line options able to be passed to poseNet::Create()
* @ingroup poseNet
*/
#define POSENET_USAGE_STRING "poseNet arguments: \n" \
" --network NETWORK pre-trained model to load, one of the following:\n" \
" * fcn-resnet18-cityscapes-512x256\n" \
" * fcn-resnet18-cityscapes-1024x512\n" \
" --model MODEL path to custom model to load (caffemodel, uff, or onnx)\n" \
" --input_blob INPUT name of the input layer (default: '" POSENET_DEFAULT_INPUT "')\n" \
" --output_blob OUTPUT name of the output layer (default: '" RESNET_DEFAULT_CMAP_OUTPUT "')\n" \
" --batch_size BATCH maximum batch size (default is 1)\n" \
" --profile enable layer profiling in TensorRT\n"
/**
* Image segmentation with FCN-Alexnet or custom models, using TensorRT.
* @ingroup poseNet
*/
class poseNet : public tensorNet
{
public:
/**
* Enumeration of pretrained/built-in network models.
*/
enum NetworkType
{
DENSENET121_BASELINE_ATT_256x256,
RESNET18_BASELINE_ATT_224x224,
/* add new models here */
POSENET_CUSTOM
};
/**
* Parse a string from one of the built-in pretrained models.
* Valid names are "cityscapes-hd", "cityscapes-sd", "pascal-voc", ect.
* @returns one of the poseNet::NetworkType enums, or poseNet::CUSTOM on invalid string.
*/
static NetworkType NetworkTypeFromStr( const char* model_name );
/**
* Convert a NetworkType enum to a human-readable string.
* @returns stringized version of the provided NetworkType enum.
*/
static const char* NetworkTypeToStr( NetworkType networkType );
/**
* Load a new network instance
*/
static poseNet* Create(
NetworkType networkType=RESNET18_BASELINE_ATT_224x224,
uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
precisionType precision=TYPE_FASTEST,
deviceType device=DEVICE_GPU,
bool allowGPUFallback=true
);
/**
* Load a new network instance
* @param prototxt_path File path to the deployable network prototxt
* @param model_path File path to the caffemodel
* @param class_labels File path to list of class name labels
* @param class_colors File path to list of class colors
* @param input Name of the input layer blob. @see POSENET_DEFAULT_INPUT
* @param cmap_blob Name of the cmap output layer blob. @see RESNET_DEFAULT_CMAP_OUTPUT
* @param paf_blob Name of the paf output layer blob. @see RESNET_DEFAULT_PAF_OUTPUT
* @param maxBatchSize The maximum batch size that the network will support and be optimized for.
*/
static poseNet* Create(
const char* prototxt_path,
const char* model_path,
const char* input = POSENET_DEFAULT_INPUT,
const char* cmap_blob = RESNET_DEFAULT_CMAP_OUTPUT,
const char* paf_blob = RESNET_DEFAULT_PAF_OUTPUT,
uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
precisionType precision=TYPE_FASTEST,
deviceType device=DEVICE_GPU,
bool allowGPUFallback=true
);
/**
* Usage string for command line arguments to Create()
*/
static inline const char* Usage() { return POSENET_USAGE_STRING; }
/**
* Destroy
*/
virtual ~poseNet();
/**
* Perform the initial inferencing processing portion of the segmentation.
* The results can then be visualized using the Overlay() and Mask() functions.
* @param input float4 input image in CUDA device memory, RGBA colorspace with values 0-255.
* @param width width of the input image in pixels.
* @param height height of the input image in pixels.
*/
bool Process( float* input, uint32_t width, uint32_t height );
/**
* Retrieve the number of columns in the classification grid.
* This indicates the resolution of the raw segmentation output.
*/
inline uint32_t GetGridWidth() const { return DIMS_W(mOutputs[0].dims); }
/**
* Retrieve the number of rows in the classification grid.
* This indicates the resolution of the raw segmentation output.
*/
inline uint32_t GetGridHeight() const { return DIMS_H(mOutputs[0].dims); }
/**
* Retrieve the network type (alexnet or googlenet)
*/
inline NetworkType GetNetworkType() const { return mNetworkType; }
/**
* Retrieve a string describing the network name.
*/
inline const char* GetNetworkName() const { return NetworkTypeToStr(mNetworkType); }
protected:
poseNet();
bool processOutput(float* output, uint32_t width, uint32_t height);
bool overlayPosePoints(
float* input,
uint32_t width,
uint32_t height,
jetsoncam::Tensor<int> topology,
jetsoncam::Tensor<int> object_counts,
jetsoncam::Tensor<int> objects,
jetsoncam::Tensor<float> normalized_peaks
);
float* mClassColors[2]; /**< array of overlay colors in shared CPU/GPU memory */
uint8_t* mClassMap[2]; /**< runtime buffer for the argmax-classified class index of each tile */
float* mLastInputImg; /**< last input image to be processed, stored for overlay */
uint32_t mLastInputWidth; /**< width in pixels of last input image to be processed */
uint32_t mLastInputHeight; /**< height in pixels of last input image to be processed */
NetworkType mNetworkType; /**< Pretrained built-in model type enumeration */
const char* topology_supercategory;
uint32_t topology_id;
const char* topology_name;
std::vector<const char*> topology_keypoints;
std::vector<std::vector<int>> topology_skeleton;
int num_parts;
int num_links;
jetsoncam::Tensor<int> topology;
trt_pose::ParseObjects NetworkOutputParser;
};
#endif
poseNet.cpp
#include "poseNet.h"
#include "poseNet.cuh"
#include "Tensor.h"
#include "ParseObjects.hpp"
#include "plugins.hpp"
#include <jetson-utils/cudaMappedMemory.h>
#include <jetson-utils/cudaOverlay.h>
#include <jetson-utils/cudaResize.h>
#include <jetson-utils/cudaFont.h>
#include <jetson-utils/commandLine.h>
#include <jetson-utils/filesystem.h>
#include <jetson-utils/imageIO.h>
#include <vector>
#define OUTPUT_CMAP 0 // CMAP
#define OUTPUT_PAF 1 // PAF
// constructor
poseNet::poseNet() : tensorNet()
{
mLastInputImg = NULL;
mLastInputWidth = 0;
mLastInputHeight = 0;
mClassColors[0] = NULL; // cpu ptr
mClassColors[1] = NULL; // gpu ptr
mClassMap[0] = NULL;
mClassMap[1] = NULL;
mNetworkType = POSENET_CUSTOM;
topology_supercategory = "person";
topology_id = 1;
topology_name = "person";
topology_keypoints = {
"nose",
"left_eye",
"right_eye",
"left_ear",
"right_ear",
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
"left_hip",
"right_hip",
"left_knee",
"right_knee",
"left_ankle",
"right_ankle",
"neck"
};
// original topology
topology_skeleton = {
{ 16, 14 },
{ 14, 12 },
{ 17, 15 },
{ 15, 13 },
{ 12, 13 },
{ 6, 8 },
{ 7, 9 },
{ 8, 10 },
{ 9, 11 },
{ 2, 3 },
{ 1, 2 },
{ 1, 3 },
{ 2, 4 },
{ 3, 5 },
{ 4, 6 },
{ 5, 7 },
{ 18, 1 },
{ 18, 6 },
{ 18, 7 },
{ 18, 12 },
{ 18, 13 }
};
num_parts = static_cast<int>(topology_keypoints.size());
num_links = static_cast<int>(topology_skeleton.size());
topology = trt_pose::plugins::coco_category_to_topology(topology_skeleton);
topology.printDims();
NetworkOutputParser = trt_pose::ParseObjects(topology);
}
// destructor
poseNet::~poseNet()
{
}
// NetworkTypeFromStr
poseNet::NetworkType poseNet::NetworkTypeFromStr( const char* modelName )
{
if( !modelName )
return poseNet::POSENET_CUSTOM;
poseNet::NetworkType type = poseNet::DENSENET121_BASELINE_ATT_256x256;
// ONNX models
if( strcasecmp(modelName, "densenet121_baseline_att_256x256_B_epoch_160") == 0 || strcasecmp(modelName, "densenet121_baseline_att") == 0 )
type = poseNet::DENSENET121_BASELINE_ATT_256x256;
else if( strcasecmp(modelName, "resnet18_baseline_att_224x224_A_epoch_249") == 0 || strcasecmp(modelName, "resnet18_baseline_att") == 0 )
type = poseNet::RESNET18_BASELINE_ATT_224x224;
else
type = poseNet::POSENET_CUSTOM;
return type;
}
// NetworkTypeToStr
const char* poseNet::NetworkTypeToStr( poseNet::NetworkType type )
{
switch(type)
{
// ONNX models
case DENSENET121_BASELINE_ATT_256x256: return "densenet121_baseline_att_256x256_B_epoch_160";
case RESNET18_BASELINE_ATT_224x224: return "resnet18_baseline_att_224x224_A_epoch_249";
default: return "custom poseNet";
}
}
// Create
poseNet* poseNet::Create(
NetworkType networkType,
uint32_t maxBatchSize,
precisionType precision,
deviceType device,
bool allowGPUFallback
){
poseNet* net = NULL;
// ONNX models
if( networkType == DENSENET121_BASELINE_ATT_256x256 ) {
net = Create(
NULL,
"networks/densenet121_baseline_att_256x256_B_epoch_160.onnx",
"0",
"1227",
"1229",
maxBatchSize,
precision,
device,
allowGPUFallback
);
}
else if( networkType == RESNET18_BASELINE_ATT_224x224 ) {
net = Create(
NULL,
"networks/resnet18_baseline_att_224x224_A_epoch_249.onnx",
"0",
"262",
"264",
maxBatchSize,
precision,
device,
allowGPUFallback
);
}
else {
return NULL;
}
if( net != NULL )
net->mNetworkType = networkType;
return net;
}
// Create
poseNet* poseNet::Create(
const char* prototxt,
const char* model,
const char* input_blob,
const char* cmap_blob,
const char* paf_blob,
uint32_t maxBatchSize,
precisionType precision,
deviceType device,
bool allowGPUFallback
){
// create segmentation model
poseNet* net = new poseNet();
if( !net )
return NULL;
printf("\n");
printf("poseNet -- loading segmentation network model from:\n");
printf(" -- prototxt: %s\n", prototxt);
printf(" -- model: %s\n", model);
printf(" -- input_blob '%s'\n", input_blob);
printf(" -- cmap_blob '%s'\n", cmap_blob);
printf(" -- paf_blob '%s'\n", paf_blob);
printf(" -- batch_size %u\n\n", maxBatchSize);
//net->EnableProfiler();
//net->EnableDebug();
//net->DisableFP16(); // debug;
// load network
std::vector<std::string> output_blobs;
output_blobs.push_back(cmap_blob);
output_blobs.push_back(paf_blob);
if( !net->LoadNetwork(prototxt, model, NULL, input_blob, output_blobs, maxBatchSize,
precision, device, allowGPUFallback) )
{
printf("poseNet -- failed to initialize.\n");
return NULL;
}
return net;
}
// Pre Process and Classify the image
bool poseNet::Process( float* rgba, uint32_t width, uint32_t height )
{
if( !rgba || width == 0 || height == 0 )
{
printf("poseNet::Process( 0x%p, %u, %u ) -> invalid parameters\n", rgba, width, height);
return false;
}
PROFILER_BEGIN(PROFILER_PREPROCESS);
if( IsModelType(MODEL_ONNX) )
{
// downsample, convert to band-sequential RGB, and apply pixel normalization, mean pixel subtraction and standard deviation
if( CUDA_FAILED(cudaPreImageNetNormMeanRGB(
(float4*)rgba,
width,
height,
mInputCUDA,
mWidth,
mHeight,
make_float2(0.0f, 1.0f), // range
make_float3(0.485f, 0.456f, 0.406f), // mean
make_float3(0.229f, 0.224f, 0.225f), // stdDev
GetStream())) )
{
printf(LOG_TRT "poseNet::Process() -- cudaPreImageNetNormMeanRGB() failed\n");
return false;
}
}
else
{
// downsample and convert to band-sequential BGR
if( CUDA_FAILED(cudaPreImageNetBGR(
(float4*)rgba,
width,
height,
mInputCUDA,
mWidth,
mHeight,
GetStream())) )
{
printf("poseNet::Process() -- cudaPreImageNetBGR() failed\n");
return false;
}
}
PROFILER_END(PROFILER_PREPROCESS);
PROFILER_BEGIN(PROFILER_NETWORK);
// process with TensorRT
void* inferenceBuffers[] = { mInputCUDA, mOutputs[OUTPUT_CMAP].CUDA, mOutputs[OUTPUT_PAF].CUDA };
// execute the neural network with your image input
if( !mContext->execute(1, inferenceBuffers) )
{
printf(LOG_TRT "poseNet::Process() -- failed to execute TensorRT context\n");
return false;
}
PROFILER_END(PROFILER_NETWORK);
PROFILER_BEGIN(PROFILER_POSTPROCESS);
printf("width: %u, height: %u, mWidth: %u, mHeight: %u\n", width, height, mWidth, mHeight);
// process model Output
if( !processOutput(rgba, width, height) )
return false;
PROFILER_END(PROFILER_POSTPROCESS);
// cache pointer to last image processed
mLastInputImg = rgba;
mLastInputWidth = width;
mLastInputHeight = height;
return true;
}
// processOutput
bool poseNet::processOutput(
float* output,
uint32_t width,
uint32_t height
)
{
size_t outputLen = mOutputs.size();
for (size_t i = 0; i < outputLen; i++) {
const char* output_name = mOutputs[i].name.c_str();
printf(LOG_TRT "poseNet::processOutput() : %s \n", output_name);
}
// retrieve scores
float* cmap = mOutputs[OUTPUT_CMAP].CPU;
float* paf = mOutputs[OUTPUT_PAF].CPU;
const int c_w = DIMS_W(mOutputs[OUTPUT_CMAP].dims);
const int c_h = DIMS_H(mOutputs[OUTPUT_CMAP].dims);
const int c_c = DIMS_C(mOutputs[OUTPUT_CMAP].dims);
const int p_w = DIMS_W(mOutputs[OUTPUT_PAF].dims);
const int p_h = DIMS_H(mOutputs[OUTPUT_PAF].dims);
const int p_c = DIMS_C(mOutputs[OUTPUT_PAF].dims);
// data/image:: torch.Size([1, 3, 224, 224])
jetsoncam::Tensor<float> cmap_tensor = jetsoncam::Tensor<float>(
"cmap_tensor",
{1, c_c, c_h, c_w},
mOutputs[OUTPUT_CMAP].CPU,
mOutputs[OUTPUT_CMAP].CUDA
);
cmap_tensor.printDims();
jetsoncam::Tensor<float> paf_tensor = jetsoncam::Tensor<float>(
"paf_tensor",
{1, p_c, p_h, p_w},
mOutputs[OUTPUT_PAF].CPU,
mOutputs[OUTPUT_PAF].CUDA
);
paf_tensor.printDims();
// cmap:: torch.Size([1, 18, 56, 56]) [Correct]
// Tensor cmap_tensor 4 dimensions { 1 18 56 56 } [Correct]
// paf:: torch.Size([1, 42, 56, 56]) [Correct]
// Tensor paf_tensor 4 dimensions { 1 42 56 56 } [Correct]
jetsoncam::ParseResult networkResults = NetworkOutputParser.Parse(cmap_tensor, paf_tensor);
jetsoncam::Tensor<int> object_counts = networkResults.object_counts;
object_counts.printDims();
jetsoncam::Tensor<int> objects = networkResults.objects;
objects.printDims();
jetsoncam::Tensor<float> normalized_peaks = networkResults.normalized_peaks;
normalized_peaks.printDims();
// counts:: torch.Size([1]) [Correct]
// Tensor object_counts 1 dimensions { 1 } [Correct]
// objects:: torch.Size([1, 100, 18]) [Correct]
// Tensor objects 3 dimensions { 1 100 18 } [Correct]
// peaks:: torch.Size([1, 18, 100, 2]) [Correct]
// Tensor refined_peaks 4 dimensions { 1 18 100 2 } [Correct]
printf(LOG_TRT "poseNet::processOutput() Computed\n");
printf(LOG_TRT " ----- object_counts\n");
printf(LOG_TRT " ----- objects\n");
printf(LOG_TRT " ----- normalized_peaks\n");
return overlayPosePoints(
output,
width,
height,
topology,
object_counts,
objects,
normalized_peaks
);
}
#define OVERLAY_CUDA
// overlayLinear
bool poseNet::overlayPosePoints(
float* input,
uint32_t width,
uint32_t height,
jetsoncam::Tensor<int> topology,
jetsoncam::Tensor<int> object_counts,
jetsoncam::Tensor<int> objects,
jetsoncam::Tensor<float> normalized_peaks
)
{
PROFILER_BEGIN(PROFILER_VISUALIZE);
#ifdef OVERLAY_CUDA
// uint8_t* scores;
// generate overlay on the GPU
if( CUDA_FAILED(cudaDrawPose(
(float4*)input,
width,
height,
topology,
object_counts,
objects,
normalized_peaks,
GetStream())) )
{
printf(LOG_TRT "poseNet -- failed to process %ux%u overlay/mask with CUDA\n", width, height);
return false;
}
#endif
PROFILER_END(PROFILER_VISUALIZE);
printf(LOG_TRT "poseNet -- completed Drawing Pose\n");
return true;
}
poseNet API Highlights⌗
- We extended the
<jetson-inference/tensorNet.h>
class - Put our images as input to the
inferenceBuffers
- Put an array of outputs for the
inferenceBuffers
- Retrieve output
float
arrays frommOutputs
mOutputs
is a list of all outputs, for our purposes, we had 2 outputs (cmap and paf) inmOutputs[0]
andmOutputs[1]
- Our outputs are in a float array which we had to wrap in a custom tensor class so we could work with them
- luckily, the
tensorNet
parent class automatically reshapes our input - Take note of the data bindings printed when you load a tensorRT optimized model. You will need these binding names when loading a model e.g.
densenet121_baseline_att_256x256_B_epoch_160.onnx
had the bindingsin/out
=INPUT
,name
=0
for the type of layer (our input layer) and its namein/out
=OUTPUT
,name
=1227
for the type of layer (our first model output layer, for CMAP) and its namein/out
=OUTPUT
,name
=1229
for the type of layer (our first model output layer, for PAF) and its name
- we will need these 3 values when trying to load the network, these are mapped to
input_blob
and an array ofoutput_blobs
, since for this pose model, we have 1 input and 2 outputs (CMAP and PAF) i.enet = Create( NULL, "networks/densenet121_baseline_att_256x256_B_epoch_160.onnx", "0", // input layer name "1227", // output 1 (CMAP) layer name "1229", // output 2 (PAF) layer name maxBatchSize, precision, device, allowGPUFallback );
The input to our model are in 4 dimensions i.e {1, C, H, W}
- where
C
are the number of color channels - where
H
is the height of the image - where
W
is the width of the image
For our model input, if you are using
- densenet121_baseline_att_256x256_B_epoch_160
- the image must be of size
256
x256
and3
channels, so our dimensions are{1, 3, 256, 256}
- the image must be of size
- resnet18_baseline_att_224x224_A_epoch_249
- the image must be of size
224
x224
and3
channels, so our dimensions are{1, 3, 224, 224}
- the image must be of size
For our model outputs we have 2:
CMAP
whose dimensions are{1, 18, 64, 64}
PAF
whose dimensions are{1, 42, 64, 64}
After we get our CMAP
and PAF
tensors, we will need to process them with our Topology
tensor to create our usable outputs i.e.
- Tensor Object Counts - dimensions
{ 1 }
- Tensor Objects - dimensions
{1, 100, 18}
- Tensor Refined Peaks - dimensions
{1, 18, 100, 2}
To compile the CUDA (.cu, .cuh files), add the following to your CMakeLists.txt
NOTE: I will explain the CUDA files (*.cu, *.cuh) in the post about GPU / Parallelization
# we have to compile CUDA libs separately
CUDA_COMPILE(posenet_cuda_o src/poseNet.cu)
add_executable(imagetaker_posenet
...
${posenet_cuda_o}
...
)
- NOTE: the
CUDA_COMPILE
and the${posenet_cuda_o}
Notes and Tips - Parameter Names⌗
Do Not name your function/constructor parameters and your local variables the same thing, it will cause errors e.g.
// BAD!
// DO NOT DO
Myclass {
public:
Myclass();
protected:
int myint; // member variable
}
// Note how the function parameter and the member function name are
// the same, this WILL CAUSE ERRORS
Myclass::Myclass(int myint) {
myint = myint;
}
If you have any tips or questions feel free to leave a comment!
NOTE: this code was created as a proof of concept, do not put it in any production code unless you are absolutely sure you know what you are doing.
I am breaking this post into 2 as it is very long right now, next time I will show you the poseNet ROS node.
See you in Part 2!
Links to source code⌗
Full source code available on GitHub
Sources and Recommended Reading⌗
- Nvidia AI IOT Trt Pose on GitHub
- Tensor RT Home
- Tensor RT Documentation
- Lots of Googling