Integrating Custom PyTorch Models Into an Android App

Gabriel Mongaras
12 min readJul 5, 2022

Recently, I’ve been learning Android app development. Although it’s really cool to build an app from scratch, I was wondering if it would be possible to integrate a PyTorch model into my app. I didn’t want to make any backend calls as that would take a while and I would have to use a database service which would cost money. Fortunately, PyTorch has Android support and I will go over how you can put your custom models into your apps with ease.

I also have a repo of the code to go along with this tutorial.

Setup

To start, you’ll want to install android studio, which you can get for free!

Next, let’s create an empty project.

Selecting a new project

Make sure you choose Java as the language for this project as I will be using that in this tutorial. You can name the project whatever you want it to be. I’m going to name the project PyTorch_App. Below is what your configuration should kind of look like:

New project configuration example

Below are my initial setup files in case you want to use the same library versions as I do:

App-level build.gradle
Project-level build.gradle

If you don’t have a device emulator setup, you should set one up so that you can run the app.

With the project created, let’s add the PyTorch library to our app. Doing this is easy enough and requires adding a single line to the app-level build.gradle file:

Updated app-level build.gradle

Click “Sync Now” in the upper right-hand corner and the project should update. Now the setup is complete and we can start creating our app!

Model Creation

To show how PyTorch can be used with a model, I’m going to be creating a simple model that takes N noise vectors as input and outputs N numbers between 0 and 9. So, let’s start coding!

import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

# Basic MLP with 2 inputs, 4 hidden layers
# and 10 outputs where each output is
# the softmax probabilities of a number 0 to 9
self.MLP = nn.Sequential(
nn.Linear(2, 5),
nn.Linear(5, 10),
nn.Linear(10, 15),
nn.Linear(15, 20),
nn.Linear(20, 15),
nn.Linear(15, 10),
nn.Softmax(-1)
)

def forward(self, X):
return torch.argmax(self.MLP(X), dim=-1)

The model will take in a vector with 2 elements (which is just some arbitrary number I chose) and output a vector with 10 elements where each element in the output vector is the probability the model thinks that number should be chosen. Of course, the model isn’t trained, so the probability output doesn’t have any context. Nonetheless, we can just take the index of the max output value (argmax) in the vector to get the predicted value. The function below is used to show how we can get a sample of N essentially random numbers from the model by inputting N noise vectors.

def main():
# Create the model
model = Model()

# Create 4 random noise vectors
# meaning we want 4 random numbers
X = torch.distributions.uniform.Uniform(-10000,\
10000).sample((4, 2))

# Send the noise vectors through the model
# to get the argmax outputs
outputs = model(X)

# Print the outputs
for o in outputs:
print(f"{o.item()} ")

# Save the model to a file named model.pkl
torch.save(model.state_dict(), "model.pkl")

The model is then saved to a file named ‘model.pkl’ which we will use as the model for the Android app.

Before working on the app, you can imagine a normal model has several problems. For example:

  • A good model is probably very large
  • A model can take a while to load
  • Models can take a while to make a prediction

To reduce these problems, PyTorch has a very useful feature called TorchScript which optimizes the model for deployment. Of course, the reduction can only be so much. So, don’t expect it to magically turn massive models into a smaller model. Additionally, the model structure does not have to be defined when loading in the model, making it very easy to use in our app. The code below takes our saved model, optimizes it for mobile, and then saves it to a new file.

from torch.utils.mobile_optimizer import optimize_for_mobile
def optimizeSave():
# Load in the model
model = Model()
model.load_state_dict(torch.load("model.pkl", \
map_location=torch.device("cpu")))
model.eval() # Put the model in inference mode

# Generate some random noise
X = torch.distributions.uniform.Uniform(-10000, \
10000).sample((4, 2))

# Generate the optimized model
traced_script_module = torch.jit.trace(model, X)
traced_script_module_optimized = optimize_for_mobile(\
traced_script_module)

# Save the optimzied model
traced_script_module_optimized._save_for_lite_interpreter(\
"model.pt")

Now we have an optimized model to put in our app!

Creating The App

The app will consist of a basic edit text field for the user to specify the number of random numbers to generate, a button to generate some random numbers, and a text field to display the random numbers. The code in ‘activity_main.xml’ is as follows:

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginBottom="100dp"
android:text="Number of digits to generate"
android:textColor="@color/black"
android:textSize="20sp"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<EditText
android:id="@+id/etNumber"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:ems="10"
android:hint="digits"
android:inputType="number"
android:minHeight="48dp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/textView" />
<Button
android:id="@+id/btnInfer"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="10dp"
android:text="Infer"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/etNumber" />
<TextView
android:id="@+id/tvDigits"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="10dp"
android:textColor="@color/black"
android:textSize="20sp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/btnInfer"
tools:text="123456789" />
</androidx.constraintlayout.widget.ConstraintLayout>

Now that we have a structure for the app, let’s add the model. To start, create a new directory in app/src/main called assets. Then place the ‘model.pt’ file in the assets folder. Your filesystem should look like the following

Image showing the proper location of the assets directory

