How to Use burn for Deep Learning in Rust

Install the burn crate, define a model struct with the Module derive macro, and call forward to run inference.

When Python feels too heavy

You've built a recommendation engine in Python, but deploying it means spinning up a heavy container, managing dependencies, and praying the GIL doesn't choke your throughput. Or you're writing a game engine and want on-device inference without the overhead. You want the math of deep learning with the performance and safety of Rust. That's where burn comes in. It's not just a port of PyTorch; it's a framework designed from the ground up to play nice with Rust's type system while giving you the flexibility to swap backends like changing tires.

Reach for burn when you need Rust's safety without sacrificing the flexibility of a deep learning framework.

The core abstractions

burn treats your model as a composable graph of operations. The framework rests on three pillars: Module, Tensor, and Backend.

A Module is a container for parameters and the logic to transform inputs into outputs. It's the Rust way of saying "this is a neural network layer or a whole network." The Module trait provides methods for saving, loading, and traversing the model structure. You derive Module on your struct, and the macro generates the boilerplate for managing state.

A Tensor is a multi-dimensional array of numbers. It's the currency of deep learning. In burn, tensors are generic over a backend type. You write Tensor<B>, where B is the backend. This means the same tensor API works on CPU, GPU, or WebAssembly. The type system enforces that you don't mix backends accidentally.

The Backend trait abstracts the hardware. burn provides backends like Autodiff (for training with gradients), Wgpu (for GPU acceleration), and Cuda. You pick a backend at compile time. Your model code stays the same; only the type parameter changes.

Think of burn as a universal translator for math. You write the logic once using Module and Tensor, and burn handles the heavy lifting on whatever hardware you have. The Module holds the state; forward holds the logic. Keep them separate.

Minimal working model

Start with a single weight and a forward pass. This example uses the Autodiff backend, which wraps another backend to track gradients for training.

use burn::backend::Autodiff;
use burn::module::Module;
use burn::tensor::Tensor;

// Define a type alias for the backend.
// Autodiff enables gradient tracking, which is essential for training.
type B = Autodiff;

#[derive(Module, Debug)]
struct SimpleModel {
    // Parameters are tensors that persist across forward passes.
    // The derive macro handles saving and loading this field.
    weight: Tensor<B>,
}

impl SimpleModel {
    fn new() -> Self {
        // Initialize weight with a fixed value for reproducibility.
        // Shape () represents a scalar tensor.
        Self {
            weight: Tensor::from_floats(&[2.0], ()),
        }
    }

    fn forward(&self, input: Tensor<B>) -> Tensor<B> {
        // The forward pass computes the output from the input.
        // This operation is tracked by Autodiff for backpropagation.
        self.weight * input
    }
}

fn main() {
    let model = SimpleModel::new();
    // Create input tensor with shape (1,).
    let input = Tensor::from_floats(&[5.0], ());
    
    // forward takes ownership of the input tensor.
    let output = model.forward(input);
    println!("Output: {:?}", output);
}

When you compile this, burn checks that your tensor operations match up. The * operator requires both operands to be tensors of the same backend type. If you try to mix a Tensor<Autodiff> with a Tensor<Wgpu>, the compiler rejects you with E0277 (trait bound not satisfied). Shapes are checked at runtime. If you multiply a scalar with a vector, burn broadcasts the scalar. If shapes are incompatible, you get a runtime panic.

The forward method takes ownership of the input tensor. Tensors are owned values. Passing one to forward consumes it. Clone the input if you need to reuse it.

Real-world patterns

Production models use a Config struct to separate hyperparameters from initialization. This pattern makes your models reproducible and serializable. You define dimensions and architecture choices in the config, then build the model from it.

use burn::backend::Autodiff;
use burn::config::Config;
use burn::module::Module;
use burn::tensor::Tensor;

type B = Autodiff;

#[derive(Config, Debug, Clone)]
struct LinearConfig {
    in_dim: usize,
    out_dim: usize,
}

#[derive(Module, Debug)]
struct LinearModel<B: Backend> {
    weight: Tensor<B>,
    bias: Tensor<B>,
}

impl LinearConfig {
    fn init<B: Backend>(&self, backend: &B) -> LinearModel<B> {
        LinearModel {
            // Weight shape: (out_dim, in_dim).
            // Zeros initialization is safe for this example.
            weight: Tensor::zeros((self.out_dim, self.in_dim), backend),
            // Bias shape: (out_dim,).
            bias: Tensor::zeros((self.out_dim,), backend),
        }
    }
}

The Config derive macro generates serialization methods. You can save the config to a file and reload it later to recreate the model with the same hyperparameters. This is the community standard. Hardcoding dimensions in new() makes it impossible to change architecture without recompiling.

Models compose naturally. You can nest modules inside other modules. The Module derive macro recursively handles children.

#[derive(Module, Debug)]
struct Network<B: Backend> {
    layer1: LinearModel<B>,
    layer2: LinearModel<B>,
}

impl Network<B: Backend> {
    fn forward(&self, input: Tensor<B>) -> Tensor<B> {
        let x = self.layer1.forward(input);
        self.layer2.forward(x)
    }
}

This composition allows you to build complex architectures from simple blocks. The Module trait provides a children method that iterates over nested modules, which is useful for saving and loading entire networks.

Adopt the Config pattern early. It pays off when you need to serialize hyperparameters or switch backends.

Pitfalls and compiler errors

Tensors move. If you try to use a tensor after passing it to forward, the compiler rejects you with E0382 (use of moved value). Clone the tensor if you need to keep the original.

let input = Tensor::from_floats(&[1.0], ());
let output1 = model.forward(input.clone()); // Clone allows reuse.
let output2 = model.forward(input);         // Original is still valid.

Backend types must match. If you define a function that takes Tensor<B> but call it with a tensor from a different backend, you get a type mismatch. The compiler catches this at compile time. Trust the type system on backends. A mismatched backend type is a compile-time error, not a runtime crash.

Shape mismatches happen at runtime. burn does not enforce shapes at compile time. If you pass a tensor with the wrong shape to a layer, the operation fails when it executes. Validate shapes in your forward pass or use assertions if you need strict guarantees.

fn forward(&self, input: Tensor<B>) -> Tensor<B> {
    // Assert shape matches expectations.
    assert_eq!(input.shape(), Shape::new([self.in_dim]));
    self.weight * input
}

The burn API evolves quickly. Check your crate version against the documentation. Methods like Tensor::random or Tensor::zeros might change signatures between minor versions. Pin your dependency version in Cargo.toml to avoid breakage.

Treat the backend type parameter as a contract. If the types don't match, the compiler stops you before you hit the GPU.

Decision matrix

Use burn when you need a full training loop with automatic differentiation and want to swap backends without rewriting code. Use burn when you are building a model from scratch in Rust and want native integration with the ecosystem. Use burn when you need WebAssembly support for browser-based inference.

Use candle when you only need inference and want a lightweight library with minimal dependencies. Use candle when you are deploying a model to an edge device and every kilobyte counts.

Use tch when you are wrapping existing PyTorch models and don't mind the C++ binding overhead. Use tch when you need access to the full PyTorch ecosystem and can tolerate FFI calls.

Choose burn for portability and training. Choose candle for minimal inference. Choose tch for PyTorch compatibility.

Where to go next