Показаны сообщения с ярлыком алгоритм. Показать все сообщения
Показаны сообщения с ярлыком алгоритм. Показать все сообщения

воскресенье, 22 сентября 2013 г.

Подсчет количества пар соседних чисел в произвольном массиве на haskell

Приятель подкинул задачку: имеется массив 8-битовых чисел произвольной длины, требуется посчитать количество пар чисел в этом массиве, отличающихся на 1 бит. Такие числа я назвал соседними в заголовке статьи. Например, числа 1 (в битовом представлении 01) и 3 (11) - соседние, а числа 1 и 2 (10) - нет.

Ниже привожу решение этой задачи на haskell.

Самое очевидное решение - попарное сравнение чисел - не интересно. Нужно что-то получше, что сузило бы количество сравнений или вообще не требовало бы их. Для этого нужно понять, какими свойствами обладают соседние числа. И с этим как раз все просто. Поскольку соседние числа в битовом представлении отличаются только в одном бите, то, соответственно, число установленных битов для них отличается ровно на единицу. Отсюда у нас появляется простой алгоритм - разбить (partition) весь заданный массив чисел по группам, в которых суммы битов всех элементов будут равны, а затем попарно сравнить все элементы из соседних групп и сложить результаты. В случае массива 8-битовых чисел количество таких групп равно 9 (от 0 до 8 установленных битов), обозначим эти группы g0, g1 .. g8, тогда количество соседних групп равно 8 ((g0, g1), (g1, g2) .. (g7, g8)).

Что ж, задача простая, приступим к ее решению. Я сначала приведу код, который я поместил в файл с названием bits_example.hs, а затем его прокомментирую.
module BitsExample where
import Data.Bits
import Data.List


partitionBySetBits' :: Bits a => Int -> [a] -> [[a]]
partitionBySetBits' (-1) _  = [[]]
partitionBySetBits' n xs    = fst x : partitionBySetBits' (- 1) (snd x)
    where x = partition (\-> popCount z == n) xs

partitionBySetBits = tail . reverse . partitionBySetBits' n

adjacentPairs :: Bits a => ([a], [a]) -> [(a, a)]
adjacentPairs (xs, ys) =
    [(x, y) | x <- xs, y <- ys, (popCount $ x `xor` y) == 1]

nAdjacentNumbers :: (Bits a, Num a) => Int -> [a] -> Int
nAdjacentNumbers n xs =
    let adjacentSets = zip partition $ tail partition
    in foldr (+) 0 (map (length . adjacentPairs) adjacentSets)
    where partition = partitionBySetBits n xs
В программе потребуются модули Data.Bits для манипулирования битовым представлением чисел и Data.List, в котором определена замечательная функция partition. Первая функция partitionBySetBits' полностью оправдывает свое название - она разбивает заданный массив чисел xs на n групп по количеству установленных битов в битовом представлении числа (как видите, я не стал ограничиваться 8-битовыми массивами и данное решение подходит для массивов чисел произвольной битовой базы). Количество установленных битов в числе возвращает функция popCount из модуля Data.Bits, она используется в качестве предиката в функции partition. Функция partition возвращает кортеж, состоящий из пары списков - в первом списке содержатся элементы из исходного списка, которые удовлетворяют условию предиката, во втором - которые не удовлетворяют. Наша функция должна возвращать список списков, поэтому мы помещаем первый список из пары в первый подсписок возвращаемого результата (с помощью функции fst), и далее (с помощью функции snd) последовательно прикрепляем остальные группы путем рекурсивного вызова partitionBySetBits' с уменьшенным значением n для элементов из второго списка пары, возвращенной функцией partition. Когда значение n достигает -1, функция partitionBySetBits' возвращает список, состоящий из пустого списка - это условие прекращения рекурсии. То, что рекурсия не прекращается при достижении n нуля важно - ведь в исходном массиве могут встречаться нули.

