Blog

Neural Networks in Unity using Native Libraries

March 11, 2020

By Nicholas Guttenberg
Cross Compass
GoodAI

Summary

  • This guide shows how to use Pytorch’s C++ API to use neural networks in Unity.
  • We can use this with existing Python-based models, by freezing the execution trace into a binary file that is loaded by the library at runtime.
  • In this form, it is easier to deploy a completed project to users (e.g. no concerns about running a Python server in sync with the Unity side of things).

Introduction

While machine learning techniques including neural networks are expanding in terms of what they can do in a research environment, deployment into specific contexts can still be quite problematic. In particular, here I’m going to consider the case of shipping a Unity demo or game, and what sorts of limitations to expect in terms of moving some arbitrary research-grade neural network into a form that should work more or less out of the box for someone who downloads it.

There are a few ways of getting a neural network into Unity. In this document I’m going to focus on using a C++ API for Pytorch called libtorch in order to make a native shared library, which can then be loaded as a plugin to Unity. There are other approaches such as ML-Agents which may be more appropriate, but the advantage of the shared library method is that pretty much you can do anything you can do in Pytorch with it. So if you have some exotic model, you just want to use some existing Pytorch code that has been written without Unity in mind, or you want to set up a workflow where some people are developing the model in a vacuum and shouldn’t have to worry about the Unity side of things, this may be an appropriate method.

This guide has four steps – setting up a development environment, the C++ side of the interface, the Unity/C# side of the interface, and saving/deploying the model.

(A caveat, since the project I was working on which needed this was developed under Linux, my installation/development/etc examples are also Linux-based; I don’t think anything here should be too OS dependent, but if you want a demo to run on Windows for example, you will need also need to prepare a version of the native library for Windows)

Development environment

I’m not going to go into too much detail, but before we get to libtorch you’ll need:

  • cmake

If you want to use GPU

  • The CUDA toolkit – As of writing this, Pytorch says 10.1 is the supported version, so I went with that.
  • The CUDNN library

CUDA in particular can be finnicky because the driver, libraries, etc all should match to some degree. And you’ll have to ship these libraries with the demo if you want to make sure things run out of the box. So this to me is the most uncomfortable part. If you don’t plan to run things on GPU, you could avoid CUDA, but that is likely to be 50-100x slower even if the user has a fairly weak GPU. Even if your network is called fairly infrequently, this would impact the severity of lag spikes experienced by the user. So this is up to your particular use case whether you want to risk it.

Once you have those things installed, you can download and (locally) install libtorch. This does’t have to be a system-wide install – you’ll just put it in a directory within your project and refer to it when you run cmake.

The C++ side

The next step will be setting up a CMakefile to build the native library. I basically started from the example in Pytorch’s own documentation for setting up this environment, then modified it slightly to make the output be a library rather than an executable. Put this in the root directory of the native library project.

CMakeLists.txt
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(networks)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_library(networks SHARED networks.cpp)
target_link_libraries(networks "${TORCH_LIBRARIES}")
set_property(TARGET networks PROPERTY CXX_STANDARD 14)

if (MSVC)
	file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
	add_custom_command(TARGET networks
		POST_BUILD
		COMMAND ${CMAKE_COMMAND} -E copy_if_different
		${TORCH_DLLS}
		$<TARGET_FILE_DIR:example-app>)
endif (MSVC)

The source code for the C++ side will be in networks.cpp. One nice thing about this approach is we don’t have to care too much about exactly what kind of neural network we want to interface with Unity yet. The reason (jumping ahead a bit) is that we can run the network once in Python, get the execution trace, and just tell libtorch ‘do this trace on these inputs’. So mainly the C++ side of things just deals with I/O.

If you want to do fancy things like for example training a network online as part of a Unity environment, then you would have to write the network architecture and training code in C++, which can be substantially more involved (but it is entirely possible to do). However, it’s beyond the scope of this guide, so I’ll refer you to the Pytorch documentation and example projects for more information.

