Competitive Programming in Haskell: tree path decomposition, part II
In a previous
post
I discussed the first half of my solution to Factor-Full
Tree. In this post,
I will demonstrate how to decompose a tree into disjoint paths.
Technically, we should clarify that we are looking for directed
paths in a rooted tree, that is, paths that only proceed down the
tree. One could also ask about decomposing an unrooted tree into
disjoint undirected paths; I haven’t thought about how to do that in
general but intuitively I expect it is not too much more difficult.
For
this particular problem, we want to decompose a tree into
maximum-length paths (i.e. we start by taking the longest possible
path, then take the longest path from what remains, and so on); I will call
this the max-chain decomposition (I don’t know if there is a
standard term). However, there are other types of path
decomposition, such as heavy-light decomposition, so we will try to
keep the decomposition code somewhat generic.
Preliminaries
This post is literate Haskell; you can find the source code on GitHub. We begin with some language pragmas and imports.
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
module TreeDecomposition where
import Control.Arrow ((>>>), (***))
import Data.Bifunctor (second)
import Data.ByteString.Lazy.Char8 (ByteString)
import Data.ByteString.Lazy.Char8 qualified as BS
import Data.List (sortBy)
import Data.List.NonEmpty (NonEmpty)
import Data.List.NonEmpty qualified as NE
import Data.Map (Map, (!), (!?))
import Data.Map qualified as M
import Data.Ord (Down(..), comparing)
import Data.Tree (Tree(..), foldTree)
import Data.Tuple (swap)
import ScannerBS
Generic path decomposition
Remember, our goal is to split up a tree into a collection of linear paths; that is, in general, something like this:
What do we need in order to specify a decomposition of a tree into disjoint paths this way? Really, all we need is to choose at most one linked child for each node. In other words, at every node we can choose to continue the current path into a single child node (in which case all the other children will start their own new paths), or we could choose to terminate the current path (in which case every child will be the start of its own new path). We can represent such a choice with a function of type
type SubtreeSelector a = a -> [Tree a] -> Maybe (Tree a, [Tree a])
which takes as input the value at a node and the list of all the
subtrees, and possibly returns a selected subtree along with the list of remaining
subtrees.Of course, there is nothing in the
type that actually requires a SubtreeSelector
to return one of the
trees from its input paired with the rest, but nothing we will do
depends on this being true. In fact, I expect there may be some
interesting algorithms obtainable by running a “path decomposition”
with a “selector” function that actually makes up new trees instead of just
selecting one, similar to the chop
function.
Given such a subtree selection function, a generic path decomposition
function will then take a tree and turn it into a list of non-empty
paths:We could also imagine wanting information about the parent of each
path, and a mapping from tree nodes to some kind of path ID, but we
will keep things simple for now.
pathDecomposition :: SubtreeSelector a -> Tree a -> [NonEmpty a]
Implementing pathDecomposition
is a nice exercise; you might like to
try it yourself! You can find my implementation at the end of this
blog post.
Max-chain decomposition
Now, let’s use our generic path decomposition to implement a max-chain decomposition. At each node we want to select the tallest subtree; in order to do this efficiently, we can first annotate each tree node with its height, via a straightforward tree fold:
type Height = Int
labelHeight :: Tree a -> Tree (Height, a)
= foldTree node
labelHeight where
= case ts of
node a ts -> Node (0, a) []
[] -> Node (1 + maximum (map (fst . rootLabel) ts), a) ts _
Our subtree selection function can now select the subtree with the
largest Height
annotation. Instead of implementing this directly,
we might as well make a generic function for selecting the “best”
element from a list (we will reuse it later):
selectMaxBy :: (a -> a -> Ordering) -> [a] -> Maybe (a, [a])
= Nothing
selectMaxBy _ [] : as) = case selectMaxBy cmp as of
selectMaxBy cmp (a Nothing -> Just (a, [])
Just (b, bs) -> case cmp a b of
LT -> Just (b, a : bs)
-> Just (a, b : bs) _
We can now put the pieces together to implement max-chain
decomposition. We first label the tree by height, then do a path
decomposition that selects the tallest subtree at each node. We leave
the height annotations in the final output since they might be
useful—for example, we can tell how long each path is just by
looking at the Height
annotation on the first element. If we don’t
need them we can easily get rid of them later. We also sort by
descending Height
, since getting the longest chains first was kind
of the whole point.
maxChainDecomposition :: Tree a -> [NonEmpty (Height, a)]
=
maxChainDecomposition >>>
labelHeight const (selectMaxBy (comparing (fst . rootLabel)))) >>>
pathDecomposition (Down . fst . NE.head)) sortBy (comparing (
Factor-full tree solution
To flesh this out into a full solution to Factor-Full
Tree, after
computing the chain decomposition we need to assign prime factors to
the chains. From those, we can compute the value for each node if we
know which chain it is in and the value of its parent. To this end,
we will need one more function which computes a Map
recording the
parent of each node in a tree. Note that if we already know all the
edges in a given edge list are oriented the same way, we can build
this much more simply as e.g. map swap >>> M.fromList
; but when
(as in general) we don’t know which way the edges should be oriented
first, we might as well first build a Tree a
via DFS with
edgesToTree
and then construct the parentMap
like this afterwards.
parentMap :: Ord a => Tree a -> Map a a
= foldTree node >>> snd
parentMap where
node :: Ord a => a -> [(a, Map a a)] -> (a, Map a a)
= (a, M.fromList (map (,a) as) <> mconcat ms)
node a b where
= unzip b (as, ms)
Finally, we can solve Factor-Full tree. Note that some code from my
previous blog
post
is needed as well, and is included at the end of the post for
completeness. Once we compute the max chain decomposition and the
prime factor for each node, we use a lazy recursive
Map
to compute the value assigned to each node.
solve :: TC -> [Int]
TC{..} = M.elems assignment
solve where
-- Build the tree and compute its parent map
= edgesToTree Node edges 1
t = parentMap t
parent
-- Compute the max chain decomposition, and use it to assign a prime factor
-- to each non-root node
paths :: [[Node]]
= map (NE.toList . fmap snd) $ maxChainDecomposition t
paths
factor :: Map Node Int
= M.fromList . concat $ zipWith (\p -> map (,p)) primes paths
factor
-- Compute an assignment of each node to a value, using a lazy map
assignment :: Map Node Int
= M.fromList $ (1,1) : [(v, factor!v * assignment!(parent!v)) | v <- [2..n]] assignment
For an explanation of this code for primes
, see this old blog post.
primes :: [Int]
= 2 : sieve primes [3 ..]
primes where
: ps) xs =
sieve (p let (h, t) = span (< p * p) xs
in h ++ sieve ps (filter ((/= 0) . (`mod` p)) t)
Bonus: heavy-light decomposition
We can easily use our generic path decomposition to compute a heavy-light decomposition as well:
type Size = Int
labelSize :: Tree a -> Tree (Size, a)
= foldTree $ \a ts -> Node (1 + sum (map (fst . rootLabel) ts), a) ts
labelSize
heavyLightDecomposition :: Tree a -> [NonEmpty (Size, a)]
=
heavyLightDecomposition >>>
labelSize const (selectMaxBy (comparing (fst . rootLabel)))) pathDecomposition (
I plan to write about this in a future post.
Leftover code
Here’s my implementation of pathDecomposition
; how did you do?
= go
pathDecomposition select where
= selectPath select >>> second (concatMap go) >>> uncurry (:)
go
selectPath :: SubtreeSelector a -> Tree a -> (NonEmpty a, [Tree a])
= go
selectPath select where
Node a ts) = case select a ts of
go (Nothing -> (NE.singleton a, ts)
Just (t, ts') -> ((a NE.<|) *** (ts' ++)) (go t)
We also include some input parsing and tree-building code from last time.
main :: IO ()
= BS.interact $ runScanner tc >>> solve >>> map (show >>> BS.pack) >>> BS.unwords
main
type Node = Int
data TC = TC { n :: Int, edges :: [(Node, Node)] }
deriving (Eq, Show)
tc :: Scanner TC
= do
tc <- int
n <- (n - 1) >< pair int int
edges return TC{..}
edgesToMap :: Ord a => [(a, a)] -> Map a [a]
= concatMap (\p -> [p, swap p]) >>> dirEdgesToMap
edgesToMap
dirEdgesToMap :: Ord a => [(a, a)] -> Map a [a]
= map (second (: [])) >>> M.fromListWith (++)
dirEdgesToMap
mapToTree :: Ord a => (a -> [b] -> b) -> Map a [a] -> a -> b
= dfs root root
mapToTree nd m root where
= nd root (maybe [] (map (dfs root) . filter (/= parent)) (m !? root))
dfs parent root
edgesToTree :: Ord a => (a -> [b] -> b) -> [(a, a)] -> a -> b
= mapToTree nd . edgesToMap edgesToTree nd