Now we can start loading the model. In ‘MainActivity.java’, let’s add a few global variables outside of onCreate that will help us with storing the model and elements in the view.

// Elements in the view
EditText etNumber;
Button btnInfer;
TextView tvDigits;

// Tag used for logging
private static final String TAG = "MainActivity";

// PyTorch model
Module module;

Now, let’s load in the model in the onCreate method.

// Get all the elements
etNumber = findViewById(R.id.etNumber);
btnInfer = findViewById(R.id.btnInfer);
tvDigits = findViewById(R.id.tvDigits);

// Load in the model
try {
module = LiteModuleLoader.load(assetFilePath("model.pt"));
} catch (IOException e) {
Log.e(TAG, "Unable to load model", e);
}

assetFilePath is a function we need to define. All it does is create the path of ‘model.pt’ assuming it’s in the app/src/main/assets folder.

// Given the name of the pytorch model, get the path for that model
public String assetFilePath(String assetName) throws IOException {
File file = new File(this.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}

try (InputStream is = this.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}

Additionally, another helper function that should be defined that takes a size as input and outputs a tensor of that size.

// Generate a tensor of random numbers given the size of that tensor.
public Tensor generateTensor(long[] Size) {
// Create a random array of floats
Random rand = new Random();
float[] arr = new float[(int)(Size[0]*Size[1])];
for (int i = 0; i < Size[0]*Size[1]; i++) {
arr[i] = -10000 + rand.nextFloat() * (20000);
}

// Create the tensor and return it
return Tensor.fromBlob(arr, Size);
}

Finally, let’s create an onClick event for the button so that a new sequence of numbers is generated when it’s clicked. This function will be defined immediately after loading in the model

// When the button is clicked, generate a noise tensor
// and get the output from the model
btnInfer.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
// Error checking
if (etNumber.getText().toString().length() == 0) {
Toast.makeText(MainActivity.this, "A number must be supplied", Toast.LENGTH_SHORT).show();
return;
}

// Get the number of numbers to generate from the edit text
int N = Integer.parseInt(etNumber.getText().toString());

// More error checking
if (N < 1 || N > 10) {
Toast.makeText(MainActivity.this, "Digits must be greater than 0 and less than 10", Toast.LENGTH_SHORT).show();
return;
}

// Prepare the input tensor (N, 2)
long[] shape = new long[]{N, 2};
Tensor inputTensor = generateTensor(shape);

// Get the output from the model
long[] output = module.forward(IValue.from(inputTensor)).toTensor().getDataAsLongArray();

// Get the output as a string
String out = "";
for (long l : output) {
out += String.valueOf(l);
}

// Show the output
tvDigits.setText(out);
}
});

The methodology is very similar to the main function in the Python code as it creates a noise vector, passes it through the model, and then iterates over the output to show the outputted characters.

If you run the app now, you should be able to generate random characters upon clicking the button

Demo of the app we created!

Note: If you are getting a null error on the forward pass, the model file is probably in the wrong directory. Make sure the assets directory is located in app/src/main/, not in app/src/[androidTest].

Picture Generation

As much fun as clicking a button to generate a random number is, what about images?

I’ve trained a StyleGAN3 model which we will use for image generation. Although the model is very large, it is still a good example of how an image model can be used in Android. To start, create a new empty activity called MainActivity2. Then go into the Android Manifest to change the initial activity from MainActivity to MainActivity2. The Android Manifest should look like the following

Updated AndroidManifest.xml

Note the switch in the location of <intent-filter> from MainActivity to MainActivity2. This is what changes the starting activity. Now let’s create the XML file. The XML file will have a few views in it to display an image when a button is clicked

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity2">
<Button
android:id="@+id/btnGenerate"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Create new image"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<ImageView
android:id="@+id/ivImage"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
app:layout_constraintBottom_toTopOf="@+id/btnGenerate"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent"
tools:srcCompat="@tools:sample/avatars" />
<TextView
android:id="@+id/tvWaiting"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Generating image..."
android:textColor="@color/black"
android:textSize="20sp"
android:visibility="invisible"
app:layout_constraintBottom_toTopOf="@+id/btnGenerate"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
</androidx.constraintlayout.widget.ConstraintLayout>

Now, let’s start coding the activity. To start, let’s add some global variables to help later in the activity.

// Elements in the view
Button btnGenerate;
ImageView ivImage;
TextView tvWaiting;

// Tag used for logging
private static final String TAG = "MainActivity2";

// PyTorch model
Module module;

// Size of the input tensor
int inSize = 512;

// Width and height of the output image
int width = 256;
int height = 256;

The stylegan model takes a 512-dimensional noise vector as input and outputs an image of shape 256✕256✕3 where 3 is the RGB channel.

Now let’s include the same exact function used in the last activity to get the path of an asset. This function is exactly the same as the one in the previous activity.