Anyhow, what we’re going to do in networks.cpp is to define an external function to initialize the networks (load from disk) and an external function which runs the network and returns the results.

networks.cpp
#include <torch/script.h>
#include <vector>
#include <memory> 

extern "C"
{
// This is going to store the loaded network
torch::jit::script::Module network;

In order to call functions from Unity, we need to expose entry points to them in the shared library. Under Linux we can put __attribute__((visibility("default"))) before the function. On Windows I think this would be __declspec( dllexport ), but I haven’t personally checked whether that works so be careful.

First, a function to load the network trace from disk. The path of this file is relative to the root directory of the Unity project, NOT in Assets/, so be careful about that. You could also just pass a filename from Unity.

extern __attribute__((visibility("default"))) void InitNetwork()
{
	network = torch::jit::load("network_trace.pt");
	network.to(at::kCUDA); // If we're doing this on GPU
}

Now, the function to apply the network to input data. We’re going to use pointers managed by Unity to shuffle information back and forth with C++. In this example, I’m just going to assume that my network has a fixed sized input and output – that is, I’m not going to let Unity vary the batch size or dimensions. Here for example I will take in a tensor with dimensions {1,3,64,64} and output a tensor with dimension {1,5,64,64} (for example, if you had a network that segmented an RGB image into 5 classes).

In more general cases you’d want to pass information about the size of the allocated space to avoid buffer overflows.

To get the data into a format that libtorch can use, we use torch::from_blob. This takes an array of floats and a specification of the tensor sizes, and returns a Tensor.

The execution traces can be of networks that have multiple input arguments (for example, if the network’s forward() call took x,y,z as inputs). To handle this, all the input tensors are packed into a standard template library vector of torch::jit::IValue in order of the arguments (even if there’s only one argument).

To get data out of a Tensor, the easiest thing is to process it element-wise, but if that ends up being a bottleneck for speed there’s something called Tensor::accessor for setting up structured read-outs of Tensors (it was overkill for my project).

Anyhow, to apply the network:

extern __attribute__((visibility("default"))) void ApplyNetwork(float *data, float *output)
{
	Tensor x = torch::from_blob(data, {1,3,64,64}).cuda();
	std::vector<torch::jit::IValue> inputs;
	inputs.push_back(x);
	Tensor z = network.forward(inputs).toTensor();
	for (int i=0;i<1*5*64*64;i++)
	output[i] = z[0][i].item<float>();
}
}

To compile this, following the docs, make a build/ subdirectory and call:

  • cmake -DCMAKE_PREFIX_PATH=/absolute/path/to/libtorch ..
  • cmake --build . --config Release

If all goes well, you’ll get a libnetworks.so or networks.dll file that you can drop into Assets/Plugins/ in Unity.

And that’s it on the C++ side!

Unity-side of the interface

On the Unity side, we will use DllImport to pull in functions from this library. First we’re going to want to call InitNetwork() during our startup in the Unity environment. This only needs to happen once globally. The code for this is just (inside whatever loader script you want):

using System.Runtime.InteropServices;

public class Startup : MonoBehaviour
{
	...
	[DllImport("networks")]
	private static extern void InitNetwork();

	void Start()
	{
		...
		InitNetwork();
		...
	}
}

Now for whatever script(s) we want to be able to use the network:

In order to move data from Unity to C++ and back, I’m taking the approach here of letting Unity handle the memory management. To do this, I’ll allocate arrays of the right size on the Unity side of things, pass references to their first element to C++, and let C++ pointer math just handle loading/unloading data from that memory.

Doing it this way means that I should avoid any sort of allocation/de-allocation of things that should persist into Unity on the C++ side of things; similarly, if I pass a pointer to C++ from Unity, I should not hold onto it past the end of the function since Unity might invalidate it.

But thankfully the C++ side of this is really simple, so this isn’t a big ask. However, if you want to do something with parallelizing the neural network calls so the demo keeps going while the neural network evaluates, it could be a concern.

[DllImport("networks")]
private static extern void ApplyNetwork(ref float data, ref float output);

