« Competitive programming in Haskell: 2D cross product, part 1 » Competitive programming in Haskell: folding challenge

Competitive programming in Haskell: cycle decomposition with mutable arrays

Posted on July 18, 2020
Tagged , ,

In my previous post I I challenged you to solve Chair Hopping: if a bunch of people permute themselves according to the same rule twice, how many different rules could they be following which would result in the observed final permutation? Or, more formally, given a permutation \(\rho\) on \([1 \dots n]\), how many permutations \(\sigma\) are there such that \(\sigma^2 = \rho\)?

Since this has to do with permutations, it should be unsurprising that cycle decomposition comes into the picture. And we have discussed cycle decomposition of permutations before; using those techniques to decompose the given permutation into cycles should be straightforward, right?

Not so fast!

Here is the code we used previously to compute the size of the cycle containing a given element:

dist :: Perm -> Int -> Int -> Int
dist p i j = length $ takeWhile (/= j) (iterate (p!) i)

cycleLen :: Perm -> Int -> Int
cycleLen p i = succ $ dist p (p!i) i

There’s nothing particularly wrong with this code, and no way to speed it up per se. Computing the distance between \(i\) and \(j\) in permutation \(p\) takes \(O(n)\), since we may have to scan through a significant fraction of the entire permutation if \(i\) and \(j\) are in a large cycle. But this is unavoidable. cycleLen then just uses dist, and if all we want to do is find the length of a single cycle this is unavoidable too.

However, the problem comes when we want to, for example, find the length of the cycle of many elements. cycleLen will take \(O(n)\) for each element we call it on. In the worst case, if the entire permutation consists of one giant cycle, calling cycleLen on every element will take \(O(n^2)\) overall. And this is particularly silly since the work of following the cycle will be entirely repeated every time, only starting from a different place! When \(n = 200\), as in The Power of Substitution, an \(O(n^2)\) algorithm is no big deal; but when \(n = 10^5\) it’s entirely too slow. Using \(10^8\) operations per second as our rule of thumb, we expect an \(O(n^2)\) algorithm on an input with \(n = 10^5\) to take on the order of \((10^5)^2 / 10^8 = 100\) seconds. An input size of \(10^5\) is extremely common in competitive programming problems: not so big that I/O is going to be a huge bottleneck, but big enough that you need to come up with an algorithm faster than \(O(n^2)\) (for example, \(O(n)\) or \(O(n \lg n)\) are both fine).

Permutations and fast cycle decomposition

The idea is to do the work of decomposing a permutation into cycles only once, in \(O(n)\) time, and store the results in a data structure that allows us to look up the needed information quickly. (This general technique of preprocessing some data into a structure allowing for fast subsequent query/lookup is ubiquitous in competitive programming, and indeed in all of computer science.) The catch? I don’t know of a good way to do this without using mutable arrays! But if we write it generically we can potentially reuse it (I have in fact reused this code several times already on other problems).

Let’s make a library for representing permutations. This code can be found in Perm.hs. First, some imports and the main Perm type itself, which is just an alias for UArray Int Int. UArray represents (immutable) unboxed arrays, that is, arrays whose elements can be stored “unboxed” in a contiguous block of memory. “Boxed” arrays are those where the array actually stores pointers and the elements themselves are allocated somewhere else. Of course we prefer using unboxed arrays whenever possible!

