web-dev-qa-db-ja.com

動的プログラミングを使用したテキストの位置揃えの実装

MIT OCW here )のコースで、動的プログラミングの概念を理解しようとしています。OCWビデオの説明はすばらしいですが、説明をコードに実装するまではあまり理解できませんが、実装する際には、講義ノート here のいくつかのノート、特にノートの3ページを参照します。

問題は、数学的表記の一部をコードに変換する方法がわからないことです。これが私が実装したソリューションの一部です(そして正しく実装されていると思います):

import math

paragraph = "Some long lorem ipsum text."
words = paragraph.split(" ")

# Count total length for all strings in a list of strings.
# This function will be used by the badness function below.
def total_length(str_arr):
    total = 0

    for string in str_arr:
        total = total + len(string)

    total = total + len(str_arr) # spaces
    return total

# Calculate the badness score for a Word.
# str_arr is assumed be send as Word[i:j] as in the notes
# we don't make i and j as argument since it will require
# global vars then.
def badness(str_arr, page_width):
    line_len = total_length(str_arr)
    if line_len > page_width:
        return float('nan') 
    else:
        return math.pow(page_width - line_len, 3)

今、わからないところは講義ノートの3〜5点目です。私は文字通り理解しておらず、どこにそれらを実装し始めるのか分かりません。これまでのところ、私は単語のリストを反復し、次のように各行の終わりとされるそれぞれの悪さを数えました:

def justifier(str_arr, page_width):
    paragraph = str_arr
    par_len = len(paragraph)
    result = [] # stores each line as list of strings
    for i in range(0, par_len):
        if i == (par_len - 1):
            result.append(paragraph)
        else:
            dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 
            # Should I do a min(dag), get the index, and declares it as end of line?

しかし、私はどうすればその機能を継続できるのかわかりません。正直なところ、この行は理解できません。

dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 