Что вернет функция partitionBySetBits'? Очевидно, перевернутый список групп (т.е. g8, g7 .. g0), заканчивающийся ненужным пустым списком. Поэтому ниже определена еще одна функция partitionBySetBits, которая переворачивает список и удаляет ненужный пустой элемент. Запустим ghci и проверим как она работает
$ ghci
GHCi, version 7.6.3: http://www.haskell.org/ghc/  :? for help
Loading package ghc-prim ... linking ... done.
Loading package integer-gmp ... linking ... done.
Loading package base ... linking ... done.
Prelude> :l bits_example.hs 
[1 of 1] Compiling BitsExample      ( bits_example.hs, interpreted )
Ok, modules loaded: BitsExample.
*BitsExample> :set +s
*BitsExample> partitionBySetBits 2 [0..3]
[[0],[1,2],[3]]
(0.05 secs, 9664352 bytes)
*BitsExample> partitionBySetBits 4 [0..15]
[[0],[1,2,4,8],[3,5,6,9,10,12],[7,11,13,14],[15]]
(0.01 secs, 2580960 bytes)
*BitsExample> partitionBySetBits 8 [0..255]
[[0],[1,2,4,8,16,32,64,128],[3,5,6,9,10,12,17,18,20,24,33,34,36,40,48,65,66,68,72,80,96,129,130,132,136,144,160,192],[7,11,13,14,19,21,22,25,26,28,35,37,38,41,42,44,49,50,52,56,67,69,70,73,74,76,81,82,84,88,97,98,100,104,112,131,133,134,137,138,140,145,146,148,152,161,162,164,168,176,193,194,196,200,208,224],[15,23,27,29,30,39,43,45,46,51,53,54,57,58,60,71,75,77,78,83,85,86,89,90,92,99,101,102,105,106,108,113,114,116,120,135,139,141,142,147,149,150,153,154,156,163,165,166,169,170,172,177,178,180,184,195,197,198,201,202,204,209,210,212,216,225,226,228,232,240],[31,47,55,59,61,62,79,87,91,93,94,103,107,109,110,115,117,118,121,122,124,143,151,155,157,158,167,171,173,174,179,181,182,185,186,188,199,203,205,206,211,213,214,217,218,220,227,229,230,233,234,236,241,242,244,248],[63,95,111,119,123,125,126,159,175,183,187,189,190,207,215,219,221,222,231,235,237,238,243,245,246,249,250,252],[127,191,223,239,247,251,253,254],[255]]
(0.03 secs, 3654448 bytes)
*BitsExample> 
Работает! Двигаемся дальше. Функция adjacentPairs вспомогательная - она принимает пару списков, соответствующих двум соседним группам и возвращает список пар - все соседние числа, составленные путем попарного комбинирования чисел из этих двух групп. Для определения того, являются ли два числа соседними, используется простой алгоритм: сначала числа складываются с помощью операции исключающего или (xor), а затем к результату применяется функция popCount, если полученное значение равно 1, то эти числа - соседние.

Функция nAdjacentNumbers считает искомое количество пар соседних чисел в произвольном массиве. Для этого исходный массив разбивается на группы с помощью функции partitionBySetBits, а затем, с помощью комбинации функций zip и tail, эти группы преобразуются в список соседних групп типа (g0, g1), (g1, g2) .. (g7, g8), который я назвал adjacentSets (напомню, что выражение zip "abc" $ tail "abc" делает как раз то, что нам нужно: возвращает список ["ab", "bc"]). Далее этот список с помощью map (length . adjacentPairs) преобразуется в список количества пар соседних чисел в соседних группах (здесь используется вспомогательная функция adjacentPairs), элементы которого суммируются в свертке foldr (+) 0.

Вот и все. Важно отметить, что класс этого алгоритма, как и класс примитивного нереализованного нами алгоритма с полным попарным перебором, равен O(n^2), поэтому ожидать чудес от него не стоит, хотя, естественно, nAdjacentNumbers будет выполняться значительно быстрее, чем полный перебор. Квадратичность алгоритма следует из реализации функции adjacentPairs, в которой попарно перебираются элементы соседних групп.

Давайте проверим, как работает наша функция.
*BitsExample> nAdjacentNumbers 0 []
0
(0.00 secs, 2096344 bytes)
*BitsExample> nAdjacentNumbers 1 [0..1]
1
(0.57 secs, 2060080 bytes)
*BitsExample> nAdjacentNumbers 2 [0..3]
4
(0.00 secs, 2573000 bytes)
*BitsExample> nAdjacentNumbers 4 [0..15]
32
(0.00 secs, 2063456 bytes)
*BitsExample> nAdjacentNumbers 8 [0..255]
1024
(0.00 secs, 7082504 bytes)
*BitsExample> nAdjacentNumbers 16 [0..65535]
524288
(1120.68 secs, 285882785392 bytes)
*BitsExample>
Какой интересный результат! Количество пар соседних чисел в массивах, составленных из уникальных чисел, заполняющих в битовом представлении все возможные комбинации, равно числу, являющемуся степенью двойки. Так, для 4-битовых чисел это число равно 2^5, для 8-битовых - 2^10, а для 16-битовых - 2^19. Также отметим, что последнее вычисление выполнялось больше 18 минут, поэтому считать количество пар соседних чисел для полной 32-битовой группы с использованием этого алгоритма нецелесообразно.

