All articles
Implementing a Vector Database in Rust
Collections & Data Structures

Implementing a Vector Database in Rust

Unlike traditional databases that store scalar data (like integers, strings, etc.), vector databases are designed to efficiently store and…

By Luis SoaresJanuary 21, 2024Original on Medium

Unlike traditional databases that store scalar data (like integers, strings, etc.), vector databases are designed to efficiently store and retrieve vector data — collections of numerical values representing points in a multi-dimensional space.

This article will explore how to implement a basic vector database in Rust.

Let’s dive right in! 🦀

What is a Vector Database?

A vector database is a type of database that is optimized for storing and querying vectors, which are arrays of numbers representing points in a high-dimensional space. These databases are essential in applications where similarity search in large datasets is a key operation, such as in recommendation systems, image retrieval, and natural language processing.

Key concepts in vector databases include:

  1. Vector Representation: Vectors in these databases represent data points. For instance, in image recognition, an image might be represented as a high-dimensional vector where each dimension corresponds to a feature of the image.
  2. Distance Metrics: To retrieve similar vectors, the database needs a way to quantify how ‘close’ or ‘similar’ two vectors are. Common metrics include Euclidean distance, Manhattan distance, and cosine similarity.
  3. Indexing and Search Algorithms: Efficient search in high-dimensional spaces is a challenging problem. Vector databases often employ specialized indexing strategies to speed up query times, such as KD-trees, R-trees, or hashing-based approaches.

Implementing a Vector Database in Rust

Step 1: Setting Up the Rust Environment

Before we start coding, ensure you have Rust installed. Rust’s package manager, Cargo, makes it easy to set up a new project:

cargo new vector_db
cd vector_db

Step 2: Defining the Vector Type

In Rust, we can define a vector as a fixed-size array or use dynamic arrays from the standard library. For simplicity, let’s use fixed-size arrays ([f32; N]) where N is the dimension of the vector:

type Vector = [f32; 3]; // Example for 3D vectors

Step 3: Creating the Database Structure

We’ll create a struct VectorDB that will act as our database:

struct VectorDB {
    vectors: Vec<Vector>,
}

Step 4: Implementing Basic Operations

Now, let’s add methods to add and retrieve vectors:

impl VectorDB {
    fn new() -> Self {
        VectorDB { vectors: Vec::new() }
    }

   fn add_vector(&mut self, vector: Vector) {
          self.vectors.push(vector);
      }

   fn get_vector(&self, index: usize) -> Option<&Vector> {
        self.vectors.get(index)
   }
}

Step 5: Adding a Search Function

To find the vector closest to a given query vector, we’ll implement a simple linear search based on Euclidean distance:

impl VectorDB {
    // Existing methods...

fn find_closest(&self, query: Vector) -> Option<&Vector> {
        self.vectors.iter().min_by(|&a, &b| {
            let distance_a = VectorDB::euclidean_distance(&query, a);
            let distance_b = VectorDB::euclidean_distance(&query, b);
            distance_a.partial_cmp(&distance_b).unwrap()
        })
    }
    fn euclidean_distance(a: &Vector, b: &Vector) -> f32 {
        a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
    }
}

Practice what you learned

Reinforce this article with hands-on coding exercises and AI-powered feedback.

View all exercises

Step 6: Testing Our Vector Database

Finally, let’s test our database in the main function:

fn main() {
    let mut db = VectorDB::new();
    db.add_vector([1.0, 2.0, 3.0]);
    db.add_vector([4.0, 5.0, 6.0]);

// Retrieving and printing a vector
    if let Some(vector) = db.get_vector(0) {
        println!("Vector at index 0: {:?}", vector);
    }
    // Finding and printing the closest vector
    if let Some(closest) = db.find_closest([2.0, 3.0, 4.0]) {
        println!("Closest vector: {:?}", closest);
    }
}

Implementing more efficient indexing

We’ll focus on a basic yet effective indexing strategy known as KD-Tree (K-dimensional Tree), a space-partitioning data structure used for organizing points in a K-dimensional space. KD-Trees are particularly useful for efficient nearest neighbor searches in multi-dimensional keys.

