How to Build a REST API for ML Model Serving in Rust

Build a Rust REST API for ML serving using axum and tokio to handle requests and return predictions.

Serving models at scale

You trained a model in Python. It predicts house prices with 95% accuracy. Now the backend team needs it in production. They don't want to spin up a Flask server that eats 2GB of RAM just to run inference. They want something fast, small, and safe. Rust is the answer. You're going to wrap that model in a REST API using axum, tokio, and serde.

The receptionist and the doctor

Think of your ML model as a specialist doctor. It knows how to diagnose, but it doesn't want to deal with scheduling, paperwork, or angry patients. The REST API is the receptionist. Patients (clients) hand over forms (JSON data). The receptionist checks the forms, calls the doctor, gets the diagnosis, and hands back a result slip. The doctor stays focused on the math. The receptionist handles the traffic.

In Rust, axum is the receptionist. It parses HTTP requests, routes them to the right handler, and formats responses. tokio is the building's power grid. It lets the receptionist handle thousands of patients at once without waiting for one diagnosis to finish before calling the next. serde is the translation layer. It turns the paper forms into structured data the doctor understands, and turns the diagnosis back into a slip the patient can read.

Build the receptionist first. The doctor can wait.

Minimal working server

Start with a skeleton that proves the plumbing works. You need three crates: axum for the web framework, tokio for the async runtime, and serde for JSON serialization.

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"

The derive feature in serde lets you generate serialization code automatically with #[derive]. Without it, you'd have to write manual Serialize and Deserialize implementations for every struct.

Here is the server code. It defines a route, a handler, and the data structures.

use axum::{routing::post, Json, Router};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;

/// Input expected by the model.
/// Derive Deserialize so axum can parse JSON into this struct.
#[derive(Deserialize)]
struct InputData {
    /// The model expects a vector of features.
    features: Vec<f32>,
}

/// Output returned to the client.
/// Derive Serialize so axum can convert this struct to JSON.
#[derive(Serialize)]
struct Prediction {
    /// The model outputs a single probability.
    result: f32,
}

/// Handler for the /predict endpoint.
/// The Json extractor deserializes the request body automatically.
async fn predict(Json(input): Json<InputData>) -> Json<Prediction> {
    // In a real app, you'd load the model once and share it.
    // Here we simulate inference with a dummy calculation.
    // Summing features and dividing by length gives a mean score.
    let score = input.features.iter().sum::<f32>() / input.features.len();
    Json(Prediction { result: score })
}

#[tokio::main]
async fn main() {
    // Build the router with a single POST route.
    // post(predict) binds the handler to HTTP POST requests.
    let app = Router::new().route("/predict", post(predict));

    // Bind to localhost:3000.
    // 127.0.0.1 ensures the server is only accessible locally.
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    println!("Listening on {}", addr);

    // Start the TCP listener.
    // unwrap() panics if the port is already in use.
    // In production, use expect() with a message or handle the error gracefully.
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

Convention aside: The community prefers expect("failed to bind listener") over unwrap() in main. It tells the next developer exactly what went wrong when the server crashes. unwrap() just says "something failed" and points to a line number.

When you run cargo run, tokio spins up an event loop. axum listens on port 3000. A client sends POST /predict with {"features": [1.0, 2.0]}. axum parses the URL, matches the route, and calls predict. The Json extractor reads the body, deserializes it into InputData using serde. If the JSON is malformed, axum returns a 400 error automatically. Your handler runs the math. serde serializes the result back to JSON. axum sends the response. The client gets {"result": 1.5}.

Run cargo run. Hit the endpoint. See the JSON fly back.

Realistic serving with shared state

The minimal example simulates inference. Real models take time to load. You never load a model per request. That would kill performance and exhaust memory. You load the model once at startup, wrap it in shared state, and pass that state to every handler.

axum provides the State extractor for this. You create a struct that holds your model, attach it to the router with with_state, and request it in your handlers.

Models are often large and read-only during inference. You need to share them across multiple async tasks. Arc (Atomic Reference Counted) is the tool for this. It allows multiple threads to read the same data safely. Rc (Reference Counted) is faster but panics if you clone it across threads. Since tokio runs tasks on multiple threads, Arc is the correct choice.

use axum::{extract::State, routing::{get, post}, Json, Router};
use serde::{Deserialize, Serialize};
use std::sync::Arc;

/// Shared application state.
/// Holds the model wrapped in Arc for thread-safe sharing.
struct AppState {
    model: Arc<Model>,
}

/// Placeholder for the actual ML model.
/// In practice, this might wrap a safetensors file or a ggml context.
struct Model {
    /// Weights loaded from disk.
    weights: Vec<f32>,
}

impl Model {
    /// Run inference on a feature vector.
    /// Returns a prediction score.
    fn predict(&self, features: &[f32]) -> f32 {
        // Simulate inference.
        // Zip features with weights, multiply, and sum.
        features
            .iter()
            .zip(&self.weights)
            .map(|(f, w)| f * w)
            .sum()
    }
}

#[derive(Deserialize)]
struct InputData {
    features: Vec<f32>,
}

#[derive(Serialize)]
struct Prediction {
    result: f32,
}

/// Handler for /predict.
/// Receives State and Json extractors.
/// State must be the first extractor in axum handlers.
async fn predict(
    State(state): State<AppState>,
    Json(input): Json<InputData>,
) -> Result<Json<Prediction>, axum::http::StatusCode> {
    // Validate input size matches model expectations.
    // Returning an error status code stops execution and sends a response.
    if input.features.len() != state.model.weights.len() {
        return Err(axum::http::StatusCode::BAD_REQUEST);
    }

    let result = state.model.predict(&input.features);
    Ok(Json(Prediction { result }))
}

/// Health check endpoint.
/// Returns 200 OK to signal the server is alive.
async fn health() -> &'static str {
    "ok"
}

