Matrix Multiplication With Rust and Go

avatarmarbelona

🚀 Introduction

I'm not even sure how I ended up here, but I decided to review Go today (I'm a bit Rust-y with that language, LOL).

So, I thought of doing this matrix multiplication exercise. Why? To review how it works under the hood and what approaches can be used.

Matrix multiplication shows up everywhere:

  • Artificial Intelligence → Neural networks basically do Inputs × Weights = Outputs.
  • Computer Graphics → 3D transformations like rotation, scaling, and projection all use matrices.

In this post, I’ll go through three different ways to write a matrix multiplication algorithm and explain each one.


⚠️ Requirements for Matrix Multiplication

Before multiplying two matrices:

  1. Number of columns of Matrix A must equal number of rows of Matrix B.
    • If not, multiplication is impossible.
  2. The output matrix C will have rows equal to Matrix A and columns equal to Matrix B.
    • C dimensions: (rows of A) × (columns of B).

🟢 Naive Approach (Triple Nested Loop)

The most straightforward way to multiply two matrices is the triple nested loop. It’s not efficient, but it’s the clearest way to understand what’s going on.

Pseudocode

for i in rows of A:
  for j in cols of B:
    C[i][j] = 0
    for k in shared dimension:
      C[i][j] += A[i][k] * B[k][j]

Step-by-Step Explanation

1. Pick a row from Matrix A

  • i represents the row index in Matrix A.
  • For each i, we look at the entire row of A.

2. Pick a column from Matrix B

  • j represents the column index in Matrix B.
  • For each i, loop through each column of B to compute C[i][j].

3. Initialize C[i][j] to 0

  • Before computing the dot product, ensure C[i][j] starts at zero.

4. Compute the dot product of row A and column B

  • k loops through the shared dimension (number of columns in A or rows in B).
  • Multiply A[i][k] × B[k][j] and add to C[i][j].

5. Repeat for all rows and columns

  • After finishing all k, C[i][j] has the correct value.
  • Move to next column j. Once all columns are done, move to the next row i.

6. Result

  • After finishing all loops, we get the final output matrix C.

Video Demo

  • Visual explanation of the above steps:

Golang Example

package main

import (
    "fmt"
    "log"
)

func multiplyMatrices(A, B [][]int) ([][]int, error) {
    // Validation: number of columns in A must equal number of rows in B
    if len(A) == 0 || len(B) == 0 || len(A[0]) != len(B) {
        return nil, fmt.Errorf(
            "cannot multiply: columns of A (%d) != rows of B (%d)",
            len(A[0]), len(B),
        )
    }

    n, m, p := len(A), len(B[0]), len(B)
    C := make([][]int, n)
    for i := range C {
        C[i] = make([]int, m)
    }

    for i := 0; i < n; i++ {
        for j := 0; j < m; j++ {
            for k := 0; k < p; k++ {
                C[i][j] += A[i][k] * B[k][j]
            }
        }
    }
    return C, nil
}

func main() {
    A := [][]int{{1, 2}, {3, 4}}
    B := [][]int{{5, 6}, {7, 8}}

    C, err := multiplyMatrices(A, B)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Println("Result:", C)
}

Output:

Result: [[19 22] [43 50]]

Rust Example