// Given the name of the pytorch model, get the path for that model
public String assetFilePath(String assetName) throws IOException {
File file = new File(this.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}

try (InputStream is = this.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}

As for the generateTensor function, we need to change it up a little bit. Since we are only generating a single image at a time, we want to generate a single tensor with 512 elements. Additionally, stylegan noise is specifically sampled from a Gaussian distribution, so we can change rand.nextrand() to rand.nextGaussian(). Since the model is actually trained, we don’t need to multiply the noise by any values, as we did with the other activity.

// Generate a tensor of random doubles given the size of
// the tensor to generate
public Tensor generateTensor(int size) {
// Create a random array of doubles
Random rand = new Random();
double[] arr = new double[size];
for (int i = 0; i < size; i++) {
arr[i] = rand.nextGaussian();
}

// Create the tensor and return it
long[] s = {1, size};
return Tensor.fromBlob(arr, s);
}

Now that the helper variables and functions are in place, download the imageGen.pt model and place it into the assets directory as you did with the other model.

I’m not going to explain how the .pt model was created as the process is very similar to the number generation model, but if you are interested in seeing the script that generated the model, you can find it in this part of the code that goes along with this article.

With the model in place, we can start coding the onCreate function. To start, let’s add basic view finding and model loading.

// Get the elements in the activity
btnGenerate = findViewById(R.id.btnGenerate);
ivImage = findViewById(R.id.ivImage);
tvWaiting = findViewById(R.id.tvWaiting);

// Load in the model
try {
module = LiteModuleLoader.load(assetFilePath("imageGen.pt"));
} catch (IOException e) {
Log.e(TAG, "Unable to load model", e);
}

Finally, all we have to do is put an onClick listener on the button.

// When the button is clicked, generate a new image
btnGenerate.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
// Error handing
btnGenerate.setClickable(false);
ivImage.setVisibility(View.INVISIBLE);
tvWaiting.setVisibility(View.VISIBLE);

// Prepare the input tensor. This time, its a
// a single integer value.
Tensor inputTensor = generateTensor(inSize);

// Run the process on a background thread
new Thread(new Runnable() {
@Override
public void run() {
// Get the output from the model. The
// length should be 256*256*3 or 196608
// Note that the output is in the layout
// [R, G, B, R, G, B, ..., B] and we
// have to deal with that.
float[] outputArr = module.forward(IValue.from(inputTensor)).toTensor().getDataAsFloatArray();

// Ensure the output array has values between 0 and 255
for (int i = 0; i < outputArr.length; i++) {
outputArr[i] = Math.min(Math.max(outputArr[i], 0), 255);
}

// Create a RGB bitmap of the correct shape
Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565);

// Iterate over all values in the output tensor
// and put them into the bitmap
int loc = 0;
for (int y = 0; y < width; y++) {
for (int x = 0; x < height; x++) {
bmp.setPixel(x, y, Color.rgb((int)outputArr[loc], (int)outputArr[loc+1], (int)outputArr[loc+2]));
loc += 3;
}
}

// The output of the network is no longer needed
outputArr = null;

// Resize the bitmap to a larger image
bmp = Bitmap.createScaledBitmap(
bmp, 512, 512, false);

// Display the image
Bitmap finalBmp = bmp;
runOnUiThread(new Runnable() {
@Override
public void run() {
ivImage.setImageBitmap(finalBmp);

// Error handing
btnGenerate.setClickable(true);
tvWaiting.setVisibility(View.INVISIBLE);
ivImage.setVisibility(View.VISIBLE);
}
});

}
}).start();

}
});

This function looks a bit scary, so let me break it down:

  1. When the button is clicked, disable the button and generate the input tensor.
  2. On a background thread, send the noise through the model to get a float output array named outputArr.
    - This output array has 1 data dimension with 256*256*3 or 196608 floating-point values.
    - The array is structured as a flattened image where the first 256*3 or 768 values belong to the first row of pixels in the image.
    - In the array, the RGB values are stored sequentially like the following: [R, G, B, R, G, B, … R, G, B]. So the first pixel is represented by the first three values in the array, the second by the second three values, and so on.
  3. After obtaining the float array, we need to clamp values between 0 and 255 since that’s the range RGB values can be in.
  4. We can then convert the values to an RGB bitmap using the properties of the array.
  5. The bitmap is then resized from 256✕256 to 512✕512 for better viewing purposes.
  6. Finally, on the main thread, we can change the image view with the newly created image and enable clicking on the button.

An example of the output is shown below

Example output

Note: If your app crashes while generating a new image, then it probably ran out of memory or had a related issue. I don’t know of many ways to fix this issue without changing the model, but just in case it makes a difference, I used a Pixel 4 emulator with API 32 (Sv2) on a Windows laptop.

Since the stylegan model is so large and uses so much memory, it may crash your emulator, but the process of extracting image data from a tensor should be the same for other image generation models. In practice, if an image generation model was put on an Android app, it would probably have to be optimized a little more for a mobile device.

That’s all there is to integrating a PyTorch model into an Android app. It’s not very difficult, though dealing with image output can be a bit tedious. Anyways, I hope you found this article helpful!

--

--

Gabriel Mongaras

AI enthusiast and CS student at SMU. For more information visit my website: https://gabrielm.cc/