Back to Topics

ONNX / TensorFlow Lite for edge deployment

Description

Edge deployment refers to deploying machine learning models on edge devices such as mobile phones, IoT devices, microcontrollers, and embedded systems — enabling inference without relying on cloud connectivity.

ONNX (Open Neural Network Exchange) and TensorFlow Lite are two popular formats that allow trained models to be optimized and deployed on edge devices efficiently:

  • ONNX is an open-source format supported by PyTorch, scikit-learn, Keras, and others. It enables interoperability across frameworks and supports edge deployment through runtimes like ONNX Runtime.
  • TensorFlow Lite is a lightweight version of TensorFlow designed specifically for mobile and embedded devices. It supports model quantization and hardware acceleration.
Key Insight

Edge deployment using ONNX or TensorFlow Lite reduces latency, saves bandwidth, and enables real-time AI inference — ideal for offline, real-time, and privacy-sensitive applications.

Edge deployment architecture with TensorFlow Lite

Workflow of model conversion and deployment with TensorFlow Lite

Examples

This example demonstrates how to build a simple image classifier, convert it to TensorFlow Lite and ONNX formats, and then load it for inference on an edge device.

TensorFlow Lite Workflow

# Step 1: Build and train a simple Keras model
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=3)
model.save("mnist_model.h5")
# Step 2: Convert to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('mnist_model.tflite', 'wb') as f:
    f.write(tflite_model)
# Step 3: Load and run inference with TFLite Interpreter
import numpy as np
interpreter = tf.lite.Interpreter(model_path="mnist_model.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_data = np.expand_dims(x_test[0], axis=0).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
predicted_label = np.argmax(output_data)
print("Predicted digit:", predicted_label)

PyTorch to ONNX Workflow

# Step 1: Build and train a simple PyTorch model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

transform = transforms.Compose([transforms.ToTensor()])
train_loader = DataLoader(datasets.MNIST('.', train=True, download=True, transform=transform), batch_size=32)

model = Net()
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.NLLLoss()

for epoch in range(1):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

torch.save(model.state_dict(), "mnist_model.pt")
# Step 2: Convert to ONNX
dummy_input = torch.randn(1, 1, 28, 28)
model.eval()
torch.onnx.export(model, dummy_input, "mnist_model.onnx",
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
# Step 3: Load and run inference using ONNX Runtime
import onnxruntime as ort
import numpy as np

session = ort.InferenceSession("mnist_model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Simulate a test image
test_image = x_test[0].astype(np.float32)
test_image = np.expand_dims(test_image, axis=0)  # batch
test_image = np.expand_dims(test_image, axis=1)  # channel

output = session.run([output_name], {input_name: test_image})
predicted = np.argmax(output[0])
print("Predicted digit (ONNX):", predicted)

Real-World Applications

Mobile Image Classification

Using TensorFlow Lite, mobile apps can classify images locally (e.g., plant disease identification).

Smart Traffic Systems

Edge AI models detect traffic patterns in real-time using ONNX for surveillance and congestion control.

Wearable Health Monitoring

TensorFlow Lite powers real-time ECG or activity monitoring on fitness bands and medical wearables.

IoT Devices

ONNX or TFLite models run directly on microcontrollers to enable smart detection (e.g., leak sensors).

Resources

PDFs

The following documents

Recommended Books

Interview Questions

What is the difference between ONNX and TensorFlow Lite?

ONNX is a framework-agnostic open standard that allows models trained in different frameworks (e.g., PyTorch, Scikit-learn) to be exported and run across platforms. TensorFlow Lite is a specialized version of TensorFlow designed for mobile and edge deployment, offering size and performance optimizations specifically for TensorFlow models.

Why is model quantization important for edge deployment?

Quantization reduces model size and increases inference speed by converting weights and activations from 32-bit floats to lower precision (e.g., 8-bit integers), with minimal loss in accuracy. It is essential for constrained environments like mobile and IoT devices.

What are the benefits of edge deployment?

  • Low latency: Faster response times as no cloud round-trip is needed.
  • Offline capability: Works without internet access.
  • Improved privacy: Data stays on-device.
  • Reduced bandwidth: No need to send data to a server.

How do you run an ONNX model on a mobile device?

You can use ONNX Runtime Mobile or integrate with third-party frameworks like Microsoft ML.NET or OpenCV. ONNX provides APIs for Android and iOS, and supports hardware acceleration through backends like NNAPI or CoreML.