web-dev-qa-db-ja.com

scala実装の中央値

Scalaでの中央値の高速実装とは何ですか?

これは私が見つけたものです ロゼッタコード

_  def median(s: Seq[Double])  =
  {
    val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)
    if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head
  }
_

ある種のことをするので、私はそれが好きではありません。線形時間で中央値を計算する方法があることを私は知っています。

編集:

さまざまなシナリオで使用できる一連の中央値関数が必要です。

  1. 線形時間で実行できる高速でインプレースの中央値計算
  2. 複数回トラバースできるストリームで機能する中央値ですが、メモリに保持できるのはO(log n)値のみです このように
  3. ストリームで機能する中央値。メモリには最大でO(log n)の値を保持でき、ストリームを最大で1回トラバースできます(これも可能ですか?)

コンパイルおよび中央値を正しく計算するコードのみを投稿してください。簡単にするために、すべての入力に奇数の値が含まれていると想定できます。

36
dsg

不変のアルゴリズム

最初のアルゴリズム表示 by Taylor Leese は二次式ですが、平均は線形です。ただし、それはピボットの選択によって異なります。そこで、ここでは、プラグ可能なピボット選択と、ランダムピボットと中央値ピボットの中央値(線形時間を保証する)の両方を備えたバージョンを提供します。

_import scala.annotation.tailrec

@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partition (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partition (a ==)
        if (s.size > k) a
        else findKMedian(b, k - s.size)
    } else if (s.size < k) findKMedian(b, k - s.size)
    else findKMedian(s, k)
}

def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)
_

ランダムピボット(2次、線形平均)、不変

これはランダムなピボット選択です。ランダムな要因を使用したアルゴリズムの分析は、確率と統計を主に扱うため、通常よりも注意が必要です。

_def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))
_

中央値の中央値(線形)、不変

上記のアルゴリズムで使用した場合に線形時間を保証する中央値の中央値法。まず、中央値アルゴリズムの中央値の基礎となる、最大5つの数値の中央値を計算するアルゴリズム。これは Rex Kerr in this answer によって提供されました-アルゴリズムはその速度に大きく依存します。

_def medianUpTo5(five: Array[Double]): Double = {
  def order2(a: Array[Double], i: Int, j: Int) = {
    if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
  }

  def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
    if (a(i)<a(k)) { order2(a,j,k); a(j) }
    else { order2(a,i,l); a(i) }
  }

  if (five.length < 2) return five(0)
  order2(five,0,1)
  if (five.length < 4) return (
    if (five.length==2 || five(2) < five(0)) five(0)
    else if (five(2) > five(1)) five(1)
    else five(2)
  )
  order2(five,2,3)
  if (five.length < 5) pairs(five,0,1,2,3)
  else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
  else { order2(five,3,4); pairs(five,0,1,3,4) }
}
_

そして、中央値アルゴリズム自体の中央値。基本的に、選択されたピボットがリストの他の30%よりも大きく、小さくなることが保証されます。これは、前のアルゴリズムの線形性を保証するのに十分です。詳細については、別の回答で提供されているウィキペディアのリンクを調べてください。

_def medianOfMedians(arr: Array[Double]): Double = {
    val medians = arr grouped 5 map medianUpTo5 toArray;
    if (medians.size <= 5) medianUpTo5 (medians)
    else medianOfMedians(medians)
}
_

インプレースアルゴリズム

それで、これがアルゴリズムのインプレースバージョンです。アルゴリズムへの変更が最小限になるように、バッキング配列を使用してパーティションをインプレースで実装するクラスを使用しています。

_case class ArrayView(arr: Array[Double], from: Int, until: Int) {
    def apply(n: Int) = 
        if (from + n < until) arr(from + n)
        else throw new ArrayIndexOutOfBoundsException(n)

    def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
      var upper = until - 1
      var lower = from
      while (lower < upper) {
        while (lower < until && p(arr(lower))) lower += 1
        while (upper >= from && !p(arr(upper))) upper -= 1
        if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
      }
      (copy(until = lower), copy(from = lower))
    }

    def size = until - from
    def isEmpty = size <= 0

    override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
    def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}

@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partitionInPlace (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partitionInPlace (a ==)
        if (s.size > k) a
        else findKMedianInPlace(b, k - s.size)
    } else if (s.size < k) findKMedianInPlace(b, k - s.size)
    else findKMedianInPlace(s, k)
}

def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)
_

ランダムピボット、インプレース

中央値の中央値は、私が定義したArrayViewクラスによって現在提供されているものよりも多くのサポートを必要とするため、私はインプレースアルゴリズムのラドムピボットのみを実装しています。

_def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))
_

ヒストグラムアルゴリズム(O(log(n))メモリ)、不変

だから、ストリームについて。文字列の長さがわからない限り、一度しかトラバースできないストリームに対してO(n)メモリ以外のことを行うことはできません(その場合、私の本ではストリームではなくなります) 。

バケットの使用も少し問題がありますが、バケットを複数回トラバースできる場合は、バケットのサイズ、最大値、最小値を把握し、そこから作業できます。例えば:

_def findMedianHistogram(s: Traversable[Double]) = {
    def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
        // The buckets
        def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
        val buckets = new Array[Int](numberOfBuckets)

        // The upper limit of each bucket
        val max = s.max
        val min = s.min
        val increment = (max - min) / numberOfBuckets
        val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)

        // Return the bucket a number is supposed to be in
        def bucketIndex(d: Double) = indices indexWhere (d <=)

        // Compute how many in each bucket
        s foreach { d => buckets(bucketIndex(d)) += 1 }

        // Now make the buckets cumulative
        val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)

        // The bucket where our target is at
        val medianBucket = partialTotals indexWhere (medianIndex <)

        // Keep track of how many numbers there are that are less 
        // than the median bucket
        val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)

        // Test whether a number is in the median bucket
        def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket

        // Get a view of the target bucket
        val view = s.view filter insideMedianBucket

        // If all numbers in the bucket are equal, return that
        if (view forall (view.head ==)) view.head
        // Otherwise, recurse on that bucket
        else medianHistogram(view, newDiscarded, medianIndex)
    }

    medianHistogram(s, 0, (s.size - 1) / 2)
}
_

テストとベンチマーク

アルゴリズムをテストするために、 Scalacheck を使用し、各アルゴリズムの出力を、並べ替えを使用した簡単な実装の出力と比較しています。もちろん、これはソートバージョンが正しいことを前提としています。

上記の各アルゴリズムのベンチマークを、提供されているすべてのピボット選択に加えて、固定ピボット選択(配列の途中、切り捨て)でベンチマークしています。各アルゴリズムは、3つの異なる入力配列サイズで、それぞれに対して3回テストされます。

テストコードは次のとおりです。

_import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._

def test(algorithm: Array[Double] => Double, 
         reference: Array[Double] => Double): String = {
    def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
    val resultEqualsReference = forAll { (arr: Array[Double]) => 
        arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
    }
    Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}

import Java.lang.System.currentTimeMillis

def bench[A](n: Int)(body: => A): Long = {
  val start = currentTimeMillis()
  1 to n foreach { _ => body }
  currentTimeMillis() - start
}

import scala.util.Random.nextDouble

def benchmark(algorithm: Array[Double] => Double,
              arraySizes: List[Int]): List[Iterable[Long]] = 
    for (size <- arraySizes)
    yield for (iteration <- 1 to 3)
        yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))

def testAndBenchmark: String = {
    val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
        "Random Pivot"      -> chooseRandomPivot,
        "Median of Medians" -> medianOfMedians,
        "Midpoint"          -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
    )
    val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
        "Random Pivot (in-place)" -> chooseRandomPivotInPlace,
        "Midpoint (in-place)"     -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
    )
    val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
        yield name -> (findMedian(_: Array[Double])(pivotSelection))
    val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
        yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
    val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
    val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
    val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms

    val formattingString = "%%-%ds  %%s" format (algorithms map (_._1.length) max)

    // Tests
    val testResults = for ((name, algorithm) <- algorithms)
        yield formattingString format (name, test(algorithm, sortingAlgorithm._2))

    // Benchmarks
    val arraySizes = List(100, 500, 1000)
    def formatResults(results: List[Long]) = results map ("%8d" format _) mkString

    val benchmarkResults: List[String] = for {
        (name, algorithm) <- algorithms
        results <- benchmark(algorithm, arraySizes).transpose
    } yield formattingString format (name, formatResults(results))

    val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))

    "Tests" :: "*****" :: testResults ::: 
    ("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
}
_

結果

テスト:

_Tests
*****
Sorting                OK, passed 100 tests.
Histogram              OK, passed 100 tests.
Random Pivot           OK, passed 100 tests.
Median of Medians      OK, passed 100 tests.
Midpoint               OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place)    OK, passed 100 tests.
_

ベンチマーク:

_Benchmark
*********
Algorithm                   100     500    1000
Sorting                    1038    6230   14034
Sorting                    1037    6223   13777
Sorting                    1039    6220   13785
Histogram                  2918   11065   21590
Histogram                  2596   11046   21486
Histogram                  2592   11044   21606
Random Pivot                904    4330    8622
Random Pivot                902    4323    8815
Random Pivot                896    4348    8767
Median of Medians          3591   16857   33307
Median of Medians          3530   16872   33321
Median of Medians          3517   16793   33358
Midpoint                   1003    4672    9236
Midpoint                   1010    4755    9157
Midpoint                   1017    4663    9166
Random Pivot (in-place)     392    1746    3430
Random Pivot (in-place)     386    1747    3424
Random Pivot (in-place)     386    1751    3431
Midpoint (in-place)         378    1735    3405
Midpoint (in-place)         377    1740    3408
Midpoint (in-place)         375    1736    3408
_

分析

すべてのアルゴリズム(ソートバージョンを除く)には、平均線形時間計算量と互換性のある結果があります。

最悪の場合に線形時間計算量を保証する中央値の中央値は、ランダムピボットよりもはるかに遅くなります。

固定ピボットの選択はランダムピボットよりもわずかに劣りますが、非ランダム入力ではパフォーマンスが大幅に低下する可能性があります。

インプレースバージョンは約230%〜250%高速ですが、さらなるテスト(図示せず)は、この利点がアレイのサイズとともに増大することを示しているようです。

ヒストグラムアルゴリズムにはとても驚きました。線形時間計算量平均を表示し、中央値の中央値よりも33%高速です。ただし、入力ランダムです。最悪のケースは2次式です。コードのデバッグ中にいくつかの例を見ました。

59