{-# LANGUAGE BangPatterns #-}

module Perm where

import           Control.Arrow
import           Control.Monad.ST
import           Data.Array.Base
import           Data.Array.MArray
import           Data.Array.ST
import           Data.Array.Unboxed

-- | 'Perm' represents a /1-indexed/ permutation.  It can also be
--   thought of as an endofunction on the set @{1 .. n}@.
type Perm = UArray Int Int

Just based on the problems where I used it, I’ve chosen to make Perm values 1-indexed, though of course we could easily have made a different choice. We can now define a few utility functions for working with permutations: fromList constructs a Perm from a list; andThen composes permutations; and inverse computes the inverse of a permutation. We’ll only need fromList to solve Chair Hopping, but the others may come in handy for other problems.

-- | Construct a 'Perm' from a list containing a permutation of the
--   numbers 1..n.  The resulting 'Perm' sends @i@ to whatever number
--   is at index @i-1@ in the list.
fromList :: [Int] -> Perm
fromList xs = listArray (1,length xs) xs

-- | Compose two permutations (corresponds to backwards function
--   composition).  Only defined if the permutations have the same
--   size.
andThen :: Perm -> Perm -> Perm
andThen p1 p2 = listArray (bounds p1) (map ((p1!) >>> (p2!)) (range (bounds p1)))

-- | Compute the inverse of a permutation.
inverse :: Perm -> Perm
inverse p = array (bounds p) [ (p!k, k) | k <- range (bounds p) ]

When decomposing a permutation into cycles, we assign each cycle a unique ID number, and compute a number of mappings:

These mappings are collected in the CycleDecomp data type:

data CycleDecomp = CD
  { cycleID     :: UArray Int Int  -- | Each number maps to the ID# of the cycle it is part of
  , cycleLen    :: UArray Int Int  -- | Each cycle ID maps to the length of that cycle
  , cycleIndex  :: UArray Int Int  -- | Each element maps to its (0-based) index in its cycle
  , cycleCounts :: UArray Int Int  -- | Each size maps to the number of cycles of that size
  }
  deriving Show

We can use these to quickly look up information about the cycle decomposition of a permutation. For example, if we want to know the size of the cycle containing element e, we can look it up with cycleLen!(cycleID!e). Or if we know that a and b are in the same cycle and we want to know the distance from a to b, we can compute it as (cycleIndex!b - cycleIndex!a) mod (cycleLen!(cycleID!a)).

Finally, here’s my code to actually compute all this information about a cycle decomposition in \(O(n)\) time, which works by looking at each element, and when finding an element which is so far unprocessed, it does a DFS in the permutation following the cycle from that element. To be honest, it’s kind of ugly; that’s what we get for working with mutable arrays in Haskell. I am very much interested if anyone has any ideas on how to make this (1) faster or (2) prettier. (I am aware those two criteria may be at odds!) I’m using STUArray which allows mutation inside a monadic ST block; at the end we freeze them into normal immutable UArrays. (Note there are also unsafe variants of reading, writing, and freezing which do less checks, but using them didn’t seem to speed things up; I’m very open to suggestions.)

-- | Cycle decomposition of a permutation in O(n), using mutable arrays.
permToCycles :: Perm -> CycleDecomp
permToCycles p = cd where

  (_,n) = bounds p

  cd = runST $ do
    cid <- newArray (1,n) 0
    cix <- newArray (1,n) 0
    ccs <- newArray (1,n) 0

    lens <- findCycles cid cix ccs 1 1
    cid' <- freeze cid
    cix' <- freeze cix
    ccs' <- freeze ccs
    return $ CD cid' (listArray (1,length lens) lens) cix' ccs'

  findCycles :: STUArray s Int Int -> STUArray s Int Int -> STUArray s Int Int
    -> Int -> Int -> ST s [Int]
  findCycles cid cix ccs l !k   -- l = next available cycle ID; k = cur element
    | k > n     = return []
    | otherwise = do
        -- check if k is already marked as part of a cycle
        id <- readArray cid k
        case id of
          0 -> do
            -- k is unvisited.  Explore its cycle and label it as l.
            len <- labelCycle cid cix l k 0

            -- Remember that we have one more cycle of this size.
            count <- readArray ccs len
            writeArray ccs len (count+1)

            -- Continue with the next label and the next element, and
            -- remember the size of this cycle
            (len:) <$> findCycles cid cix ccs (l+1) (k+1)

          -- k is already visited: just go on to the next element
          _ -> findCycles cid cix ccs l (k+1)

  -- Explore a single cycle, label all its elements and return its size.
  labelCycle cid cix l k !i = do

    -- Keep going as long as the next element is unlabelled.
    id <- readArray cid k
    case id of
      0 -> do

        -- Label the current element with l.
        writeArray cid k l
        -- The index of the current element is i.
        writeArray cix k i

        -- Look up the next element in the permutation and continue.
        (1+) <$> labelCycle cid cix l (p!k) (i+1)
      _ -> return 0

This code is overly generic in some sense—we don’t actually need all this information to solve Chair Hopping, for example—but again, I am trying to make it as reusable as possible.

Now, how can we use cycle decomposition to solve Chair Hopping? That will have to wait for another post!