#[tokio::main]
async fn main() {
    // Load model once at startup.
    // Wrap in Arc immediately so it can be shared.
    let model = Arc::new(Model {
        weights: vec![0.5, 0.3, 0.2],
    });

    // Create the state struct.
    let state = AppState { model };

    // Build the router.
    // with_state attaches the state to all routes in the router.
    let app = Router::new()
        .route("/predict", post(predict))
        .route("/health", get(health))
        .with_state(state);

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

Convention aside: In axum, the State extractor must be the first argument in the handler signature. If you put Json before State, the compiler rejects the code. The framework uses the order to resolve extractors. Keep State first to avoid confusion.

The State extractor pulls the shared model from the router context. You access it via state.model. Since model is an Arc, you can call methods on it from any handler without cloning the weights. The validation step checks that the input length matches the model weights. If it doesn't, you return a BAD_REQUEST status. This prevents panics inside the model logic.

Load the model once. Share it with Arc. Never reload it per request.

Pitfalls and compiler signals

Rust catches mistakes at compile time. You'll see these errors often when building ML APIs.

If you forget #[derive(Serialize)] on your response struct, the compiler rejects you with E0277 (the trait Serialize is not implemented for Prediction). You can't send the response. Add the derive macro and the error vanishes.

If you try to use the model variable after moving it into AppState, you get E0382 (use of moved value). The model lives in the state now. Don't try to access it directly in main. Access it through the State extractor in handlers.

A common runtime panic happens when you use Rc instead of Arc in an async context. Rc is not thread-safe. If tokio moves a task to a different thread and you clone an Rc, the program panics with "cannot share across thread boundaries". Always use Arc for shared state in tokio applications.

Another trap is blocking the async runtime. If your model inference is CPU-heavy and runs on the main thread, it blocks other requests. tokio has a limited number of worker threads. A long inference call can starve the server. Use tokio::task::spawn_blocking to offload heavy CPU work to a separate thread pool. This keeps the async runtime responsive.

// Offload heavy inference to a blocking thread.
let result = tokio::task::spawn_blocking(move || {
    state.model.predict(&input.features)
}).await.unwrap();

Trust the error codes. They point to the missing trait or the moved value every time.

Choosing your stack

Rust has several web frameworks. Pick the one that fits your needs.

Use axum when you want a modern, modular API that composes well with the tower ecosystem and handles async state cleanly. It's the community standard for new Rust web projects.

Use actix-web when you need maximum throughput on older hardware and prefer a more imperative, callback-heavy style. It has a mature ecosystem but a steeper learning curve for async patterns.

Use warp when you like a filter-based routing style that feels like chaining operations. It's flexible but can become hard to debug as routes grow complex.

Use a Python framework like FastAPI when your model is deeply coupled to the Python ecosystem and you can't port the inference logic to Rust yet. You can still call Python from Rust using pyo3, but you lose some performance and safety guarantees.

Reach for axum with tokio for new Rust ML serving projects. It's the path of least resistance.

Where to go next