пятница, 24 июля 2015 г.

Мемоизация функций в haskell на конкретном примере

Возьмем пресловутую задачу вычисления n-ого числа Фибоначчи. Наивное решение на haskell
fib 0 = 0
fib 1 = 1
fib n = fib (n - 1) + fib (n - 2)
не будет работать для больших n: попробуйте посчитать fib 50! Проблема заключается в экспоненциальном росте числа рекурсивных вызовов fib при увеличении n и невозможности сведения этих вызовов при таком определении функции в простой цикл. И это при том, что интуитивно задача является простейшей, а для получения, скажем, 50-ого числа, нужно всего-то вычислить 48 последовательных значений fib, начиная с n = 2. То есть практически нам нужно сделать всего 48 вычислений fib и каким-то образом запомнить их результаты, чтобы не вызывать функцию fib вновь и вновь. Механизм запоминания результатов вызовов функции для разных ее аргументов с дальнейшей подстановкой уже вычисленных значений вместо повторных вызовов этой же функции с теми же аргументами называется мемоизацией (memoization). Удивительно, но мемоизация функций поддерживается компилятором ghc из коробки. Почему же тогда в приведенной реализации функции fib она не работает? Оказывается, компилятор не может мемоизировать вызовы функций со связанными аргументами, поскольку связанный аргумент может изменить семантику вызова функции и не гарантирует одинаковый результат в случае ее повторных вызовов. Зато функции, объявленные без указания аргументов (то есть бесточечные определения) будут успешно мемоизироваться! Очень подробно мемоизация функций рассматривается в этом руководстве, там же приводится решение задачи вычисления n-ого числа Фибоначчи с определением функции в бесточечном стиле. Я не стану больше говорить о числах Фибоначчи, однако замечу, что очевидный алгоритм решения этой задачи, следующий непосредственно из определения fib и заключающийся в последовательном нахождении всех чисел fib i для i от 2 до n, представляет собой классический пример так называемого динамического программирования. В этой статье я покажу решение похожей задачи о поиске минимального числа монет с номиналами из заданного списка c, которые в сумме составляют заданную стоимость m. Алгоритм решения так же основан на последовательном вычислении частных решений для стоимостей от 1 до m в предположении, что количество монет для стоимости 0 равно 0. Подробное описание алгоритма и псевдокод можно легко найти в интернете (например, здесь). Поэтому я не стану на этом останавливаться, а сразу же приведу наивное решение.
minNumCoinsPlain :: [Int] -> Int -> Int
minNumCoinsPlain c m = map (mnc c) [0 .. m] !! m
    where mnc c m = foldl (\a x -> min (minNumCoinsPlain c (m - x) + 1) a) m $
                    filter (<= m) c
Наивное оно по той же причине, что и приведенное решение задачи о числах Фибоначчи — слишком много рекурсивных вызовов minNumCoinsPlain при больших значениях m. Вы можете скачать исходный код примера отсюда, загрузить его в ghci и протестировать при не очень больших m (начните с 25, при 50 вычисления должны плотно висеть, при необходимости их можно прервать стандартной комбинацией клавиш Ctrl-C).
:l minNumCoins
[1 of 1] Compiling Main             ( minNumCoins.hs, interpreted )
Ok, modules loaded: Main.
:set +s
minNumCoinsPlain [1, 2, 5] 25
5
(4.22 secs, 1,137,893,584 bytes)
Больше четырех секунд вычислений для задачи, которую решит любой школьник не задумываясь! Давайте уберем аргумент m из определения функции: останется аргумент c — функция не стала бесточечной, но ведь c вообще не меняется между вызовами, может компилятор сможет мемоизировать minNumCoinsPlain в таком случае?
minNumCoins_ :: [Int] -> Int -> Int
minNumCoins_ c = (map (mnc c) [0 ..] !!)
    where mnc _ 0 = 0
          mnc c m = foldl (\a x -> min (minNumCoins_ c (m - x) + 1) a) m $
                    filter (<= m) c
Проверяем в ghci.
minNumCoins_ [1, 2, 5] 25
5
(3.66 secs, 1,084,731,832 bytes)
Чуда не произошло. Наличие в определении функции связанного аргумента c способно повлиять на семантику вызовов функции, поэтому компилятор не может выполнить ее мемоизацию. Кстати, метод избавления от связанного аргумента m с помощью применения сечения функции (!!) я взял из упомянутого руководства на haskell.org. Давайте избавимся и от аргумента c! Это сделать не так уж сложно, учитывая, что он не изменяется между вызовами функции.
minNumCoins :: [Int] -> Int -> Int
minNumCoins c =
    let minNumCoinsMemo = (map (mnc c) [0 ..] !!)
        mnc _ 0 = 0
        mnc c m = foldl (\a x -> min (minNumCoinsMemo (m - x) + 1) a) m $
                  filter (<= m) c
    in minNumCoinsMemo