А теперь посчитаем количество пар соседних чисел в произвольных 8-битовых массивах. Для этого подключим модуль System.Random и определим функцию randomList
*BitsExample> :m +System.Random
*BitsExample System.Random> let randomList s = randomRs (0,255) (mkStdGen s)
Loading package array-0.4.0.1 ... linking ... done.
Loading package deepseq-1.3.0.1 ... linking ... done.
Loading package old-locale-1.0.0.5 ... linking ... done.
Loading package time-1.4.0.1 ... linking ... done.
Loading package random-1.0.1.1 ... linking ... done.
(0.22 secs, 23528592 bytes)
*BitsExample System.Random> nAdjacentNumbers 8 (take 2000 $ randomList 1 :: [Int])
62330
(1.09 secs, 200541944 bytes)
*BitsExample System.Random> nAdjacentNumbers 8 (take 20000 $ randomList 1 :: [Int])
6249274
(96.87 secs, 18507404864 bytes)
*BitsExample System.Random>
Массив, состоящий из 20000 случайных 8-битовых чисел, рассчитывался полторы минуты. Это очень плохой результат. Нужно улучшать алгоритм. Потенциально, алгоритм расчета может быть улучшен за счет более мелкого разбиения чисел по группам. Давайте подумаем, как это сделать. Допустим мы разобьем битовое представление числа на две равные части - левую и правую (так, для 8-битовых чисел это соответствует разбиению по 4 старшим битам и 4 младшим битам). Какими свойствами будут обладать битовые части для соседних чисел? Количество установленных битов в одной из них (левой или правой) будет совпадать в обоих числах, а в другой - отличаться на единицу. Это свойство мы используем в нашей новой реализации.

Напишем новую функцию разбиения по группам partitionBySetBits2, которая будет разбивать исходный массив чисел по количеству установленных битов в левой и правой части битовых представлений чисел.
partitionBySetBits'' :: (Bits a, Num a) => Int -> Int -> Int -> [a] -> [[a]]
partitionBySetBits'' _ _ (-1) _   = [[]]
partitionBySetBits'' k l n xs     =
    let xs' = partition (\-> popCount (vpart k l z) == n) xs
    in fst xs' : partitionBySetBits'' k l (- 1) (snd xs')
    where vpart k l z = z .&. (1 `rotate` l - 1) `shift` k

partitionBySetBits2' :: (Bits a, Num a) => Int -> [a] -> [[[a]]]
partitionBySetBits2' (-1) _  = [[[]]]
partitionBySetBits2' n xs    =
    let xs1 = partitionBySetBits'' 0 n2 n2 xs
    in map (partitionBySetBits'' n2 n2 n2) xs1
    where n2 = n `quot` 2

