web-dev-qa-db-ja.com

期待値最大化手法の直感的な説明は何ですか?

期待値最大化(EM)は、データを分類するための一種の確率的手法です。それが分類器でない場合、私が間違っている場合は私を修正してください。

このEM技術の直感的な説明は何ですか?ここでexpectationとは何ですか?maximizedとは何ですか?

102
London guy

注:この答えの背後にあるコードは こちらにあります


赤と青の2つの異なるグループからサンプリングされたデータがあるとします。

enter image description here

ここで、赤または青のグループに属するデータポイントを確認できます。これにより、各グループを特徴付けるパラメーターを簡単に見つけることができます。たとえば、赤のグループの平均は約3、青のグループの平均は約7です(必要に応じて正確な平均を見つけることができます)。

これは、一般的に言えば、最尤推定として知られています。いくつかのデータが与えられたら、そのデータを最もよく説明するパラメーターの値を計算します。

ここで、どのグループからどの値がサンプリングされたかをできないと想像してください。すべてが紫色に見えます。

enter image description here

ここには、値のグループがtwoあるという知識がありますが、特定の値がどのグループに属しているかはわかりません。

このデータに最適な赤のグループと青のグループの平均を推定できますか?

はい、できます! 期待値最大化はそれを行う方法を提供します。アルゴリズムの背後にある非常に一般的な考え方は次のとおりです。

  1. 各パラメーターの初期推定値から始めます。
  2. 各パラメーターがデータポイントを生成する尤度を計算します。
  3. パラメーターによって生成される可能性に基づいて、各データポイントの重みを計算します。重みとデータを組み合わせます(expectation)。
  4. 重み調整されたデータ(maximisation)を使用して、パラメーターのより良い推定値を計算します。
  5. パラメーター推定値が収束するまで、プロセス2から4を繰り返します(プロセスは別の推定値の生成を停止します)。

これらの手順にはさらに説明が必要なため、上記の問題について説明します。

例:平均と標準偏差の推定

この例ではPythonを使用しますが、この言語に精通していない場合、コードはかなり簡単に理解できるはずです。

赤と青の2つのグループがあり、上の画像のように値が分布しているとします。具体的には、各グループには 正規分布 から引き出された値が含まれ、次のパラメーターがあります。

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

これらの赤と青のグループの画像を再度示します(上にスクロールしなくて済むように)。

enter image description here

各ポイントの色(つまり、どのグループに属するか)がわかると、各グループの平均と標準偏差を非常に簡単に推定できます。赤と青の値をNumPyの組み込み関数に渡すだけです。例えば:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

しかし、ポイントの色が見えない場合はどうでしょうか?つまり、赤または青ではなく、すべてのポイントが紫色に着色されています。

赤と青のグループの平均と標準偏差のパラメーターを回復するために、期待値の最大化を使用できます。

最初のステップ(ステップ1)は、各グループの平均と標準偏差のパラメーター値を推測することです。知的に推測する必要はありません。好きな数字を選ぶことができます:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

これらのパラメーター推定値は、次のようなベル曲線を生成します。

enter image description here

これらは悪い見積もりです。両方の手段(垂直の点線)は、たとえば、賢明なポイントのグループについては、あらゆる種類の「中間」から遠く離れているように見えます。これらの見積もりを改善したいと考えています。

次のステップ(ステップ2)は、現在のパラメーターの推測の下に表示される各データポイントの尤度を計算することです。

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

ここでは、赤と青の平均と標準偏差の現在の推測を使用して、正規分布の各データポイントを 確率密度関数 に単純に入れています。これは、たとえば、現在の推測では1.761のデータポイントはmuchが青(0.00003)よりも赤(0.189)である可能性が高いことを示しています。

各データポイントについて、これらの2つの尤度値を重み(ステップ3)に変換して、次のように合計して1にすることができます。

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

現在の推定値と新しく計算された重みを使用して、赤と青のグループの平均と標準偏差のnew推定値を計算できるようになりました(ステップ4)。

allデータポイントを使用して平均と標準偏差を2回計算しますが、異なる重みを使用します。1回は赤の重み、1回は青の重みです。

直観の重要な点は、データポイントの色の重みが大きいほど、データポイントがその色のパラメータの次の推定値に影響を与えることです。これには、パラメーターを正しい方向に「プル」する効果があります。

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

パラメーターの新しい推定値があります。再び改善するには、ステップ2に戻ってプロセスを繰り返します。これは、推定値が収束するまで、またはいくつかの反復が実行されるまで行います(ステップ5)。

データの場合、このプロセスの最初の5回の反復は次のようになります(最近の反復ではより強い外観になります)。

enter image description here

平均はすでにいくつかの値に収束しており、曲線の形状(標準偏差で管理)もより安定していることがわかります。

20回繰り返した場合、次のようになります。

enter image description here

EMプロセスは次の値に収束しましたが、実際の値に非常に近い値になりました(色を確認できる場所-隠された変数はありません)。

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

