Matrix Multiplication With Rust and Go

🚀 Introduction
I'm not even sure how I ended up here, but I decided to review Go today (I'm a bit Rust-y with that language, LOL).
So, I thought of doing this matrix multiplication exercise. Why? To review how it works under the hood and what approaches can be used.
Matrix multiplication shows up everywhere:
- Artificial Intelligence → Neural networks basically do
Inputs × Weights = Outputs
. - Computer Graphics → 3D transformations like rotation, scaling, and projection all use matrices.
In this post, I’ll go through three different ways to write a matrix multiplication algorithm and explain each one.
⚠️ Requirements for Matrix Multiplication
Before multiplying two matrices:
- Number of columns of Matrix A must equal number of rows of Matrix B.
- If not, multiplication is impossible.
- The output matrix C will have rows equal to Matrix A and columns equal to Matrix B.
C
dimensions:(rows of A) × (columns of B)
.
🟢 Naive Approach (Triple Nested Loop)
The most straightforward way to multiply two matrices is the triple nested loop. It’s not efficient, but it’s the clearest way to understand what’s going on.
Pseudocode
for i in rows of A:
for j in cols of B:
C[i][j] = 0
for k in shared dimension:
C[i][j] += A[i][k] * B[k][j]
Step-by-Step Explanation
1. Pick a row from Matrix A
i
represents the row index in Matrix A.- For each
i
, we look at the entire row of A.
2. Pick a column from Matrix B
j
represents the column index in Matrix B.- For each
i
, loop through each column of B to computeC[i][j]
.
3. Initialize C[i][j] to 0
- Before computing the dot product, ensure
C[i][j]
starts at zero.
4. Compute the dot product of row A and column B
k
loops through the shared dimension (number of columns in A or rows in B).- Multiply
A[i][k] × B[k][j]
and add toC[i][j]
.
5. Repeat for all rows and columns
- After finishing all
k
,C[i][j]
has the correct value. - Move to next column
j
. Once all columns are done, move to the next rowi
.
6. Result
- After finishing all loops, we get the final output matrix C.
Video Demo
- Visual explanation of the above steps:
Golang Example
package main
import (
"fmt"
"log"
)
func multiplyMatrices(A, B [][]int) ([][]int, error) {
// Validation: number of columns in A must equal number of rows in B
if len(A) == 0 || len(B) == 0 || len(A[0]) != len(B) {
return nil, fmt.Errorf(
"cannot multiply: columns of A (%d) != rows of B (%d)",
len(A[0]), len(B),
)
}
n, m, p := len(A), len(B[0]), len(B)
C := make([][]int, n)
for i := range C {
C[i] = make([]int, m)
}
for i := 0; i < n; i++ {
for j := 0; j < m; j++ {
for k := 0; k < p; k++ {
C[i][j] += A[i][k] * B[k][j]
}
}
}
return C, nil
}
func main() {
A := [][]int{{1, 2}, {3, 4}}
B := [][]int{{5, 6}, {7, 8}}
C, err := multiplyMatrices(A, B)
if err != nil {
log.Fatal(err)
}
fmt.Println("Result:", C)
}
Output:
Result: [[19 22] [43 50]]
Rust Example
fn multiply_matrices(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Vec<Vec<i32>> {
let n = a.len();
let m = b[0].len();
let p = b.len();
let mut c = vec![vec![0; m]; n];
for i in 0..n {
for j in 0..m {
for k in 0..p {
c[i][j] += a[i][k] * b[k][j];
}
}
}
c
}
fn main() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![5, 6], vec![7, 8]];
let c = multiply_matrices(&a, &b);
println!("Result: {:?}", c);
}
Output:
Result: [[19, 22], [43, 50]]
🔵 Strassen's Algorithm (Divide and Conquer)
80-20 of Strassen’s Algorithm: The 20% Key Ideas You Really Need to Know
Divide-and-Conquer
Split each matrix into four equal quadrants:
A → [A11, A12; A21, A22]
B → [B11, B12; B21, B22]
This recursive structure is the heart of the algorithm.
Reduce the number of multiplications
Instead of doing 8 multiplications (A11B11, A12B21, etc.), Strassen uses 7 carefully chosen multiplications (P1–P7) and some additions/subtractions. This is what reduces the asymptotic complexity from O(n³)
→ O(n^2.81)
.
Combine submatrices
The 7 products are recombined using simple addition/subtraction to form the resulting quadrants: C11, C12, C21, C22.
Base case
For small matrices (like 2×2 or 1×1), just use normal multiplication.
Trade-offs
- More additions/subtractions → slightly more memory and bookkeeping.
- Slight numerical instability.
- For small n, standard multiplication is fast
- Doesn't handle odd-sized matrices: This implementation assumes the matrix dimensions are powers of two.
Strassen's algorithm is a more efficient, recursive method for matrix multiplication. It follows a divide-and-conquer strategy to reduce the number of required multiplications, making it faster for large matrices.
How It Works
Instead of the standard 8 multiplications for a 2x2 matrix, Strassen's algorithm cleverly uses only 7. This reduction has a significant impact on performance as the matrix size increases.
The algorithm involves these main steps:
- Divide: Split the input matrices A and B into four equal-sized submatrices.
- Conquer: Compute 7 matrix products recursively (the "Strassen products").
- Combine: Combine the results of the 7 products to form the resulting matrix C.
Pseudocode
FUNCTION Strassen(A, B):
n ← size of matrix A
IF n ≤ 2 THEN
RETURN standard_matrix_multiplication(A, B)
END IF
mid ← n / 2
// Partition A into submatrices
A11 ← top-left submatrix of A
A12 ← top-right submatrix of A
A21 ← bottom-left submatrix of A
A22 ← bottom-right submatrix of A
// Partition B into submatrices
B11 ← top-left submatrix of B
B12 ← top-right submatrix of B
B21 ← bottom-left submatrix of B
B22 ← bottom-right submatrix of B
// Compute the 7 Strassen products (recursive calls)
P1 ← Strassen(A11, B12 - B22) // P1 = A11 * (B12 - B22)
P2 ← Strassen(A11 + A12, B22) // P2 = (A11 + A12) * B22
P3 ← Strassen(A21 + A22, B11) // P3 = (A21 + A22) * B11
P4 ← Strassen(A22, B21 - B11) // P4 = A22 * (B21 - B11)
P5 ← Strassen(A11 + A22, B11 + B22) // P5 = (A11 + A22) * (B11 + B22)
P6 ← Strassen(A12 - A22, B21 + B22) // P6 = (A12 - A22) * (B21 + B22)
P7 ← Strassen(A11 - A21, B11 + B12) // P7 = (A11 - A21) * (B11 + B12)
// Compute the resulting quadrants of C
C11 ← P5 + P4 - P2 + P6 // top-left quadrant
C12 ← P1 + P2 // top-right quadrant
C21 ← P3 + P4 // bottom-left quadrant
C22 ← P5 + P1 - P3 - P7 // bottom-right quadrant
// Combine quadrants into full matrix
C ← [ [C11, C12],
[C21, C22] ]
RETURN C
END FUNCTION
Golang Example
package main
import "fmt"
// Function to add two matrices
func add(A, B [][]int) [][]int {
n := len(A)
C := make([][]int, n)
for i := 0; i < n; i++ {
C[i] = make([]int, n)
for j := 0; j < n; j++ {
C[i][j] = A[i][j] + B[i][j]
}
}
return C
}
// Function to subtract two matrices
func subtract(A, B [][]int) [][]int {
n := len(A)
C := make([][]int, n)
for i := 0; i < n; i++ {
C[i] = make([]int, n)
for j := 0; j < n; j++ {
C[i][j] = A[i][j] - B[i][j]
}
}
return C
}
// Strassen's algorithm for matrix multiplication
func strassen(A, B [][]int) [][]int {
n := len(A)
// Base case
if n == 1 {
C := make([][]int, 1)
C[0] = make([]int, 1)
C[0][0] = A[0][0] * B[0][0]
return C
}
// New size for submatrices
newSize := n / 2
A11 := make([][]int, newSize)
A12 := make([][]int, newSize)
A21 := make([][]int, newSize)
A22 := make([][]int, newSize)
B11 := make([][]int, newSize)
B12 := make([][]int, newSize)
B21 := make([][]int, newSize)
B22 := make([][]int, newSize)
for i := 0; i < newSize; i++ {
A11[i] = make([]int, newSize)
A12[i] = make([]int, newSize)
A21[i] = make([]int, newSize)
A22[i] = make([]int, newSize)
B11[i] = make([]int, newSize)
B12[i] = make([]int, newSize)
B21[i] = make([]int, newSize)
B22[i] = make([]int, newSize)
for j := 0; j < newSize; j++ {
A11[i][j] = A[i][j]
A12[i][j] = A[i][j+newSize]
A21[i][j] = A[i+newSize][j]
A22[i][j] = A[i+newSize][j+newSize]
B11[i][j] = B[i][j]
B12[i][j] = B[i][j+newSize]
B21[i][j] = B[i+newSize][j]
B22[i][j] = B[i+newSize][j+newSize]
}
}
// Recursive calls for Strassen's formulas
P1 := strassen(A11, subtract(B12, B22))
P1 := strassen(A11, subtract(B12, B22))
P2 := strassen(add(A11, A12), B22)
P3 := strassen(add(A21, A22), B11)
P4 := strassen(A22, subtract(B21, B11))
P5 := strassen(add(A11, A22), add(B11, B22))
P6 := strassen(subtract(A12, A22), add(B21, B22))
P7 := strassen(subtract(A11, A21), add(B11, B12))
// Combining the results
C11 := add(subtract(add(P5, P4), P2), P6)
C12 := add(P1, P2)
C21 := add(P3, P4)
C22 := subtract(subtract(add(P5, P1), P3), P7)
// Final result matrix
C := make([][]int, n)
for i := 0; i < newSize; i++ {
C[i] = make([]int, n)
C[i+newSize] = make([]int, n)
for j := 0; j < newSize; j++ {
C[i][j] = C11[i][j]
C[i][j+newSize] = C12[i][j]
C[i+newSize][j] = C21[i][j]
C[i+newSize][j+newSize] = C22[i][j]
}
}
return C
}
func main() {
A := [][]int{{1, 2}, {3, 4}}
B := [][]int{{5, 6}, {7, 8}}
C := strassen(A, B)
fmt.Println("Result:", C)
}
🟡 Parallelization
Matrix multiplication is a highly parallelizable problem. Each element of the output matrix C
can be computed independently. This means we can use multiple threads or goroutines to compute different parts of the matrix simultaneously, which can lead to significant performance improvements on multi-core processors.
How It Works
The basic idea is to divide the work of the outer loops among multiple threads. For example, we can assign each thread a range of rows from matrix A
to process. Each thread will then compute the corresponding rows of the output matrix C
.
Golang Example (using Goroutines)
package main
import (
"fmt"
"sync"
)
func multiplyMatricesParallel(A, B [][]int) ([][]int, error) {
if len(A) == 0 || len(B) == 0 || len(A[0]) != len(B) {
return nil, fmt.Errorf(
"cannot multiply: columns of A (%d) != rows of B (%d)",
len(A[0]), len(B),
)
}
n, m, p := len(A), len(B[0]), len(B)
C := make([][]int, n)
for i := range C {
C[i] = make([]int, m)
}
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func(row int) {
defer wg.Done()
for j := 0; j < m; j++ {
for k := 0; k < p; k++ {
C[row][j] += A[row][k] * B[k][j]
}
}
}(i)
}
wg.Wait()
return C, nil
}
Rust Example (using Rayon)
To run this example, you'll need to add the rayon
crate to your Cargo.toml
:
[dependencies]
rayon = "1.5"
use rayon::prelude::*;
fn multiply_matrices_parallel(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Vec<Vec<i32>> {
let n = a.len();
let m = b[0].len();
let p = b.len();
let mut c = vec![vec![0; m]; n];
c.par_iter_mut().enumerate().for_each(|(i, row)| {
for j in 0..m {
for k in 0..p {
row[j] += a[i][k] * b[k][j];
}
}
});
c
}
🤔 What's Next?
Now that we’ve covered the naive approach, Strassen’s algorithm, and basic parallelization, here are some directions you and I can explore to level up our matrix multiplication game:
1. Benchmarking & Performance Profiling
- Try measuring how fast the naive vs Strassen vs parallel versions are.
- In Rust, we can use the
criterion
crate for precise benchmarking. - In Go, we can use
testing.B
for benchmark tests.
2. GPU Acceleration
- The current approach in the code runs on CPU, but if you compare that to production-level matrix multiplication like in NumPy, it’s far behind since those implementations use GPUs — which are perfect for this kind of problem. For now, I haven’t tried them yet, but my idea is we could explore using CUDA.
- Rust →
wgpu
orcuda
bindings. - Go →
gorgonia
or custom CUDA/OpenCL bindings.