void SomeFunction() {
	float[] input = new float[1*3*64*64];
	float[] output = new float[1*5*64*64];

	// Load input with whatever data you want
	...

	ApplyNetwork(ref input[0], ref output[0]);

	// Do whatever you want with the output
	...
}

And that’s it for the Unity side.

Saving the model

So far we haven’t at all discussed the actual neural network model. For sake of example, I’m going to go with the kind of simple convolutional neural network you might use to do image segmentation. I’m not going to include data collection, training, etc sort of considerations here – this is just the bare bones of defining an architecture and turning it into a frozen execution trace.

Again, the (Pytorch documentation)[https://pytorch.org/tutorials/advanced/cpp_export.html] has a good example for how to do this and talks about some of the corner cases and potential issues, so I won’t repeat that extensively here. The big point to keep in mind is that Python flow control stuff, things where you dip into/out of Pytorch operations, etc may not be captured by the trace. The documentation talks about ways to get around this using annotations and explicitly compiling the model.

Anyhow, this is what our model might look like on the Python side of things:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
	def __init__(self):
		super().__init__()
		self.c1 = nn.Conv2d(3,64,5,padding=2)
		self.c2 = nn.Conv2d(64,5,5,padding=2)

	def forward(self, x):
		z = F.leaky_relu(self.c1(x))
		z = F.log_softmax(self.c2(z), dim=1)
	return z

Not a very good segmentation architecture, but hopefully you get the idea.

Then, if we want to export that model (with its weights, whatever they might be at the time of export), we would do something like:

network = Net().cuda()
example = torch.rand(1, 3, 32, 32).cuda()
traced_network = torch.jit.trace(network, example)
traced_network.save("network_trace.pt")

And that’s basically it!

A note on distribution

Unless you’ve made a statically linked library, you’re going to need to also distribute some supporting libraries with your project. A caveat here – I haven’t tried distributing such a project yet, so I’m not 100% sure of the libraries you need to package. But at the very least, this is probably libtorch, libc10, libc10_cuda, libnvToolsExt, and libcudart (based on running ldd libnetworks.so and seeing what pops up). This adds about 2gb to the size of a project, so it’s something to keep in mind.

Why did I do this when there’s ML-Agents?

For a lot of projects, especially in research and prototyping, I think ML-Agents is a really good way to just get into things and start using neural networks in Unity. But when things get more complex, it’s good to have a general option to fall back on.

Up until a couple of weeks ago, I was using ML-Agents to interface between a Unity demo and a couple of different neural networks on the Python side of things. Based on what happened in Unity, it might want to call some or other of these networks, on different sets of data.

I was using the Python interface for ML-Agents already because some of the operations my networks used weren’t yet supported in Barracuda (which is the trace execution library ML-Agents is currently using) – things like 1d convolutions and transpose operations.

The issue I ran into was that ML-Agents collects ‘requests’ from agents over the course of a time-step and then sends all the requests to the Python notebook to be evaluated. However, some of my networks depended on the output of others of my networks, so I would have to make a request, wait a timestep, grab the result, make another request, wait a timestep, grab the result, … in order to evaluate a chain of networks. Furthermore, the execution order and depth depended non-trivially on user input, meaning that I couldn’t just run the networks in sequence in Python. Also, in some cases the amount of data I would need to send could vary, but ML-Agents has you define a fixed dimension of sensor space per-agent (it might just be as simple as changing this on the fly, but I’m not actually sure about ML-Agents’ internal flow to know whether that’s kosher or not).

Although I could have done something like computing the sequence of calculations requested, sending that as a sensor input to Python, then applying the networks accordingly, it was starting to build up as excess complexity in my code on the Unity side and on the Python side. So I looked into this approach instead.

Before, if someone asked me to put a pretrained GPT-2 model into a demo or use MAML, I would have said that if we could at all avoid doing so we should, because of the potential complexity it would entail. But with the native library approach, I’d consider it potentially feasible.

Join GoodAI

Are you keen on making a meaningful impact? Interested in joining the GoodAI team?

View open positions