justifierintとして返す方法(リストであるresultに戻り値を格納することをすでに決定しているためです。別の関数を作成してから再帰する必要があります再帰が必要ですか?

次に何をすべきかを示し、これが動的プログラミングであることを説明していただけませんか?再帰がどこにあるのか本当にわかりません副問題です。

前に感謝します。

22
bertzzie

ダイナミックプログラミング自体のコアアイデアを理解するのに問題がある場合は、ここに私の見解を示します。

動的プログラミングは基本的にスペースの複雑さを犠牲にしています時間の複雑さ(ただし、使用する余分なスペースは通常very節約する時間と比較してわずかであり、正しく実装されていれば、動的プログラミングはそれだけの価値があります)。各再帰呼び出しの値を(配列やディクショナリなどに)格納しておくと、再帰ツリーの別のブランチで同じ再帰呼び出しを実行するときに、2回目の計算を回避できます。

そして、あなたはしませんnot再帰を使用する必要があります。これは、ループだけを使用して取り組んでいた質問の実装です。私は、AlexSilvaによってリンクされたTextAlignment.pdfを非常に密接に追跡しました。この情報がお役に立てば幸いです。

def length(wordLengths, i, j):
    return sum(wordLengths[i- 1:j]) + j - i + 1


def breakLine(text, L):
    # wl = lengths of words
    wl = [len(Word) for Word in text.split()]

    # n = number of words in the text
    n = len(wl)    

    # total badness of a text l1 ... li
    m = dict()
    # initialization
    m[0] = 0    

    # auxiliary array
    s = dict()

    # the actual algorithm
    for i in range(1, n + 1):
        sums = dict()
        k = i
        while (length(wl, k, i) <= L and k > 0):
            sums[(L - length(wl, k, i))**3 + m[k - 1]] = k
            k -= 1
        m[i] = min(sums)
        s[i] = sums[min(sums)]

    # actually do the splitting by working backwords
    line = 1
    while n > 1:
        print("line " + str(line) + ": " + str(s[n]) + "->" + str(n))
        n = s[n] - 1
        line += 1
21
Joohwan

これにまだ興味がある人のために:キーはテキストの終わりから逆方向に移動することです(前述のように here )。その場合は、すでに記憶されている要素を比較するだけです。

たとえば、wordstextwidthに従ってラップされる文字列のリストです。次に、講義の表記では、タスクは3行のコードに削減されます。

_import numpy as np

textwidth = 80

DP = [0]*(len(words)+1)

for i in range(len(words)-1,-1,-1):
    DP[i] = np.min([DP[j] + badness(words[i:j],textwidth) for j in range(i+1,len(words)+1)])
_

と:

_def badness(line,textwidth):

    # Number of gaps
    length_line = len(line) - 1

    for Word in line:
        length_line += len(Word)

    if length_line > textwidth: return float('inf')

    return ( textwidth - length_line )**3
_

彼は、2番目のリストを追加して、最新の位置を追跡できると述べています。これを行うには、コードを次のように変更します。

_DP = [0]*(len(words)+1)
breaks = [0]*(len(words)+1)

for i in range(len(words)-1,-1,-1):
    temp = [DP[j] + badness(words[i:j],args.textwidth) for j in range(i+1,len(words)+1)]

    index = np.argmin(temp)

    # Index plus position in upper list
    breaks[i] = index + i + 1
    DP[i] = temp[index]
_

テキストを復元するには、改行位置のリストを使用します。

_def reconstruct_text(words,breaks):                                                                                                                

    lines = []
    linebreaks = []

    i = 0 
    while True:

        linebreaks.append(breaks[i])
        i = breaks[i]

        if i == len(words):
            linebreaks.append(0)
            break

    for i in range( len(linebreaks) ):
        lines.append( ' '.join( words[ linebreaks[i-1] : linebreaks[i] ] ).strip() )

    return lines
_

結果:(text = reconstruct_text(words,breaks)

_Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy
eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam
voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet
clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit
amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam
nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed
diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet
clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.
_

空白を追加したくなるかもしれません。これはかなりトリッキーです(さまざまな美的ルールが考えられるため)。

_import re

def spacing(text,textwidth,maxspace=4):

    for i in range(len(text)):

        length_line = len(text[i])

        if length_line < textwidth:

            status_length = length_line
            whitespaces_remain = textwidth - status_length
            Nwhitespaces = text[i].count(' ')

            # If whitespaces (to add) per whitespace exeeds
            # maxspace, don't do anything.
            if whitespaces_remain/Nwhitespaces > maxspace-1:pass
            else:
                text[i] = text[i].replace(' ',' '*( 1 + int(whitespaces_remain/Nwhitespaces)) )
                status_length = len(text[i])

                # Periods have highest priority for whitespace insertion
                periods = text[i].split('.')

                # Can we add a whitespace behind each period?
                if len(periods) - 1 + status_length <= textwidth:
                    text[i] = '. '.join(periods).strip()

                status_length = len(text[i])
                whitespaces_remain = textwidth - status_length
                Nwords = len(text[i].split())
                Ngaps = Nwords - 1

                if whitespaces_remain != 0:factor = Ngaps / whitespaces_remain

                # List of whitespaces in line i
                gaps = re.findall('\s+', text[i])

                temp = text[i].split()
                for k in range(Ngaps):
                    temp[k] = ''.join([temp[k],gaps[k]])

                for j in range(whitespaces_remain):
                    if status_length >= textwidth:pass
                    else:
                        replace = temp[int(factor*j)]
                        replace = ''.join([replace, " "])
                        temp[int(factor*j)] = replace

                text[i] = ''.join(temp)

    return text
_

あなたに与えるもの:(text = spacing(text,textwidth)

_Lorem  ipsum  dolor  sit  amet, consetetur  sadipscing  elitr,  sed  diam nonumy
eirmod  tempor  invidunt  ut labore  et  dolore  magna aliquyam  erat,  sed diam
voluptua.   At  vero eos  et accusam  et justo  duo dolores  et ea  rebum.  Stet
clita  kasd  gubergren,  no  sea  takimata sanctus  est  Lorem  ipsum  dolor sit
amet.   Lorem  ipsum  dolor  sit amet,  consetetur  sadipscing  elitr,  sed diam
nonumy  eirmod  tempor invidunt  ut labore  et dolore  magna aliquyam  erat, sed
diam  voluptua.  At vero eos et accusam et  justo duo dolores et ea rebum.  Stet
clita  kasd gubergren, no sea  takimata sanctus est Lorem  ipsum dolor sit amet.
_
8
Suuuehgi

講義を見たところ、理解できるものなら何でも入れようと思いました。質問者と同じ形式でコードを入力しました。講義で説明したように、ここでは再帰を使用しました。
ポイント#3、再発を定義します。これは基本的にアプローチの底です。そこでは、より高い入力に関連する関数の値を以前に計算し、それを使用してより低い値の入力のを計算します。
講演ではそれを次のように説明しています:
DP(i)= min(DP(j)+ badness(i、j))
i + 1からnまで変化するjの場合。
ここで、iはnから0まで変化します(下から上へ!)。
DP(n)= 0として、
DP(n-1)= DP(n)+ badness(n-1、n)
次にD(n-2) from D(n-1) and D(n)そして、それらから最小限を取りなさい。
これで、i = 0まで下がることができ、それが悪さの最後の答えです!
ポイント4では、ご覧のとおり、ここで2つのループが発生しています。 1つはi、もう1つはi for jの内部。
したがって、i = 0のときj(max) = n、i = 1、j(max) = n-1、 ... i = n、j(max) = 0。
したがって、合計時間=これらの加算= n(n + 1)/ 2。
したがって、O(n ^ 2)。
ポイント#5は、DP [0]であるソリューションを識別するだけです!
お役に立てれば!

import math

justification_map = {}
min_map = {}

def total_length(str_arr):
    total = 0

    for string in str_arr:
        total = total + len(string)

    total = total + len(str_arr) - 1 # spaces
    return total

def badness(str_arr, page_width):
    line_len = total_length(str_arr)
    if line_len > page_width:
        return float('nan') 
    else:
        return math.pow(page_width - line_len, 3)

def justify(i, n, words, page_width):
    if i == n:

        return 0
    ans = []
    for j in range(i+1, n+1):
        #ans.append(justify(j, n, words, page_width)+ badness(words[i:j], page_width))
        ans.append(justification_map[j]+ badness(words[i:j], page_width))
    min_map[i] = ans.index(min(ans)) + 1
    return min(ans)

def main():
    print "Enter page width"
    page_width = input()
    print "Enter text"
    paragraph = input() 
    words = paragraph.split(' ')
    n = len(words)
    #justification_map[n] = 0 
    for i in reversed(range(n+1)):
        justification_map[i] = justify(i, n, words, page_width)

    print "Minimum badness achieved: ", justification_map[0]

    key = 0
    while(key <n):
        key = key + min_map[key]
        print key

if __name__ == '__main__':
    main()
1
Rindojiterika

これはあなたの定義によると私は思います。

import math

class Text(object):
    def __init__(self, words, width):
        self.words = words
        self.page_width = width
        self.str_arr = words
        self.memo = {}

    def total_length(self, str):
        total = 0
        for string in str:
            total = total + len(string)
        total = total + len(str) # spaces
        return total

    def badness(self, str):
        line_len = self.total_length(str)
        if line_len > self.page_width:
            return float('nan') 
        else:
            return math.pow(self.page_width - line_len, 3)

    def dp(self):
        n = len(self.str_arr)
        self.memo[n-1] = 0

        return self.judge(0)

    def judge(self, i):
        if i in self.memo:
            return self.memo[i]

        self.memo[i] = float('inf') 
        for j in range(i+1, len(self.str_arr)):
            bad = self.judge(j) + self.badness(self.str_arr[i:j])
            if bad < self.memo[i]:
                self.memo[i] = bad

        return self.memo[i]
0
user6043912

Java実装最大行幅をLとして、テキストTを正当化するという考えは、テキストのすべてのサフィックスを考慮することです(サフィックスを形成するために、文字ではなく単語を考慮してください)。動的プログラミングは「注意深い総当たり」に他なりません」ブルートフォースアプローチを検討する場合は、次のことを行う必要があります。

  1. 1、2、.. nワードを最初の行に置くことを検討してください。
  2. ケース1で説明した各ケース(たとえば、iワードが行1に配置される)について、1、2、.. n -iワードを2行目に配置し、残りの単語を3行目に配置する場合などを検討します。

代わりに、問題を検討して、Wordを行の先頭に置くコストを見つけてみましょう。一般に、DP(i)を(i-1)番目の単語を行の始まりと見なすためのコストとして定義できます。

DP(i)の再帰関係をどのように形成できますか?

J番目の単語が次の行の先頭である場合、現在の行にはwords [i:j)(jは含まれません)が含まれ、次の行の先頭であるj番目の単語のコストはDP(j)になります。したがって、DP(i)= DP(j)+現在の行にwords [i:j)を配置するコスト合計コストを最小化したいので、DP(i)は次のように定義できます。

繰り返し関係:

DP(i)= min {DP(j)+ [i + 1、n]内のすべてのjについて、words [i:j in the current line}を配置するコスト

J = nは、次の行に入力する単語が残っていないことを意味します。

基本ケース:DP(n)= 0 =>この時点で、書き込むワードは残っていません。

要約すると:

  1. サブ問題:接尾辞、単語[:i]
  2. 推測:次の行を開始する場所、選択肢の数n-i-> O(n)
  3. 繰り返し:DP(i)=最小{DP(j)+現在の行に単語[i:j)を置くコスト}メモ化を使用する場合、中括弧内の式はO(1)時間、およびループ実行O(n)回(選択回数#回)。inから0まで変化します=>したがって、全体の複雑度はO( n ^ 2)。

ここで、テキストを正当化するための最小コストを導出しましたが、上の式で選択されたj値を追跡することによって元の問題を解決し、後で同じ値を使用して正当化されたものを出力できるようにする必要があります。テキスト。親ポインターを保持するという考え方です。

これがソリューションの理解に役立つことを願っています。以下は、上記のアイデアの簡単な実装です。

 public class TextJustify {
    class IntPair {
        //The cost or badness
        final int x;

        //The index of Word at the beginning of a line
        final int y;
        IntPair(int x, int y) {this.x=x;this.y=y;}
    }
    public List<String> fullJustify(String[] words, int L) {
        IntPair[] memo = new IntPair[words.length + 1];

        //Base case
        memo[words.length] = new IntPair(0, 0);


        for(int i = words.length - 1; i >= 0; i--) {
            int score = Integer.MAX_VALUE;
            int nextLineIndex = i + 1;
            for(int j = i + 1; j <= words.length; j++) {
                int badness = calcBadness(words, i, j, L);
                if(badness < 0 || badness == Integer.MAX_VALUE) break;
                int currScore = badness + memo[j].x;
                if(currScore < 0 || currScore == Integer.MAX_VALUE) break;
                if(score > currScore) {
                    score = currScore;
                    nextLineIndex = j;
                }
            }
            memo[i] = new IntPair(score, nextLineIndex);
        }

        List<String> result = new ArrayList<>();
        int i = 0;
        while(i < words.length) {
            String line = getLine(words, i, memo[i].y);
            result.add(line);
            i = memo[i].y;
        }
        return result;
    }

    private int calcBadness(String[] words, int start, int end, int width) {
        int length = 0;
        for(int i = start; i < end; i++) {
            length += words[i].length();
            if(length > width) return Integer.MAX_VALUE;
            length++;
        }
        length--;
        int temp = width - length;
        return temp * temp;
    }


    private String getLine(String[] words, int start, int end) {
        StringBuilder sb = new StringBuilder();
        for(int i = start; i < end - 1; i++) {
            sb.append(words[i] + " ");
        }
        sb.append(words[end - 1]);

        return sb.toString();
    }
  }
0
self_noted