Making Fibonacci Fast (And Proving It's Correct)
In my last post, I showed off a simple recursive Fibonacci implementation. While it clearly expresses the mathematical definition, it has a serious performance problem that becomes obvious pretty quickly.
The Problem with Naive Recursion
Let's look at our original definition again:
def fib : Nat → Nat
| 0 => 0
| 1 => 1
| n + 2 => fib n + fib (n + 1)
Evaluating fib 10 is plenty fast:
#eval fib 10
Evaluating fib 30, however, takes a noticable pause:
#eval fib 30
Evaluating fib 40 would take too long.
The issue is that this definition computes the same Fibonacci numbers over and over again. To calculate fib 5, we need fib 3 and fib 4. But to calculate fib 4, we need fib 2 and fib 3 again. The amount of redundant computation grows exponentially.
An Iterative Solution
We can solve this with an iterative approach that computes each Fibonacci number exactly once:
def fibIter (n : Nat) : Nat := loop n (0, 1)
where
loop
| 0, (i, _) => i
| k + 1, (i, j) => loop k (j, i + j)
This version works by maintaining a pair (i, j) representing consecutive Fibonacci numbers and updating them as we count down from n to 0. The computation is linear in n rather than exponential.
The result for 10 is computed very quickly:
#eval fibIter 10
So is the result for 30:
#eval fibIter 30
And even the result for 100:
#eval fibIter 100
Proving They're Equivalent
Of course, we want to be absolutely sure that our optimization didn't change the mathematical meaning. Here's where Lean really shines—we can prove that both implementations compute the same function. The first step is to show the correctness of the inner loop.
This statement is a bit tricky.
It uses a loop invariant: at each step, where the Fibonacci number being computed is n - k, the pair contains the correct results.
The premise that k ≤ n is needed to ensure that n - k denotes a sensible index into the sequence.
The proof uses the fun_induction tactic to reason based on the call graph of fibIter.loop, and grind to dispatch all the tedious details.
def fibIterloop_eq_fib :
k ≤ n →
i = fib (n - k) →
j = fib (n - k + 1) →
fibIter.loop k (i, j) = fib n := k:Natn:Nati:Natj:Nat⊢ k ≤ n → i = fib (n - k) → j = fib (n - k + 1) → fibIter.loop k (i, j) = fib n
k:Natn:Nati:Natj:Natk_le_n:k ≤ nhi:i = fib (n - k)hj:j = fib (n - k + 1)⊢ fibIter.loop k (i, j) = fib n
k:Natn:Nati:Natj:Natk_le_n:k ≤ nhi:i = fib (n - k)hj:j = fib (n - k + 1)p:Nat × Nathp:(i, j) = p⊢ fibIter.loop k p = fib n
n:Nati✝:Natsnd✝:Nati:Natj:Natk_le_n:0 ≤ nhi:i = fib (n - 0)hj:j = fib (n - 0 + 1)hp:(i, j) = (i✝, snd✝)⊢ i✝ = fib nn:Natk✝:Nati✝:Natj✝:Natih1✝:∀ {i j : Nat},
k✝ ≤ n → i = fib (n - k✝) → j = fib (n - k✝ + 1) → (i, j) = (j✝, i✝ + j✝) → fibIter.loop k✝ (j✝, i✝ + j✝) = fib ni:Natj:Natk_le_n:k✝.succ ≤ nhi:i = fib (n - k✝.succ)hj:j = fib (n - k✝.succ + 1)hp:(i, j) = (i✝, j✝)⊢ fibIter.loop k✝ (j✝, i✝ + j✝) = fib n
next n:Nati✝:Natsnd✝:Nati:Natj:Natk_le_n:0 ≤ nhi:i = fib (n - 0)hj:j = fib (n - 0 + 1)hp:(i, j) = (i✝, snd✝)⊢ i✝ = fib n All goals completed! 🐙
next ih n:Natk✝:Nati✝:Natj✝:Natih:∀ {i j : Nat},
k✝ ≤ n → i = fib (n - k✝) → j = fib (n - k✝ + 1) → (i, j) = (j✝, i✝ + j✝) → fibIter.loop k✝ (j✝, i✝ + j✝) = fib ni:Natj:Natk_le_n:k✝.succ ≤ nhi:i = fib (n - k✝.succ)hj:j = fib (n - k✝.succ + 1)hp:(i, j) = (i✝, j✝)⊢ fibIter.loop k✝ (j✝, i✝ + j✝) = fib n
n:Natk✝:Nati✝:Natj✝:Natih:∀ {i j : Nat},
k✝ ≤ n → i = fib (n - k✝) → j = fib (n - k✝ + 1) → (i, j) = (j✝, i✝ + j✝) → fibIter.loop k✝ (j✝, i✝ + j✝) = fib ni:Natj:Natk_le_n:k✝.succ ≤ nhi:i = fib (n - k✝.succ)hj:j = fib (n - k✝.succ + 1)hp:(i, j) = (i✝, j✝)⊢ (fib (n - k✝), fib (n - k✝ + 1)) = (j✝, i✝ + j✝)n:Natk✝:Nati✝:Natj✝:Natih:∀ {i j : Nat},
k✝ ≤ n → i = fib (n - k✝) → j = fib (n - k✝ + 1) → (i, j) = (j✝, i✝ + j✝) → fibIter.loop k✝ (j✝, i✝ + j✝) = fib ni:Natj:Natk_le_n:k✝.succ ≤ nhi:i = fib (n - k✝.succ)hj:j = fib (n - k✝.succ + 1)hp:(i, j) = (i✝, j✝)⊢ k✝ ≤ n n:Natk✝:Nati✝:Natj✝:Natih:∀ {i j : Nat},
k✝ ≤ n → i = fib (n - k✝) → j = fib (n - k✝ + 1) → (i, j) = (j✝, i✝ + j✝) → fibIter.loop k✝ (j✝, i✝ + j✝) = fib ni:Natj:Natk_le_n:k✝.succ ≤ nhi:i = fib (n - k✝.succ)hj:j = fib (n - k✝.succ + 1)hp:(i, j) = (i✝, j✝)⊢ (fib (n - k✝), fib (n - k✝ + 1)) = (j✝, i✝ + j✝)n:Natk✝:Nati✝:Natj✝:Natih:∀ {i j : Nat},
k✝ ≤ n → i = fib (n - k✝) → j = fib (n - k✝ + 1) → (i, j) = (j✝, i✝ + j✝) → fibIter.loop k✝ (j✝, i✝ + j✝) = fib ni:Natj:Natk_le_n:k✝.succ ≤ nhi:i = fib (n - k✝.succ)hj:j = fib (n - k✝.succ + 1)hp:(i, j) = (i✝, j✝)⊢ k✝ ≤ n All goals completed! 🐙
With this lemma, we can special-case the base cases of fib to prove the entire statement:
def fibIter_eq_fib : fibIter = fib := ⊢ fibIter = fib
n:Nat⊢ fibIter n = fib n
n:Nat⊢ fibIter.loop n (0, 1) = fib n
match n with
n:Nat⊢ fibIter.loop 1 (0, 1) = fib 1n:Nat⊢ fibIter.loop 0 (0, 1) = fib 0
All goals completed! 🐙
n:Natn':Nat⊢ fibIter.loop (n' + 2) (0, 1) = fib (n' + 2)
n:Natn':Nat⊢ n' + 2 ≤ n' + 2n:Natn':Nat⊢ 0 = fib (n' + 2 - (n' + 2))n:Natn':Nat⊢ 1 = fib (n' + 2 - (n' + 2) + 1) n:Natn':Nat⊢ n' + 2 ≤ n' + 2n:Natn':Nat⊢ 0 = fib (n' + 2 - (n' + 2))n:Natn':Nat⊢ 1 = fib (n' + 2 - (n' + 2) + 1) All goals completed! 🐙
Why This Matters
Having both a clear mathematical specification (fib) and an efficient implementation (fibIter), along with a machine-verified proof that they're equivalent, gives us the best of both worlds. We can reason about the mathematical properties using the simple recursive definition, while actually computing results with the fast iterative version.
This equivalence proof also means that all the properties I proved about fib in my last post—like the fact that every third Fibonacci number is even—automatically apply to fibIter as well. Lean can use the fibIter_eq_fib proof to rewrite any statement about fibIter into one about fib, letting us leverage all our previous mathematical work.