partitionBySetBits2 n xs =
    map (tail . reverse) (tail $ reverse $ partitionBySetBits2' n xs)
Новая функция partitionBySetBits'' умеет делать разбиение по количеству установленных битов в части битового представления числа. Для этого она принимает два дополнительных параметра: k - порядковый номер начального бита части битового представления и l - длину этой части. Кроме того, по сравнению с partitionBySetBits' мне пришлось добавить ограничение Num a для типовой переменной a - это нужно для функции vpart. Функция vpart переводит битовую часть, определяемую k и l, в число, количество бит в котором определяет предикат для функции partition. В остальном логика partitionBySetBits'' совпадает с логикой partitionBySetBits'. Функция partitionBySetBits'' вызывается дважды (для формирования групп чисел по левой и правой частям битового представления) из вспомогательной функции partitionBySetBits2'. Заметим, что partitionBySetBits2' возвращает список списков списков чисел благодаря новому алгоритму разбиения. Функция partitionBySetBits2 переворачивает все подсписки и удаляет ненужный пустой элемент, являющийся артефактом рекурсии в partitionBySetBits''. Проверим как она работает.
*BitsExample System.Random> partitionBySetBits2 2 [0..3]
[[[0],[2]],[[1],[3]]]
(0.02 secs, 4127056 bytes)
*BitsExample System.Random> partitionBySetBits2 4 [0..15]
[[[0],[4,8],[12]],[[1,2],[5,6,9,10],[13,14]],[[3],[7,11],[15]]]
(0.02 secs, 4157584 bytes)
*BitsExample System.Random> partitionBySetBits2 8 [0..255]
[[[0],[16,32,64,128],[48,80,96,144,160,192],[112,176,208,224],[240]],[[1,2,4,8],[17,18,20,24,33,34,36,40,65,66,68,72,129,130,132,136],[49,50,52,56,81,82,84,88,97,98,100,104,145,146,148,152,161,162,164,168,193,194,196,200],[113,114,116,120,177,178,180,184,209,210,212,216,225,226,228,232],[241,242,244,248]],[[3,5,6,9,10,12],[19,21,22,25,26,28,35,37,38,41,42,44,67,69,70,73,74,76,131,133,134,137,138,140],[51,53,54,57,58,60,83,85,86,89,90,92,99,101,102,105,106,108,147,149,150,153,154,156,163,165,166,169,170,172,195,197,198,201,202,204],[115,117,118,121,122,124,179,181,182,185,186,188,211,213,214,217,218,220,227,229,230,233,234,236],[243,245,246,249,250,252]],[[7,11,13,14],[23,27,29,30,39,43,45,46,71,75,77,78,135,139,141,142],[55,59,61,62,87,91,93,94,103,107,109,110,151,155,157,158,167,171,173,174,199,203,205,206],[119,123,125,126,183,187,189,190,215,219,221,222,231,235,237,238],[247,251,253,254]],[[15],[31,47,79,143],[63,95,111,159,175,207],[127,191,223,239],[255]]]
(0.05 secs, 7290832 bytes)
*BitsExample System.Random>
Переходим к реализации nAdjacentNumbers2. В предыдущем алгоритме мы использовали тот факт, что все соседние числа могут находиться только в соседних группах. Наш новый алгоритм может использовать этот же факт. В самом деле, после фиксации одной из подгрупп, соответствующей количеству битов в левой или правой битовой части числа, в другой подгруппе количество битов для соседних чисел должно отличаться ровно на единицу. Вот и рецепт получения соседних групп. Я не хочу искать аналитическое выражение для поиска соседних подгрупп (мы сделаем это в следующем варианте алгоритма). В случае 8-битовых чисел легко увидеть, что соседние группы соответствуют списку ((0, 0), (0, 1)), ((0, 0), (1, 0)), ((0, 1), (0, 2)), ((0, 1), (1, 1)) .., состоящему из 40 элементов (первый элемент в каждой паре соответствует числу установленных битов в правой части битового представления числа, второй - в левой). Мы просто перечислим все эти элементы в нашей новой функции nAdjacentNumbers2.
nAdjacentNumbers2 :: (Bits a, Num a) => [a] -> Int
nAdjacentNumbers2 xs =
    let adjacentSets = [((0, 0), (0, 1)), ((0, 0), (1, 0)), ((0, 1), (0, 2)),
                        ((0, 1), (1, 1)), ((0, 2), (0, 3)), ((0, 2), (1, 2)),
                        ((0, 3), (0, 4)), ((0, 3), (1, 3)), ((0, 4), (1, 4)),
                        ((1, 0), (1, 1)), ((1, 0), (2, 0)), ((1, 1), (1, 2)),
                        ((1, 1), (2, 1)), ((1, 2), (1, 3)), ((1, 2), (2, 2)),
                        ((1, 3), (1, 4)), ((1, 3), (2, 3)), ((1, 4), (2, 4)),
                        ((2, 0), (2, 1)), ((2, 0), (3, 0)), ((2, 1), (2, 2)),
                        ((2, 1), (3, 1)), ((2, 2), (2, 3)), ((2, 2), (3, 2)),
                        ((2, 3), (2, 4)), ((2, 3), (3, 3)), ((2, 4), (3, 4)),
                        ((3, 0), (3, 1)), ((3, 0), (4, 0)), ((3, 1), (3, 2)),
                        ((3, 1), (4, 1)), ((3, 2), (3, 3)), ((3, 2), (4, 2)),
                        ((3, 3), (3, 4)), ((3, 3), (4, 3)), ((3, 4), (4, 4)),
                        ((4, 0), (4, 1)), ((4, 1), (4, 2)), ((4, 2), (4, 3)),
                        ((4, 3), (4, 4))]
    in foldr (+) 0 (map (xyLength) adjacentSets)
    where partition  = partitionBySetBits2 8 xs
          xyLength z = let x  = fst z
                           y  = snd z
                           x1 = fst x
                           x2 = snd x
                           y1 = fst y
                           y2 = snd y
                       in length $ adjacentPairs (partition!!x1!!x2,
                                                  partition!!y1!!y2)
В функции nAdjacentNumbers2 нет параметра n, так как она может быть использована для поиска числа пар соседних чисел только в 8-битовых массивах, и это отражено внутри ее тела при вызове partitionBySetBits2 8 xs. Строкой выше выполняется основная задача функции: список adjacentSets трансформируется в список количества пар соседних чисел, посчитанных для всех соседних групп с помощью функции xyLength, которая находит соседние группы в списке partition путем непосредственного обращения к его элементам с использованием оператора !!, а затем подсчитывает количество пар соседних чисел в этих подгруппах с помощью нашей старой функции adjacentPairs. Затем все найденные числа складываются с помощью свертки foldr (+) 0. Из-за использования adjacentPairs класс нового алгоритма по-прежнему квадратичный.

Сравним значения, полученные с использованием нового алгоритма с теми, которые мы уже считали.
*BitsExample System.Random> nAdjacentNumbers2 [0..255]
1024
(0.06 secs, 8795888 bytes)
*BitsExample System.Random> nAdjacentNumbers2 (take 2000 $ randomList 1 :: [Int])
62330
(0.74 secs, 140816296 bytes)
*BitsExample System.Random> nAdjacentNumbers2 (take 20000 $ randomList 1 :: [Int])
6249274
(68.79 secs, 12915067416 bytes)
*BitsExample System.Random>
Отлично! Результаты те же, работает чуть быстрее.

А теперь напишем линейный алгоритм. Идея та же, что и в предыдущем, главное отличие - разбивать на подгруппы будем по каждому отдельному биту в битовом представлении. Такие группы примечательны тем, что все элементы в соседних группах являются соседними числами (это очевидно из наших определений соседних групп и соседних чисел), а значит квадратичный алгоритм adjacentPairs больше не нужен - достаточно перемножить длины списков соседних групп. Кроме того, для таких групп достаточно легко вывести выражение для получения полного списка всех соседних групп. Для этого достаточно представить все возможные группы в виде битовых полей. Для 8-битовых чисел длина полного списка таких групп равна 256, что соответствует всем различным вариантам расположения битов в 8-битовом поле. Начиная с группы (0, 0, 0, 0, 0, 0, 0, 0) будем помещать единицу слева направо - это и будут все соседние группы для нулевой группы, затем проделаем то же самое с группой (0, 0, 0, 0, 0, 0, 0, 1) и так далее вплоть до группы (1, 1, 1, 1, 1, 1, 1, 1). Полученные таким образом соседние группы будут соответствовать выражению
[(x, z) | x <- [0..254] :: [Int], y <- [1, 2, 4, 8, 16, 32, 64, 128],
          let z = x .|. y, z /= x]
Его мы и будем использовать для генерации списка соседних групп. Ниже приводится реализация линейного алгоритма подсчета количесива пар соседних чисел в массиве 8-битовых чисел.
partitionBySetBits8' :: (Bits a, Num a) => Int -> [a] -> [[[[[[[[[a]]]]]]]]]
partitionBySetBits8' (-1) _  = [[[[[[[[[]]]]]]]]]
partitionBySetBits8' n xs    =
    let xs1 = partitionBySetBits'' 0 n8 n8 xs
    in let xs2 = map (partitionBySetBits'' n8 n8 n8) xs1
       in let xs3 = map (map (partitionBySetBits'' (n8 * 2) n8 n8)) xs2
          in let xs4 = map (map (map
                                (partitionBySetBits'' (n8 * 3) n8 n8))) xs3
             in let xs5 = map (map (map (map
                                (partitionBySetBits'' (n8 * 4) n8 n8)))) xs4
                in let xs6 = map (map (map (map (map
                                (partitionBySetBits'' (n8 * 5) n8 n8))))) xs5
                   in let xs7 = map (map (map (map (map (map
                                (partitionBySetBits'' (n8 * 6) n8 n8)))))) xs6
             in map (map (map (map (map (map (map
                                (partitionBySetBits'' (n8 * 7) n8 n8))))))) xs7
    where n8 = n `quot` 8

partitionBySetBits8 n xs    =
    let xs1 = partitionBySetBits8' n xs
    in let xs2 = map (tail . reverse) xs1
       in let xs3 = map (map (tail . reverse)) xs2
          in let xs4 = map (map (map (tail . reverse))) xs3
             in let xs5 = map (map (map (map (tail . reverse)))) xs4
                in let xs6 = map (map (map (map (map (tail . reverse))))) xs5
                   in let xs7 = map (map (map (map (map (map
                                                    (tail . reverse)))))) xs6
                      in tail $ reverse $ map (map (map (map (map (map (map
                                                    (tail . reverse))))))) xs7

nAdjacentNumbers8 :: (Bits a, Num a) => [a] -> Int
nAdjacentNumbers8 xs =
    let adjacentSets = [(x, z) | x <- [0..254] :: [Int],
                       y <- [1, 2, 4, 8, 16, 32, 64, 128],
                       let z = x .|. y, z /= x]
    in foldr (+) 0 (map (xyLength) adjacentSets)
    where partition  = partitionBySetBits8 8 xs
          xyLength z = let x              = fst z
                           y              = snd z
                           nonzeroToOne 0 = 0
                           nonzeroToOne _ = 1
                           x1             = nonzeroToOne $ x .&. 1
                           x2             = nonzeroToOne $ x .&. 2
                           x3             = nonzeroToOne $ x .&. 4
                           x4             = nonzeroToOne $ x .&. 8
                           x5             = nonzeroToOne $ x .&. 16
                           x6             = nonzeroToOne $ x .&. 32
                           x7             = nonzeroToOne $ x .&. 64
                           x8             = nonzeroToOne $ x .&. 128
                           y1             = nonzeroToOne $ y .&. 1
                           y2             = nonzeroToOne $ y .&. 2
                           y3             = nonzeroToOne $ y .&. 4
                           y4             = nonzeroToOne $ y .&. 8
                           y5             = nonzeroToOne $ y .&. 16
                           y6             = nonzeroToOne $ y .&. 32
                           y7             = nonzeroToOne $ y .&. 64
                           y8             = nonzeroToOne $ y .&. 128
                       in (length $ partition!!x1!!x2!!x3!!x4!!x5!!x6!!x7!!x8) *
                          (length $ partition!!y1!!y2!!y3!!y4!!y5!!y6!!y7!!y8)
Хоть он выглядит уже не так симпатично, как предыдущие, зато должен работать очень быстро! Проверим.
*BitsExample System.Random> nAdjacentNumbers8 [0..255]
1024
(0.16 secs, 16539728 bytes)
*BitsExample System.Random> nAdjacentNumbers8 (take 2000 $ randomList 1 :: [Int])
62330
(0.20 secs, 31029640 bytes)
*BitsExample System.Random> nAdjacentNumbers8 (take 20000 $ randomList 1 :: [Int])
6249274
(0.91 secs, 209783000 bytes)
*BitsExample System.Random> nAdjacentNumbers8 (take 200000 $ randomList 1 :: [Int])
624987946
(7.82 secs, 1996417336 bytes)
*BitsExample System.Random> nAdjacentNumbers8 (take 2000000 $ randomList 1 :: [Int])
62500089507
(76.73 secs, 19851517872 bytes)
*BitsExample System.Random>
На этот раз действительно быстро. То, что предыдущий алгоритм считал 68 секунд, теперь было посчитано быстрее чем за секунду. Следующий рубеж - массив из 200000 чисел - был посчитан быстрее чем за 8 секунд, а массив из двух миллионов чисел - за 76 секунд, и это хорошо отражает линейность нового алгоритма.

Исходный код примера здесь.

воскресенье, 15 мая 2011 г.

Сравнение подходов в реализации алгоритмов на C++ и Haskell

Классический пример краткой и ясной реализации алгоритма на Haskell - это быстрая сортировка:
quickSort [] = []
quickSort (x:xs) = quickSort [ y | y <- xs, y < x ] ++ [ x ] ++
                   quickSort [ y | y <- xs, y >= x ]
Обычно рядом с этим кодом приводят какую-нибудь 20-строчную реализацию быстрой сортировки на C или C++. И конечно же сравнение идёт не в пользу C++. Откровенно говоря, здесь заключена немалая доля лукавства: как можно сравнивать высокоуровневые определители списков (list comprehensions - выражения в квадратных скобках в данном примере) с низкоуровневыми итерациями в коде C++? С другой стороны, определители списков являются встроенным механизмом Haskell, поэтому подобное сравнение всё-таки допустимо.

Ниже приводятся реализации одного алгоритма на C++ и Haskell. Алгоритм был взят из книги Ф. Меньшикова "Олимпиадные задачи по программированию". Задача формулируется следующим образом:
Дана последовательность из N целых чисел чисел. Необходимо удалить из последовательности минимальное количество чисел так, чтобы оставшаяся часть последовательности оказалась строго возрастающей. Иными словами нужно найти самую длинную возрастающую подпоследовательность.
Базовый алгоритм для решения этой задачи приводится в этой же книге:
Начиная с первого элемента последовательности и заканчивая последним найти максимальную длину возрастающей подпоследовательности из предшествующих элементов, меньших данного, и прибавить к ней 1. Эта величина будет соответствовать максимальной длине возрастающей подпоследовательности, которую можно построить начиная с 1-ого элемента исходной последовательность вплоть до данного элемента. Так, для 1-ого элемента последовательности из предшествующих элементов не существует, поэтому ему соответствует длина 1. Для произвольного i-ого элемента нужно совершить обратный проход к началу последовательности, найти меньший по значению элемент с максимальным значением длины и прибавить к нему 1. После того как длины найдены (при условии, что они сохранены в некотором вспомогательном массиве), нужно в исходной последовательности найти элемент, которому соответствует  максимальное значение длины и, двигаясь к началу последовательности, находить первые элементы, для которых соответствующая длина уменьшилась на единицу. Эти элементы (вместе с исходным) и будут составлять искомую подпоследовательность наибольшей длины.
Несколько замечаний. Во-первых, существуют модификации этого алгоритма, в которых финальный проход от элемента, которому соответствует максимальная длина возрастающей подпоследовательности к началу исходной последовательности не требуется. Поскольку моей целью является не написание самого оптимального алгоритма, а качественное сравнение подобных реализаций на C++ и Haskell, то я использовал базовый алгоритм. Во-вторых, очевидно, что в базовом варианте класс этого алгоритма соответствует O(n^2), так как присутствуют вложенные итерации (обратные проходы к началу последовательности) внутри базовой итерации по всем элементам. В-третьих, могут существовать несколько подпоследовательностей максимальной длины, нам нужно выбрать любую из них.

Итак, реализация этого алгоритма на C++ выглядит следующим образом:
typedef std::vector< int >  VIn;

typedef VIn::value_type     VInElem;

typedef std::vector< int >  VCnt;

void  findLongestIncSeq( const VIn &  vIn, VIn &  vRes )
{
    VCnt                 vCnt( vIn.size() );
    VCnt::iterator       vCntIt( vCnt.begin() );
    VIn::const_iterator  vMaxIt( vIn.begin() );
    VCnt::iterator       vCntMaxIt( vCnt.begin() );
    int                  maxLen( 0 );

    for ( VIn::const_iterator  k( vIn.begin() ); k != vIn.end(); ++k, ++vCntIt )
    {
        int                           maxVal( 0 );
        VCnt::const_reverse_iterator  vCntBackIt( vCntIt );

        for ( VIn::const_reverse_iterator  l( k ); l != vIn.rend();
                                                            ++l, ++vCntBackIt )
        {
            if ( *k <= *l )
                continue;

            if ( *vCntBackIt > maxVal )
                maxVal = *vCntBackIt;
        }

        *vCntIt = maxVal + 1;
        if ( *vCntIt > maxLen )
        {
            maxLen = *vCntIt;
            vMaxIt = k;
            vCntMaxIt = vCntIt;
        }
    }

    std::cout << "Lengths:  ";
    printVec( vCnt );

    if ( vMaxIt == vIn.begin() )
        return;

    std::stack< VInElem >   reverseSeq;
    VCnt::reverse_iterator  vCntBackIt( vCntMaxIt + 1 );

    for ( VIn::const_reverse_iterator  k( vMaxIt + 1 ); k != vIn.rend();
                                                            ++k, ++vCntBackIt )
    {
        if ( *vCntBackIt == maxLen )
        {
            reverseSeq.push( *k );
            --maxLen;
        }
    }

    while ( ! reverseSeq.empty() )
    {
        vRes.push_back( reverseSeq.top() );
        reverseSeq.pop();
    }
}
Вспомогательная функция printVec() реализуется так:
template  < typename  VElem >
struct  PrintVec
{
    void  operator()( VElem  value )
    {
        std::cout << value << " ";
    }
};

template  < typename  Vec >
void  printVec( Vec  v )
{
    std::for_each( v.begin(), v.end(), PrintVec< typename Vec::value_type >() );
    std::cout << std::endl;
}
В функцию findLongestIncSeq() передаются ссылки на исходную последовательность vIn и последовательность vRes, которую нужно заполнить. Внутри функции создаем вспомогательный массив целых чисел vCnt, в котором будем хранить длины максимальной возрастающей последовательности до i-ого элемента включительно. Внутри первого цикла for происходит основная итерация, в которой заполняется массив vCnt, кроме того попутно вычисляются максимальная длина возрастающей подпоследовательности maxLen и соответствующие ей элементы в vIn и vCnt: vMaxIt и vCntMaxIt - они будут нужны для определения элемента, с которого нужно начать финальный проход по исходной последовательности вниз к первому элементу. Этот финальный проход осуществляется в последнем цикле for функции findLongestIncSeq(), в котором заполняется вспомогательный стек reverseSeq. Теперь, чтобы заполнить искомую последовательность vRes, нужно опустошить reverseSeq и сложить его элементы в vRes, что и происходит в последнем цикле while.

Я не буду приводить функцию main() - это детали. Укажу только, что для тестовой последовательности [67 5 34 89 65 12 90 75 8 9 3] выводится следующий результат:
Original: 67 5 34 89 65 12 90 75 8 9 3 
Lengths:  1 1 2 3 3 2 4 4 2 3 1 
Sequence: 5 34 65 90
В строке Original выводится исходная последовательность, в строке Lengths - вспомогательная последовательность длин возрастающих подпоследовательностей, а в строке Sequence - результирующая возрастающая подпоследовательность максимальной длины.

А теперь внимание. Подобный алгоритм написанный на Haskell (я привожу весь код вместе с тестовой частью, так что его можно сразу скомпилировать):
{-find longest increasing subsequence-}

getPrevLength [] = 0
getPrevLength x  = maximum $ doGetLengths x

doGetLengths []     = []
doGetLengths (x:xs) = getPrevLength [ y | y <- xs, y < x ] + 1 : doGetLengths xs

getLengths = reverse . doGetLengths . reverse

doCutTail [] _                                          = []
doCutTail a@(x:xs) maxLen
    | getPrevLength [ y | y <- xs, y < x ] + 1 < maxLen = doCutTail xs maxLen
    | otherwise                                         = a

cutTail [] = []
cutTail x  = reverse $ doCutTail ( reverse x ) ( maximum $ getLengths x )

doFindLongestIncSec' [] _   = []
doFindLongestIncSec' (x:xs) maxLen
    | maxLen - prevLen == 1 = x : doFindLongestIncSec' xs prevLen
    | otherwise             = doFindLongestIncSec' xs maxLen
        where prevLen = getPrevLength [ y | y <- xs, y < x ] + 1

doFindLongestIncSec [] = []
doFindLongestIncSec x  =
    doFindLongestIncSec' ( reverse x ) ( maximum ( getLengths x ) + 1 )

findLongestIncSec = reverse . doFindLongestIncSec . cutTail

{-do some testing-}

doTest x = do
    putStrLn $ "Original: " ++ show x
    putStrLn $ "Lengths:  " ++ ( show $ getLengths x )
    putStrLn $ "Trimmed:  " ++ ( show $ cutTail x )
    putStrLn $ "Sequence: " ++ ( show $ findLongestIncSec x )

main = mapM_ doTest [ [ 10, 5, 89, 16, 78, 67, 56, 34, 12, 10 ],
                      [ 67, 5, 34, 89, 65, 12, 90, 75, 8, 9, 3 ] ]
Я назвал этот алгоритм подобным, поскольку его невозможно реализовать тем же самым способом, что и в C++ (во всяком случае без использования монад). В Haskell отсутствуют итерации и состояния (то бишь переменные), и это означает, что создать вспомогательный массив длин подпоследовательностей и использовать его на том же уровне программного потока не удастся. По той же причине невозможно сначала запомнить индекс элемента с максимальным значением возрастающей подпоследовательности, а затем использовать его по своему усмотрению в параллельной последовательности вызовов функций. В Haskell возможность параллельной последовательности вызовов функций появляется только с использованием монад (см. реализацию doTest выше), а в общем случае вызовы функций могут быть только вложены друг в друга. Это означает, что обычно решение любой задачи в Haskell представляет собой вызов единственной функции, которая в свою очередь вызывает единственную функцию и т.д. Это создает впечатление решения одним махом. В самом деле, в приведенном коде всё решение заключается в вызове единственной функции findLongestIncSeq x. Параллельные вызовы getLengths и cutTail, организованные в doTest нужны только для демонстрации промежуточных состояний исходной последовательности и никоим образом не влияют на результат финального вызова findLongestIncSeq.

Приведенные выше особенности программирования на Haskell требуют более частого использования рекурсивных вызовов, чем в C++, где рекурсия чаще всего заменяется итерацией. Так, функция getPrevLength, которая используется в программе для поиска длины максимальной подпоследовательности, является косвенно рекурсивной благодаря вызову doGetLengths, которая в свою очередь является рекурсивной и косвенно, и явно. Функция cutTail возвращает подпоследовательность исходной последовательности от начала до последнего элемента, которому соответствует максимальная длина возрастающей подпоследовательности. Основная функция findLongestIncSeq является композицией функций reverse, doFindLongestIncSeq и cutTail - это отражено в ее определении.

Нельзя не заметить, что в коде часто употребляется обращение последовательности с помощью reverse. Если учесть, что скорость reverse должна соответствовать O(n), это должно настораживать. Проблема в том, что семантика конструктора списков Haskell предполагает добавление элементов в начало, а не в конец списка, а в нашем алгоритме чаще всего приходится выделять последний элемент подпоследовательности и проходить вниз к началу списка. Таким образом, от reverse можно избавиться, видоизменив исходный алгоритм. Однако проблема обращения списка не такая уж серьезная, как кажется сначала. Дело в том, что reverse не вызывается рекурсивно, а только счётное количество раз, соответственно общая скорость алгоритма остается прежней - O(n^2).

Еще одно важное замечание. Что если мы захотим изменить тип элементов последовательности? Например, захотим искать максимальную подпоследовательность вещественных чисел или строк. В коде на C++ такой вариант предусмотрен с помощью объявления типов VIn и VInElem - стоит поменять тип элемента вектора и всё должно заработать. В коде на Haskell вообще ничего менять не нужно! И это не потому, что в Haskell нет статической проверки типов данных, она там есть, причем более сильная, чем в C++. Просто в данном случае работает параметрический полиморфизм, реализованный в Haskell.
Осталось привести вывод программы на Haskell:
Original: [10,5,89,16,78,67,56,34,12,10]
Lengths:  [1,1,2,2,3,3,3,3,2,2]
Trimmed:  [10,5,89,16,78,67,56,34]
Sequence: [5,16,34]
Original: [67,5,34,89,65,12,90,75,8,9,3]
Lengths:  [1,1,2,3,3,2,4,4,2,3,1]
Trimmed:  [67,5,34,89,65,12,90,75]
Sequence: [5,34,65,75]
Видно, что последний элемент одной и той же тестовой последовательности отличается от результата программы на C++. Это связано с тем, что в C++ варианте последний проход начинался от первого элемента с максимальным значением возрастающей подпоследовательности, а в Haskell варианте - с последнего. Как я уже говорил, могут существовать несколько возрастающих подпоследовательностей максимальной длины, поэтому оба случая являются правильным решением задачи.