web-dev-qa-db-ja.com

Haskellでのメモ化?

Haskellの次の関数を効率的に解決する方法に関するすべてのポインター(多数の場合)(n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

Haskellでフィボナッチ数を解決するためのメモ化の例を見てきました。これには、必要なnまでのすべてのフィボナッチ数を(怠iに)計算する必要がありました。ただし、この場合、指定されたnについて、中間結果を計算する必要はほとんどありません。

ありがとう

128

Edwardの答え は非常に素晴らしい宝石であり、私はそれを複製し、関数をオープン再帰形式で記憶するmemoListおよびmemoTreeコンビネータの実装を提供しました。

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
17
Tom Ellis

最も効率的な方法ではありませんが、次のことをメモします。

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

f !! 144を要求すると、f !! 143が存在することが確認されますが、その正確な値は計算されません。まだ未知の計算結果として設定されています。計算される正確な値は、必要な値のみです。

したがって、最初は、計算された量に関しては、プログラムは何も知りません。

f = .... 

f !! 12をリクエストすると、パターンマッチングが開始されます。

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

今、計算を開始します

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

これは、fに対して別の要求を再帰的に行うため、計算します

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

今、私たちは少しバックアップすることができます

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

つまり、プログラムは次のことを認識しています。

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

引き続き細流化:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

つまり、プログラムは次のことを認識しています。

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

f!!6の計算を続けます:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

つまり、プログラムは次のことを認識しています。

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

f!!12の計算を続けます:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

つまり、プログラムは次のことを認識しています。

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

したがって、計算はかなり遅れて行われます。プログラムは、f !! 8の値が存在し、g 8と等しいことを知っていますが、g 8が何であるかはわかりません。

12
rampion

Edward Kmettの回答で述べたように、物事をスピードアップするには、高価な計算をキャッシュし、それらにすばやくアクセスできるようにする必要があります。

関数を単項でない状態に保つために、無限の遅延ツリーを構築するソリューション(適切なインデックス方法(前の投稿を参照)を使用)は、その目標を達成します。関数の非モナド的な性質を放棄すると、Haskellで利用可能な標準の連想コンテナを「状態のような」モナド(StateやSTなど)と組み合わせて使用​​できます。

主な欠点は、非単項関数を取得することですが、これ以上構造をインデックス化する必要はなく、連想コンテナの標準実装を使用できます。

そのためには、最初に関数を書き直して、あらゆる種類のモナドを受け入れる必要があります。

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

テストでは、Data.Function.fixを使用してメモ化を行わない関数を定義できますが、もう少し冗長です。

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

その後、ステートモナドをData.Mapと組み合わせて使用​​して、処理を高速化できます。

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

わずかな変更で、代わりにData.HashMapで動作するようにコードを調整できます。

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

永続的なデータ構造の代わりに、STモナドと組み合わせて可変データ構造(Data.HashTableなど)を試すこともできます。

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

メモ化を行わない実装と比較すると、これらの実装はいずれも、膨大な入力に対して、数秒待たずにマイクロ秒単位で結果を得ることができます。

Criterionをベンチマークとして使用すると、Data.HashMapを使用した実装は、タイミングが非常に類似したData.MapとData.HashTableよりも実際にわずかに(約20%)優れたパフォーマンスを発揮することがわかりました。

ベンチマークの結果は少し驚くべきものでした。私の最初の感じは、HashTableは変更可能であるため、HashMapの実装よりも優れているということでした。この最後の実装には、パフォーマンスの欠陥が隠されている可能性があります。

8
Quentin

これは、エドワードクメットの優れた答えに対する補遺です。

私が彼のコードを試したとき、natsindexの定義はかなり不可解であるように思えたので、理解しやすい代替バージョンを作成しました。

index'nats'の観点からindexnatsを定義します。

index' t n[1..]の範囲で定義されています。 (index t[0..]の範囲で定義されていることを思い出してください。)nをビットの文字列として扱い、ビットを逆方向に読み取ることでツリーを検索します。ビットが1の場合、右側の分岐を取ります。ビットが0の場合、左側の分岐を取ります。最後のビット(1でなければなりません)に達すると停止します。

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

natsindexに対して定義されているためindex nats n == nが常にtrueであるように、nats'index'に対して定義されています。

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

現在、natsindexは、単にnats'index'ですが、値は1シフトされています。

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
8
Pitarou

数年後、私はこれを見て、zipWithとヘルパー関数を使用して線形時間でこれをメモする簡単な方法があることに気付きました。

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilateには、dilate n xs !! i == xs !! div i n

したがって、f(0)が与えられたとすると、これは計算を単純化して

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

元の問題の説明によく似ており、線形解(sum $ take n fsはO(n)を取ります)。

4
rampion

エドワード・ケメットの答えに対する別の補遺:自己完結型の例:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

次のように使用して、単一の整数arg(たとえばfibonacci)を持つ関数をメモします。

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

負でない引数の値のみがキャッシュされます。

負の引数の値もキャッシュするには、次のように定義されたmemoIntを使用します。

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

2つの整数引数を持つ関数の値をキャッシュするには、次のように定義されたmemoIntIntを使用します。

memoIntInt f = memoInt (\n -> memoInt (f n))
2
Neal Young

エドワードKMETTに基づいていない、インデックス付けのないソリューション。

共通のサブツリーを共通の親に分解します(f(n/4)f(n/2)f(n/4)で共有され、f(n/6)f(2)およびf(3))。それらを親の単一変数として保存することにより、サブツリーの計算が1回行われます。

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

コードは一般的なメモ化関数に簡単に拡張できません(少なくとも、それを行う方法はわかりません)。そして、サブ問題がどのように重複するかを本当に考えなければなりませんが、strategyは機能するはずです一般的な複数の非整数パラメーターの場合。 (2つの文字列パラメーターについて考えました。)

メモは各計算後に破棄されます。 (繰り返しますが、2つの文字列パラメーターについて考えていました。)

これが他の答えよりも効率的かどうかはわかりません。各ルックアップは技術的には1つまたは2つのステップ(「あなたの子供またはあなたの子供の子供を見てください」)だけですが、多くの余分なメモリ使用があるかもしれません。

編集:この解決策はまだ正しくありません。共有は不完全です。

編集:今では子を適切に共有する必要がありますが、この問題には多くの重要な共有があることに気付きました:n/2/2/2およびn/3/3は同じかもしれません。この問題は私の戦略によく合いません。

2
leewz