fn multiply_matrices(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Vec<Vec<i32>> {
    let n = a.len();
    let m = b[0].len();
    let p = b.len();
    let mut c = vec![vec![0; m]; n];

    for i in 0..n {
        for j in 0..m {
            for k in 0..p {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    c
}

fn main() {
    let a = vec![vec![1, 2], vec![3, 4]];
    let b = vec![vec![5, 6], vec![7, 8]];
    let c = multiply_matrices(&a, &b);
    println!("Result: {:?}", c);
}

Output:

Result: [[19, 22], [43, 50]]

🔵 Strassen's Algorithm (Divide and Conquer)

80-20 of Strassen’s Algorithm: The 20% Key Ideas You Really Need to Know

Divide-and-Conquer

Split each matrix into four equal quadrants:

A[A11, A12; A21, A22]
B[B11, B12; B21, B22]

This recursive structure is the heart of the algorithm.

Reduce the number of multiplications

Instead of doing 8 multiplications (A11B11, A12B21, etc.), Strassen uses 7 carefully chosen multiplications (P1–P7) and some additions/subtractions. This is what reduces the asymptotic complexity from O(n³)O(n^2.81).

Combine submatrices

The 7 products are recombined using simple addition/subtraction to form the resulting quadrants: C11, C12, C21, C22.

Base case

For small matrices (like 2×2 or 1×1), just use normal multiplication.

Trade-offs

  • More additions/subtractions → slightly more memory and bookkeeping.
  • Slight numerical instability.
  • For small n, standard multiplication is fast
  • Doesn't handle odd-sized matrices: This implementation assumes the matrix dimensions are powers of two.

Strassen's algorithm is a more efficient, recursive method for matrix multiplication. It follows a divide-and-conquer strategy to reduce the number of required multiplications, making it faster for large matrices.

How It Works

Instead of the standard 8 multiplications for a 2x2 matrix, Strassen's algorithm cleverly uses only 7. This reduction has a significant impact on performance as the matrix size increases.

The algorithm involves these main steps:

  1. Divide: Split the input matrices A and B into four equal-sized submatrices.
  2. Conquer: Compute 7 matrix products recursively (the "Strassen products").
  3. Combine: Combine the results of the 7 products to form the resulting matrix C.

Pseudocode

FUNCTION Strassen(A, B):
    n ← size of matrix A

    IF n ≤ 2 THEN
        RETURN standard_matrix_multiplication(A, B)
    END IF

    mid ← n / 2

    // Partition A into submatrices
    A11 ← top-left submatrix of A
    A12 ← top-right submatrix of A
    A21 ← bottom-left submatrix of A
    A22 ← bottom-right submatrix of A

    // Partition B into submatrices
    B11 ← top-left submatrix of B
    B12 ← top-right submatrix of B
    B21 ← bottom-left submatrix of B
    B22 ← bottom-right submatrix of B

    // Compute the 7 Strassen products (recursive calls)
    P1 ← Strassen(A11, B12 - B22)               // P1 = A11 * (B12 - B22)
    P2 ← Strassen(A11 + A12, B22)               // P2 = (A11 + A12) * B22
    P3 ← Strassen(A21 + A22, B11)               // P3 = (A21 + A22) * B11
    P4 ← Strassen(A22, B21 - B11)               // P4 = A22 * (B21 - B11)
    P5 ← Strassen(A11 + A22, B11 + B22)        // P5 = (A11 + A22) * (B11 + B22)
    P6 ← Strassen(A12 - A22, B21 + B22)        // P6 = (A12 - A22) * (B21 + B22)
    P7 ← Strassen(A11 - A21, B11 + B12)        // P7 = (A11 - A21) * (B11 + B12)

    // Compute the resulting quadrants of C
    C11 ← P5 + P4 - P2 + P6                     // top-left quadrant
    C12 ← P1 + P2                               // top-right quadrant
    C21 ← P3 + P4                               // bottom-left quadrant
    C22 ← P5 + P1 - P3 - P7                     // bottom-right quadrant

    // Combine quadrants into full matrix
    C ← [ [C11, C12],
          [C21, C22] ]

    RETURN C
END FUNCTION

Golang Example

package main

import "fmt"

// Function to add two matrices
func add(A, B [][]int) [][]int {
	n := len(A)
	C := make([][]int, n)
	for i := 0; i < n; i++ {
		C[i] = make([]int, n)
		for j := 0; j < n; j++ {
			C[i][j] = A[i][j] + B[i][j]
		}
	}
	return C
}

// Function to subtract two matrices
func subtract(A, B [][]int) [][]int {
	n := len(A)
	C := make([][]int, n)
	for i := 0; i < n; i++ {
		C[i] = make([]int, n)
		for j := 0; j < n; j++ {
			C[i][j] = A[i][j] - B[i][j]
		}
	}
	return C
}

// Strassen's algorithm for matrix multiplication
func strassen(A, B [][]int) [][]int {
	n := len(A)

	// Base case
	if n == 1 {
		C := make([][]int, 1)
		C[0] = make([]int, 1)
		C[0][0] = A[0][0] * B[0][0]
		return C
	}

	// New size for submatrices
	newSize := n / 2
	A11 := make([][]int, newSize)
	A12 := make([][]int, newSize)
	A21 := make([][]int, newSize)
	A22 := make([][]int, newSize)
	B11 := make([][]int, newSize)
	B12 := make([][]int, newSize)
	B21 := make([][]int, newSize)
	B22 := make([][]int, newSize)

	for i := 0; i < newSize; i++ {
		A11[i] = make([]int, newSize)
		A12[i] = make([]int, newSize)
		A21[i] = make([]int, newSize)
		A22[i] = make([]int, newSize)
		B11[i] = make([]int, newSize)
		B12[i] = make([]int, newSize)
		B21[i] = make([]int, newSize)
		B22[i] = make([]int, newSize)

		for j := 0; j < newSize; j++ {
			A11[i][j] = A[i][j]
			A12[i][j] = A[i][j+newSize]
			A21[i][j] = A[i+newSize][j]
			A22[i][j] = A[i+newSize][j+newSize]
			B11[i][j] = B[i][j]
			B12[i][j] = B[i][j+newSize]
			B21[i][j] = B[i+newSize][j]
			B22[i][j] = B[i+newSize][j+newSize]
		}
	}

	// Recursive calls for Strassen's formulas
	P1 := strassen(A11, subtract(B12, B22))
	P1 := strassen(A11, subtract(B12, B22))
	P2 := strassen(add(A11, A12), B22)
	P3 := strassen(add(A21, A22), B11)
	P4 := strassen(A22, subtract(B21, B11))
	P5 := strassen(add(A11, A22), add(B11, B22))
	P6 := strassen(subtract(A12, A22), add(B21, B22))
	P7 := strassen(subtract(A11, A21), add(B11, B12))

	// Combining the results
	C11 := add(subtract(add(P5, P4), P2), P6)
	C12 := add(P1, P2)
	C21 := add(P3, P4)
	C22 := subtract(subtract(add(P5, P1), P3), P7)

	// Final result matrix
	C := make([][]int, n)
	for i := 0; i < newSize; i++ {
		C[i] = make([]int, n)
		C[i+newSize] = make([]int, n)
		for j := 0; j < newSize; j++ {
			C[i][j] = C11[i][j]
			C[i][j+newSize] = C12[i][j]
			C[i+newSize][j] = C21[i][j]
			C[i+newSize][j+newSize] = C22[i][j]
		}
	}
	return C
}

func main() {
	A := [][]int{{1, 2}, {3, 4}}
	B := [][]int{{5, 6}, {7, 8}}
	C := strassen(A, B)
	fmt.Println("Result:", C)
}

🟡 Parallelization

Matrix multiplication is a highly parallelizable problem. Each element of the output matrix C can be computed independently. This means we can use multiple threads or goroutines to compute different parts of the matrix simultaneously, which can lead to significant performance improvements on multi-core processors.

How It Works

The basic idea is to divide the work of the outer loops among multiple threads. For example, we can assign each thread a range of rows from matrix A to process. Each thread will then compute the corresponding rows of the output matrix C.

Golang Example (using Goroutines)

package main

import (
 "fmt"
 "sync"
)

func multiplyMatricesParallel(A, B [][]int) ([][]int, error) {
 if len(A) == 0 || len(B) == 0 || len(A[0]) != len(B) {
  return nil, fmt.Errorf(
  	"cannot multiply: columns of A (%d) != rows of B (%d)",
  	len(A[0]), len(B),
  )
 }

 n, m, p := len(A), len(B[0]), len(B)
 C := make([][]int, n)
 for i := range C {
  C[i] = make([]int, m)
 }

 var wg sync.WaitGroup
 wg.Add(n)

 for i := 0; i < n; i++ {
  go func(row int) {
  	defer wg.Done()
  	for j := 0; j < m; j++ {
  		for k := 0; k < p; k++ {
  			C[row][j] += A[row][k] * B[k][j]
  		}
  	}
  }(i)
 }

 wg.Wait()
 return C, nil
}

Rust Example (using Rayon)

To run this example, you'll need to add the rayon crate to your Cargo.toml:

[dependencies]
rayon = "1.5"
use rayon::prelude::*;

fn multiply_matrices_parallel(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Vec<Vec<i32>> {
    let n = a.len();
    let m = b[0].len();
    let p = b.len();
    let mut c = vec![vec![0; m]; n];

    c.par_iter_mut().enumerate().for_each(|(i, row)| {
        for j in 0..m {
            for k in 0..p {
                row[j] += a[i][k] * b[k][j];
            }
        }
    });

    c
}

🤔 What's Next?

Now that we’ve covered the naive approach, Strassen’s algorithm, and basic parallelization, here are some directions you and I can explore to level up our matrix multiplication game:

1. Benchmarking & Performance Profiling

  • Try measuring how fast the naive vs Strassen vs parallel versions are.
  • In Rust, we can use the criterion crate for precise benchmarking.
  • In Go, we can use testing.B for benchmark tests.

2. GPU Acceleration

  • The current approach in the code runs on CPU, but if you compare that to production-level matrix multiplication like in NumPy, it’s far behind since those implementations use GPUs — which are perfect for this kind of problem. For now, I haven’t tried them yet, but my idea is we could explore using CUDA.
  • Rust → wgpu or cuda bindings.
  • Go → gorgonia or custom CUDA/OpenCL bindings.