ONNX / TensorFlow Lite for edge deployment
Description
Edge deployment refers to deploying machine learning models on edge devices such as mobile phones, IoT devices, microcontrollers, and embedded systems — enabling inference without relying on cloud connectivity.
ONNX (Open Neural Network Exchange) and TensorFlow Lite are two popular formats that allow trained models to be optimized and deployed on edge devices efficiently:
- ONNX is an open-source format supported by PyTorch, scikit-learn, Keras, and others. It enables interoperability across frameworks and supports edge deployment through runtimes like ONNX Runtime.
- TensorFlow Lite is a lightweight version of TensorFlow designed specifically for mobile and embedded devices. It supports model quantization and hardware acceleration.
Edge deployment using ONNX or TensorFlow Lite reduces latency, saves bandwidth, and enables real-time AI inference — ideal for offline, real-time, and privacy-sensitive applications.
Workflow of model conversion and deployment with TensorFlow Lite
Examples
This example demonstrates how to build a simple image classifier, convert it to TensorFlow Lite and ONNX formats, and then load it for inference on an edge device.
TensorFlow Lite Workflow
# Step 1: Build and train a simple Keras model
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3)
model.save("mnist_model.h5")
# Step 2: Convert to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('mnist_model.tflite', 'wb') as f:
f.write(tflite_model)
# Step 3: Load and run inference with TFLite Interpreter
import numpy as np
interpreter = tf.lite.Interpreter(model_path="mnist_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_data = np.expand_dims(x_test[0], axis=0).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
predicted_label = np.argmax(output_data)
print("Predicted digit:", predicted_label)
PyTorch to ONNX Workflow
# Step 1: Build and train a simple PyTorch model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = F.relu(self.fc1(x))
return F.log_softmax(self.fc2(x), dim=1)
transform = transforms.Compose([transforms.ToTensor()])
train_loader = DataLoader(datasets.MNIST('.', train=True, download=True, transform=transform), batch_size=32)
model = Net()
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.NLLLoss()
for epoch in range(1):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), "mnist_model.pt")
# Step 2: Convert to ONNX
dummy_input = torch.randn(1, 1, 28, 28)
model.eval()
torch.onnx.export(model, dummy_input, "mnist_model.onnx",
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
# Step 3: Load and run inference using ONNX Runtime
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession("mnist_model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Simulate a test image
test_image = x_test[0].astype(np.float32)
test_image = np.expand_dims(test_image, axis=0) # batch
test_image = np.expand_dims(test_image, axis=1) # channel
output = session.run([output_name], {input_name: test_image})
predicted = np.argmax(output[0])
print("Predicted digit (ONNX):", predicted)
Real-World Applications
Mobile Image Classification
Using TensorFlow Lite, mobile apps can classify images locally (e.g., plant disease identification).
Smart Traffic Systems
Edge AI models detect traffic patterns in real-time using ONNX for surveillance and congestion control.
Wearable Health Monitoring
TensorFlow Lite powers real-time ECG or activity monitoring on fitness bands and medical wearables.
IoT Devices
ONNX or TFLite models run directly on microcontrollers to enable smart detection (e.g., leak sensors).
Resources
Recommended Books
- TensorFlow Lite for Mobile and Edge Devices by Bhavani Rao
- Edge Computing by Perry Lea
Interview Questions
What is the difference between ONNX and TensorFlow Lite?
ONNX is a framework-agnostic open standard that allows models trained in different frameworks (e.g., PyTorch, Scikit-learn) to be exported and run across platforms. TensorFlow Lite is a specialized version of TensorFlow designed for mobile and edge deployment, offering size and performance optimizations specifically for TensorFlow models.
Why is model quantization important for edge deployment?
Quantization reduces model size and increases inference speed by converting weights and activations from 32-bit floats to lower precision (e.g., 8-bit integers), with minimal loss in accuracy. It is essential for constrained environments like mobile and IoT devices.
What are the benefits of edge deployment?
- Low latency: Faster response times as no cloud round-trip is needed.
- Offline capability: Works without internet access.
- Improved privacy: Data stays on-device.
- Reduced bandwidth: No need to send data to a server.
How do you run an ONNX model on a mobile device?
You can use ONNX Runtime Mobile or integrate with third-party frameworks like Microsoft ML.NET or OpenCV. ONNX provides APIs for Android and iOS, and supports hardware acceleration through backends like NNAPI or CoreML.