Step 1: Understanding KD-Trees

A KD-Tree is a binary tree in which every node is a K-dimensional point. Every non-leaf node generates a splitting hyperplane that divides the space into two halves. Points to the left of this hyperplane are represented by the left subtree, while points to the right are represented by the right subtree.

Key Concepts:

  1. Splitting Dimension: At each level of the tree, a different dimension is chosen for splitting the data. The choice of dimension typically cycles through all dimensions.
  2. Median Finding: To split the points at each node, the median along the chosen dimension is selected. This approach helps balance the tree.
  3. Recursive Partitioning: The process continues recursively, resulting in a tree where each leaf node represents a point in the space.

Step 2: Defining the KD-Tree Structure in Rust

First, define the structure of the KD-Tree. You’ll need to represent both internal nodes (with splitting information) and leaf nodes (with actual vector data):

enum KDTreeNode {
    Leaf(Vector),
    Internal {
        left: Box<KDTreeNode>,
        right: Box<KDTreeNode>,
        split_value: f32,
        split_dimension: usize,
    },
}

struct KDTree {
    root: KDTreeNode,
}

Step 3: Building the KD-Tree

The process involves sorting points based on the splitting dimension and recursively building the tree:

impl KDTree {
    fn build(points: Vec<Vector>, depth: usize) -> KDTreeNode {
        if points.len() == 1 {
            return KDTreeNode::Leaf(points[0]);
        }

    let dim = depth % K; // K is the number of dimensions
        let mut sorted_points = points;
        sorted_points.sort_by(|a, b| a[dim].partial_cmp(&b[dim]).unwrap());
        let median_idx = sorted_points.len() / 2;
        let median_value = sorted_points[median_idx][dim];
        KDTreeNode::Internal {
            left: Box::new(KDTree::build(sorted_points[..median_idx].to_vec(), depth + 1)),
            right: Box::new(KDTree::build(sorted_points[median_idx..].to_vec(), depth + 1)),
            split_value: median_value,
            split_dimension: dim,
        }
    }
}

Step 4: Implementing Search in the KD-Tree

Searching for the nearest neighbor involves traversing the tree, checking distances, and possibly backtracking:

impl KDTree {
    fn nearest_neighbor(&self, query: &Vector) -> Option<&Vector> {
        self.nearest(query, &self.root, None, f32::MAX)
    }

fn nearest<'a>(&'a self, query: &Vector, node: &'a KDTreeNode, best: Option<&'a Vector>, best_dist: f32) -> Option<&'a Vector> {
        match node {
            KDTreeNode::Leaf(point) => {
                let dist = VectorDB::euclidean_distance(query, point);
                if dist < best_dist {
                    Some(point)
                } else {
                    best
                }
            },
            KDTreeNode::Internal { left, right, split_value, split_dimension } => {
                let next_node = if query[*split_dimension] < *split_value { left } else { right };
                let other_node = if query[*split_dimension] < *split_value { right } else { left };
                let updated_best = self.nearest(query, next_node, best, best_dist);
                let updated_best_dist = updated_best.map_or(best_dist, |v| VectorDB::euclidean_distance(query, v));
                if (query[*split_dimension] - split_value).abs() < updated_best_dist {
                    self.nearest(query, other_node, updated_best, updated_best_dist)
                } else {
                    updated_best
                }
            }
        }
    }
}

Step 5: Testing the KD-Tree

Finally, integrate the KD-Tree into your vector database and test it:

fn main() {
    let points = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [2.0, 3.0, 4.0], ...];
    let kd_tree = KDTree::build(points, 0);

if let Some(nearest) = kd_tree.nearest_neighbor(&[3.0, 3.0, 3.0]) {
        println!("Nearest neighbor: {:?}", nearest);
    }
}

That’s all, fellow Rustaceans! 🦀

Practice what you learned

Reinforce this article with hands-on coding exercises and AI-powered feedback.

View all exercises

Want to practice Rust hands-on?

Go beyond reading — solve interactive exercises with AI-powered code review on Rust Lab.