Competitive programming in Haskell: Infinite 2D array, Level 4
In a previous post, I challenged you to solve Infinite 2D Array using Haskell. After deriving a formula for \(F_{x,y}\) that involves only a linear number of terms, last time we discussed how to efficiently calculate Fibonacci numbers and binomial coefficients modulo a prime. Today, we’ll finally see some actual Haskell code for solving this problem.
The code is not very long, and seems rather simple, but what it doesn’t show is the large amount of time and effort I spent trying different versions until I figured out how to make it fast enough. Later in the post I will share some lessons learned.
Modular arithmetic
When a problem requires a fixed modulus like this, I typically prefer using a newtype M
with a Num
instance that does all operations using modular arithmetic, as explained in this post. However, that approach has a big downside: we cannot (easily) store our own newtype
values in an unboxed array (UArray
), since that requires defining a bunch of instances by hand. And the speedup we get from using unboxed vs boxed arrays is significant, especially for this problem.
So instead I just made some standalone functions to do arithmetic modulo \(10^9 + 7\):
p :: Int
p = 10^9 + 7
padd :: Int -> Int -> Int
padd x y = (x + y) `mod` p
pmul :: Int -> Int -> Int
pmul x y = (x*y) `mod` p
What about modular inverses? At first I defined a modular inverse operation based on my own implementation of the extended Euclidean Algorithm, but at some point I did some benchmarking and realized that my egcd
function was taking up the majority of the runtime, so I replaced it with a highly optimized version taken from the arithmoi package. Rather than pasting in the code I will let you go look at it yourself if you’re interested.
Given the efficient extendedGCD
, we can now define modular inversion like so:
inv :: Int -> Int
inv a = y `mod` p
where
(_,_,y) = extendedGCD p a
Fibonacci numbers and factorials
We want to compute Fibonacci numbers and factorials modulo \(p = 10^9 + 7\) and put them in tables so we can quickly look them up later. The simplest way to do this is to generate an infinite list of each (via the standard knot-tying approach in the case of Fibonacci numbers, and scanl’
in the case of factorials) and then put them into an appropriate UArray
:
fibList :: [Int]
fibList = 0 : 1 : zipWith padd fibList (tail fibList)
fib :: UArray Int Int
fib = listArray (0, 10^6) fibList
fac :: UArray Int Int
fac = listArray (0, 2*10^6) (scanl' pmul 1 [1 ..])
I should mention that at one point I defined fib
this way instead:
fib' :: Array Int Int
fib' = array (0, 10^6) $ (0,0):(1,1):[ (i, fib!(i-1) `padd` fib!(i-2)) | i <- [2 .. 10^6]]
This takes advantage of the fact that unboxed arrays are lazy in their values—and can hence be constructed recursively—to directly define the array via dynamic programming. But this version is much slower, and uglier to boot! (If we really needed to initialize an unboxed array using recursion/dynamic programming, we could do that via runSTUArray
, but it would be overkill for this problem.)
Binomial coefficients modulo a prime
We can now efficiently compute binomial coefficients using fac
and inv
, like so:
mbinom :: Int -> Int -> Int
mbinom m k = (fac!m) `pdiv` ((fac!k) `pmul` (fac!(m-k)))
As mentioned in a previous post, this only works since the modulus is prime; otherwise, more complex techniques would be needed.
We could also precompute all inverse factorials, and then we can get rid of the pdiv
call in mbinom
(remember that pmul
is very fast, whereas pdiv
has to call extendedGCD
):
ifac :: UArray Int Int
ifac = listArray (0, 2*10^6) (scanl' pdiv 1 [1 ..])
mbinom' :: Int -> Int -> Int
mbinom' m k = (fac!m) `pmul` (ifac!k) `pmul` (ifac!(m-k))
For this particular problem, it doesn’t make much difference either way, since the total number of pdiv
calls stays about the same. But this can be an important optimization for problems where the number of calls to mbinom
will be much larger than the max size of its arguments.
Putting it all together
Finally, we can put all the pieces together to solve the problem like so:
main = interact $ words >>> map read >>> solve >>> show
solve :: [Int] -> Int
solve [x,y] =
sum [ (fib!k) `pmul` mbinom (x-k+y-1) (x-k) | k <- [1 .. x]] `padd`
sum [ (fib!k) `pmul` mbinom (y-k+x-1) (y-k) | k <- [1 .. y]]
Lessons learned
The fact that the above code is fairly short (besides extendedGCD
) belies the amount of time I spent trying to optimize it. Here are some things I learned while benchmarking and comparing different versions.
First, we should try really, really hard to use unboxed arrays (UArray
) instead of boxed arrays (Array
). Boxed arrays have one distinct advantage, which is that they can be constructed lazily, and hence recursively. This helps a lot for dynamic programming problems (which I have a lot to write about at some point in the future). But otherwise, they introduce a ton of overhead.
In this particular problem, committing to use UArray
meant (1) using explicit modular operations like padd
and pmul
instead of a newtype, and (2) constructing the fib
array by calculating a list of values and then using it to construct the array, instead of defining the array via recursion/DP.
The optimized implementation of extendedGCD
makes a big difference, too, which makes sense: a majority of the computation time for this problem is spent running it (via pdiv
). I don’t know what general lesson to draw from this other than affirm the value of profiling to figure out where optimizations would help the most.
I tried a whole bunch of other things which turn out to make very little difference in comparison to the above optimizations. For example:
-
Optimizing
padd
andpmul
to conditionally avoid an expensivemod
operation when the arguments are not too big: this sped things up a tiny bit but not much. -
Rewriting everything in terms of a tail-recursive loop that computes the required Fibonacci numbers and binomial coefficients incrementally, and hence does not require any lookup arrays:
solve' :: [Int] -> Int
solve' [x,y] = go x y 0 1 0 1 (mbinom (x+y-2) (x-1)) `padd`
go y x 0 1 0 1 (mbinom (x+y-2) (y-1))
where
-- Invariants:
-- s = sum so far
-- k = current k
-- f' = F_{k-1}
-- f = F_k
-- bx = binom (x-k+y-1) (x-k)
go x y !s !k !f' !f !bx
| k > x = s
| otherwise
= go x y (s `padd` (bx `pmul` f)) (k+1)
f (f' `padd` f) ((bx `pdiv` (x-k+y-1)) `pmul` (x-k))
mbinom' n k = fac' n `pdiv` (fac' k `pmul` fac' (n-k))
fac' k = foldl' pmul 1 [1 .. k]
This version is super ugly and erases most of the benefits of using Haskell in the first place, so I am happy to report that it runs in exactly the same amount of time as the solution I described earlier.