上記のコードでは、標準偏差の新しい推定値が、以前の反復の平均値の推定値を使用して計算されたことに気付いたかもしれません。最終的には、いくつかの中心点の周りの値の(加重)分散を見つけるだけなので、最初に平均値の新しい値を計算するかどうかは問題ではありません。パラメータの推定値はまだ収束します。

110
Alex Riley

EMは、モデル内の変数の一部が観察されない場合(潜在変数がある場合)に尤度関数を最大化するためのアルゴリズムです。

単に関数を最大化しようとしているのであれば、関数を最大化するために既存の機械を使用しないでください。さて、導関数を取得してゼロに設定することでこれを最大化しようとすると、多くの場合、1次条件には解決策がないことがわかります。モデルパラメータを解くには、観測されていないデータの分布を知る必要があるという鶏と卵の問題があります。しかし、観測されていないデータの分布は、モデルパラメーターの関数です。

E-Mは、観測されていないデータの分布を繰り返し推測し、実際の尤度関数の下限であるものを最大化してモデルパラメーターを推定し、収束するまで繰り返すことで、この問題を回避しようとします。

EMアルゴリズム

モデルパラメーターの値の推測から始めます

Eステップ:欠損値のある各データポイントについて、モデル方程式を使用して、モデルパラメーターの現在の推測と観測データを指定して、欠損データの分布を解きます(各欠損の分布を解くことに注意してください)値、予想される値ではありません)。各欠損値の分布ができたので、観測されていない変数に関する尤度関数のexpectationを計算できます。モデルパラメーターの推測が正しかった場合、この予想尤度は観測データの実際の尤度になります。パラメータが正しくなかった場合は、下限になります。

Mステップ:観測されていない変数を含まない予想尤度関数が得られたので、完全に観測された場合のように関数を最大化し、モデルパラメーターの新しい推定値を取得します。

収束するまで繰り返します。

35
Marc Shivers

以下は、期待値最大化アルゴリズムを理解するための簡単なレシピです。

1-これを読む EMチュートリアルペーパー DoおよびBatzoglou著。

2-あなたは頭に疑問符を持っているかもしれません。この数学スタック交換の説明を見てください ページ

3-アイテム1のEMチュートリアルペーパーの例を説明するPythonで書いたこのコードを見てください。

Warning:私はPythonではないため、コードが乱雑/最適ではない可能性があります開発者。しかし、それは仕事をします。

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = Zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1
27
Zhubarb

技術的には「EM」という用語は少し指定不足ですが、一般的なEM原理のinstanceであるGaussian Mixture Modelingクラスター分析手法を参照していると思います。

実際、EMクラスター分析は分類子ではありません。一部の人々はクラスタリングを「教師なし分類」と見なしていることを知っていますが、実際にはクラスター分析はまったく異なるものです。

主な違いと、人々がクラスター分析で常に持つ大きな誤解の分類は、次のとおりです。クラスター分析では、「正しい解」はありません。それは知識ですdiscoveryメソッド、実際には何かを見つけることを意図していますnew!これにより、評価が非常に難しくなります。多くの場合、参照として既知の分類を使用して評価されますが、それは常に適切であるとは限りません。あなたが持っている分類は、データの内容を反映する場合としない場合があります。

例を挙げましょう。性別データを含む顧客の大規模なデータセットがあります。このデータセットを「男性」と「女性」に分割する方法は、既存のクラスと比較するときに最適です。 「予測」の考え方では、これは良いことです。新しいユーザーについては、性別を予測できるようになりました。 「知識発見」の考え方では、これは実際には悪いことです。なぜなら、あなたはデータの中のいくつかの新しい構造を発見したかったからです。例えばただし、データを高齢者と子供に分割すると、男性/女性のクラスに関して取得可能な限り悪化するが得点されます。ただし、それは優れたクラスタリング結果です(年齢が指定されていない場合)。

EMに戻りましょう。基本的に、データは複数の多変量正規分布で構成されていることを前提としています(特にクラスター数を修正する場合、これはveryの強い仮定であることに注意してください!)。次に、モデルとモデルへのオブジェクト割り当てを交互に改善するによって、このためのローカル最適モデルを見つけようとします。

分類コンテキストで最良の結果を得るには、クラスターの数を選択します大きいクラスの数よりも多く、またはクラスタリングを単一クラスのみに適用します(存在するかどうかを調べるために)クラス内のいくつかの構造!)。

「車」、「自転車」、「トラック」を区別するために分類器を訓練するとします。データが正確に3つの正規分布で構成されていると仮定する場合、ほとんど使用されません。ただし、車には複数のタイプがあります(およびトラックとバイク)と仮定できます。したがって、これらの3つのクラスの分類子をトレーニングする代わりに、車、トラック、自転車をそれぞれ10個のクラスター(または、10台の車、3つのトラック、3つの自転車など)にクラスター化してから、これらの30のクラスを区別するように分類器をトレーニングします。クラスの結果を元のクラスにマージします。また、Trikesなど、分類が特に難しいクラスターが1つあることもあります。彼らはやや車であり、やや自転車です。または配達用トラック、それはトラックよりも特大の車のようなものです。