Видите, я зафиксировал c в функции-обертке minNumCoins, в то время как основная рабочая лошадка — функция minNumCoinsMemo — объявлена в бесточечном стиле и вызывается рекурсивно из вспомогательной функции mnc c. Это значит, что ghc должен ее наконец-то мемоизировать!
minNumCoins [1, 2, 5] 25
5
(0.01 secs, 3,910,424 bytes)
Ура! Мгновенное выполнение. Проверим на большой стоимости m.
minNumCoins [1, 2, 5] 10001
2001
(16.53 secs, 0 bytes)
Не быстро, но и не бесконечно долго! Собственно, фиксация аргументов, которую мы здесь проделали вручную, может быть реализована с помощью функции fix из модуля Data.Function: читайте упомянутое выше руководство. Однако, при количестве аргументов от двух и более, прямое использование fix становится трудоемким и подверженным ошибкам. Поэтому лучше всего воспользоваться готовыми функциями из модуля Data.Function.Memoize. Для нашей задачи, в частности, нужна функция memoFix2, поскольку наша рекурсивная функция ожидает на входе два аргумента c и m. Давайте протестируем решение с memoFix2. Прежде всего следует импортировать эту функцию из модуля Data.Function.Memoize.
import Data.Function.Memoize (memoFix2)
Определение новой функции minNumCoinsMemo.
minNumCoinsMemo :: [Int] -> Int -> Int
minNumCoinsMemo = memoFix2 mnc
    where mnc _ _ 0 = 0
          mnc f c m = foldl (\a x -> min (f c (m - x) + 1) a) m $
                      filter (<= m) c
Обратите внимание, функция minNumCoinsMemo объявлена в бесточечном стиле, а у вспомогательной функции mnc появился новый аргумент f, который вышел на первое место — это и есть наша фиксированная функция, которая рекурсивно вызывается внутри mnc. Проверим ее скорость.
minNumCoinsMemo [1, 2, 5] 10001
2001
(1.00 secs, 527,741,984 bytes)
Вот это уже очень хорошо! Новая функция оказалась быстрее нашей самодельной функции minNumCoins на порядок. Но это еще не всё. Хотя с мемоизацией мы закончили. На самом деле, природа задачи о минимальном количестве монет такова, что мы можем попробовать вручную собирать уже подсчитанные результаты в список или в ассоциативный массив (Map): это легко сделать с помощью функций mapAccumL или mapAccumR из модуля Data.List. Очевидно, что поиск уже найденных значений внутри списка потребует линейной скорости, а внутри Map — логарифмической. Я не знаю как организован поиск мемоизированных значений изнутри, но даже если он происходит за константное время, логарифм при таких небольших значениях как 10001 может составить ему хорошую конкуренцию. Поэтому давайте протестируем этакую народную самописную реализацию мемоизации на основе хорошо оптимизированного, строгого типа IntMap. Импортируем нужные функции.
import Data.IntMap.Strict (singleton, insert, (!))
import Data.List (mapAccumL)
Определяем новую функцию
minNumCoinsMap :: [Int] -> Int -> Int
minNumCoinsMap c m = snd (mapAccumL (mnc c) (singleton 0 0) [0 .. m]) !! m
    where mnc c s m = foldl step (s, m) $ filter (<= m) c
            where step a x = let cur = min (s ! (m - x) + 1) $ snd a
                             in (insert m cur s, cur)
Почти все то же самое, но на этот раз вместо рекурсивных вызовов функций типа f c (m - x), внутри функции свертки step мы обращаемся к соответствующему элементу массива s в предположении, что он уже был вычислен в предыдущих итерациях. Проверяем функцию minNumCoinsMap.
minNumCoinsMap [1, 2, 5] 10001
2001
(0.15 secs, 20,208,520 bytes)
Весьма красноречивый результат. Народная мемоизация в этой задаче — чемпион. Напоследок приведу ссылку на замечательную статью из журнала Практика функционального программирования, в которой исследуется взаимосвязь рекурсии, мемоизации и динамического программирования.