16
Anony-Mousse

受け入れられた回答は Chuong EM Paper を参照しており、EMを説明するまともな仕事をしています。 youtube video もあり、これにより論文の詳細が説明されます。

要約すると、シナリオは次のとおりです。

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

最初のトライアルの質問の場合、ヘッドの割合がBのバイアスに非常によく一致するので、直感的にBがそれを生成したと思いますが、その値は単なる推測であるため、確信が持てません。

それを念頭に置いて、私はこのようなEMソリューションを考えたいです:

  • フリップの各試行は、最もコインが好きな「投票」に到達します
    • これは、各コインがその分布にどれだけうまく適合するかに基づいています
    • または、コインの観点から見ると、他のコインと比較してこの試行を見ることの高い期待があります(対数尤度に基づく)。
  • 各トライアルが各コインをどれだけ気に入っているかに応じて、そのコインのパラメーター(バイアス)の推測を更新できます。
    • 試用版がコインを好むほど、コインのバイアスを更新して、コインを反映させます!
    • 基本的に、コインのバイアスは、すべての試行にわたってこれらの重み付き更新を組み合わせることによって更新されます。これは、(maximazation)と呼ばれるプロセスです。

これは単純化しすぎかもしれません(または一部のレベルでは根本的に間違っているかもしれません)が、これは直感的なレベルで役立つことを願っています!

2
lucidv01d

他の答えは良いです、私は別の視点を提供し、質問の直感的な部分に取り組むようにします。

EM(期待値最大化)アルゴリズム は、 duality を使用した反復アルゴリズムのクラスの変形です

抜粋(エンファシス鉱山):

数学では、一般的に、双対性は概念、定理または数学的構造を他の概念、定理または構造に1対1の方法で、しばしば(常にではないが)インボリューション演算によって変換します。 AがBである場合、Bの双対はAです。このようなインボリューションには固定小数点が含まれている場合があるため、Aの双対はA自体です。

通常、dual Bのobject Aは、何らかの方法でsymmetry or compatibleを保持するAに関連しています。たとえば、AB =const

(以前の意味で)双対性を使用する反復アルゴリズムの例は次のとおりです。

  1. 最大公約数のユークリッドアルゴリズム、およびその変形
  2. Gram–Schmidt Vector Basisアルゴリズムとバリアント
  3. 算術平均-幾何平均不等式、およびその変形
  4. 期待値最大化アルゴリズムとその変形情報幾何学的ビューについてはこちら も参照)
  5. (..他の同様のアルゴリズム..)

同様の方法で、 EMアルゴリズムは2つの二重最大化ステップとして見ることもできます

.. [EM]は、パラメーターと観測されていない変数の分布の結合関数を最大化するものと見なされます。Eステップは、観測されていない変数の分布に関してこの関数を最大化します。パラメーターに関するMステップ.

双対性を使用する反復アルゴリズムでは、平衡(または固定)収束点の明示的な(または暗黙の)仮定があります(EMの場合、これはジェンセンの不等式を使用して証明されます)

そのため、このようなアルゴリズムの概要は次のとおりです。

  1. Eのようなステップ:一定に保たれている特定のyに関して最適なソリューションxを見つけます。
  2. M-like step(dual):xに関してyを見つける前のステップ)一定に保持されます。
  3. 終了/収束ステップの基準:更新された値xyでステップ1、2を繰り返します収束するまで(または指定された反復回数に達するまで)

このようなアルゴリズムが(グローバル)最適に収束すると、両方の意味でベストな構成が見つかりました(つまり、 xドメイン/パラメーターおよびyドメイン/パラメーター)。ただし、アルゴリズムはglobal最適ではなくlocal最適を見つけることができます。

これは、アルゴリズムのアウトラインの直感的な説明だと思います

統計的議論と応用については、他の回答が良い説明を与えています(この回答の参考文献もチェックしてください)

2
Nikos M.

EMは、潜在変数Zを持つモデルQの尤度を最大化するために使用されます。

これは反復的な最適化です。

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-step:Zの現在の推定値が与えられ、予想される対数尤度関数を計算します

mステップ:このQを最大化するthetaを見つける

GMMの例:

e-step:現在のgmmパラメーターの推定値から、各データポイントのラベル割り当てを推定します

mステップ:新しいラベル割り当てを指定して、新しいシータを最大化します

K-meansもEMアルゴリズムであり、K-meansでのアニメーションの説明はたくさんあります。

1
SlimJim

Zhubarbの答えで引用されたDoとBatzoglouによる同じ記事を使用して、その問題のEMをJavaで実装しました。彼の答えへのコメントは、アルゴリズムがローカル最適で立ち往生していることを示しています。これは、パラメータthetaAとthetaBが同じ場合にも実装で発生します。

以下は、コードの標準出力であり、パラメーターの収束を示しています。

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

以下は、(Do and Batzoglou、2008)の問題を解決するためのEMのJava実装です。実装のコア部分は、パラメーターが収束するまでEMを実行するループです。

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

以下はコード全体です。

import Java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static Java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}