web-dev-qa-db-ja.com

(a + sqrt(b))という形式の2つの値を可能な限り速く比較しますか?

私が書いているプログラムの一部として、2つの値をa + sqrt(b)の形式で比較する必要があります。ここで、abは符号なし整数です。これはタイトなループの一部であるため、この比較をできるだけ速く実行したいと思います。 (問題がある場合、私はx86-64マシンでコードを実行しており、符号なし整数は10 ^ 6以下です。また、_a1<a2_であることも知っています。)

スタンドアロン関数として、これは私が最適化しようとしているものです。私の数値は、double(またはfloat)でも正確に表すことができるほど小さい整数ですが、sqrtの結果の丸め誤差によって結果が変わることはありません。

_// known pre-condition: a1 < a2  in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);  // computed mathematically exactly
}
_

テストケースis_smaller(900000, 1000000, 900001, 998002)はtrueを返す必要がありますが、@ wimのコメントに示されているように、sqrtf()で計算するとfalseが返されます。したがって、_(int)sqrt()_を切り捨てて整数に戻します。

a1+sqrt(b1) = 90100およびa2+sqrt(b2) = 901000.00050050037512481206。これに最も近いフロートは正確に90100です。


sqrt()関数は、sqrtsd命令として完全にインライン化されている場合、現代のx86-64でも一般的に非常に高価であるため、sqrt()を呼び出さないようにしています可能。

二乗によってsqrtを削除すると、すべての計算が正確になるため、丸めエラーの危険性も回避されます。

代わりに関数がこのようなものだった場合...

_bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
    return a1+sqrt(b1) < x;
}
_

...その後、私は単にreturn x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;

しかし、今は2つのsqrt(...)項があるため、同じ代数操作を行うことはできません。

次の式を使用して、値twiceを二乗できます。

_      a1 + sqrt(b1) = a2 + sqrt(b2)
<==>  a1 - a2 = sqrt(b2) - sqrt(b1)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==>  (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2
_

符号なし4による除算はビットシフトなので安価ですが、数値を2乗するので128ビット整数を使用する必要があり、いくつかの_>=0_チェックを導入する必要があります(不等式を比較しているため)平等の代わりに)。

この問題により良い代数を適用することで、これをより速く行う方法があるかもしれないと感じています。これをより速く行う方法はありますか?

45
Bernard

これはsqrtのないバージョンですが、sqrtが1つしかないバージョンよりも高速かどうかはわかりません(値の分布に依存する場合があります)。

これが数学です(両方のsqrtを削除する方法):

_ad = a2-a1
bd = b2-b1

a1+sqrt(b1) < a2+sqrt(b2)              // subtract a1
   sqrt(b1) < ad+sqrt(b2)              // square it
        b1  < ad^2+2*ad*sqrt(b2)+b2    // arrange
   ad^2+bd  > -2*ad*sqrt(b2)
_

ここで、右側は常に負です。左側が正の場合、trueを返す必要があります。

左側が負の場合、不等式を二乗できます。

_ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2
_

ここで注目すべき重要な点は、_a2>=a1+1000_の場合、_is_smaller_は常にtrueを返すことです(sqrt(b1)の最大値は1000であるため)。 _a2<=a1+1000_の場合、adは小さい数値であるため、_ad^4_は常に64ビットに収まります(128ビット演算の必要はありません)。これがコードです:

_bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    if (ad>1000) {
        return true;
    }

    int bd = b2 - b1;
    if (ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;

    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}
_

編集:Peter Cordesが気づいたように、最初のifは必要ありません。2番目のはそれを処理するため、コードはより小さく、より速くなります。

_bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    int bd = b2 - b1;
    if ((long long int)ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;
    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}
_
19
geza

私は疲れていて、おそらく間違いを犯しました。誰かが指摘してくれたらきっと….

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a1-a2;   // May be negative

    if(a_diff < 0) {
        if(b1 < b2) {
            return true;
        }
        temp = a_diff+sqrt(b1);
        if(temp < 0) {
            return true;
        }
        return temp*temp < b2;
    } else {
        if(b1 >= b2) {
            return false;
        }
    }
//  return a_diff+sqrt(b1) < sqrt(b2);

    temp = a_diff+sqrt(b1);
    return temp*temp < b2;
}

あなたが知っていれば a1 < a2その後、次のようになります。

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a2-a1;    // Will be positive

    if(b1 > b2) {
        return false;
    }
    if(b1 >= a_diff*a_diff) {
        return false;
    }
    temp = a_diff+sqrt(b2);
    return b1 < temp*temp;
}
4
Brendan

ここで説明 として整数sqrtsを計算するためのニュートンメソッドもあります。別のアプローチは、平方根を計算せずに、バイナリ検索を介してfloor(sqrt(n))を検索することです...「のみ」があります10 ^ 6未満の1000の完全な平方数。これはおそらくパフォーマンスが悪いですが、興味深いアプローチになります。私はこれらのどれも測定していませんが、ここに例があります:

#include <iostream>
#include <array>
#include <algorithm>        // std::lower_bound
#include <cassert>          


bool is_smaller_sqrt(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt(b1) < a2 + sqrt(b2);
}

static std::array<int, 1001> squares;

template <typename C>
void squares_init(C& c)
{
    for (int i = 0; i < c.size(); ++i)
        c[i] = i*i;
}

inline bool greater(const int& l, const int& r)
{
    return r < l;
}

inline bool is_smaller_bsearch(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    // return a1 + sqrt(b1) < a2 + sqrt(b2)

    // find floor(sqrt(b1)) - binary search withing 1000 elems
    auto it_b1 = std::lower_bound(crbegin(squares), crend(squares), b1, greater).base();

    // find floor(sqrt(b2)) - binary search withing 1000 elems
    auto it_b2 = std::lower_bound(crbegin(squares), crend(squares), b2, greater).base();

    return (a2 - a1) > (it_b1 - it_b2);
}

unsigned int sqrt32(unsigned long n)
{
    unsigned int c = 0x8000;
    unsigned int g = 0x8000;

    for (;;) {
        if (g*g > n) {
            g ^= c;
        }

        c >>= 1;

        if (c == 0) {
            return g;
        }

        g |= c;
    }
}

bool is_smaller_sqrt32(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt32(b1) < a2 + sqrt32(b2);
}

int main()
{
    squares_init(squares);

    // now can use is_smaller
    assert(is_smaller_sqrt(1, 4, 3, 1) == is_smaller_sqrt32(1, 4, 3, 1));
    assert(is_smaller_sqrt(1, 2, 3, 3) == is_smaller_sqrt32(1, 2, 3, 3));
    assert(is_smaller_sqrt(1000, 4, 1001, 1) == is_smaller_sqrt32(1000, 4, 1001, 1));
    assert(is_smaller_sqrt(1, 300, 3, 200) == is_smaller_sqrt32(1, 300, 3, 200));
}
2
StPiere

代数的操作を整数演算と組み合わせると、必然的に最速のソリューションが得られるかどうかはわかりません。その場合、多くのスカラー乗算が必要になります(これは非常に高速ではありません)。分岐予測が失敗する可能性があり、その場合はパフォーマンスが低下する可能性があります。明らかに、特定のケースで最も高速なソリューションを確認するには、ベンチマークを行う必要があります。

sqrtを少し速くする1つの方法は、gccまたはclangに_-fno-math-errno_オプションを追加することです。その場合、コンパイラは負の入力をチェックする必要はありません。 iccでは、これがデフォルト設定です。

スカラーsqrt命令sqrtpdの代わりに、ベクトル化されたsqrt命令sqrtsdを使用すると、パフォーマンスをさらに向上させることができます。 Peter Cordes 示されています clangがこのコードを自動ベクトル化して、このsqrtpdを生成できるようにします。

ただし、自動ベクトル化が成功するかどうかは、正しいコンパイラ設定と使用するコンパイラ(clang、gcc、iccなど)に大きく依存します。 _-march=nehalem_以前では、clangはベクトル化しません。

以下の組み込みコードを使用すると、より信頼性の高いベクトル化の結果が得られます。以下を参照してください。移植性のために、x86-64ベースラインであるSSE2サポートのみを想定しています。

_/* gcc -m64 -O3 -fno-math-errno smaller.c                      */
/* Adding e.g. -march=nehalem or -march=skylake might further  */
/* improve the generated code                                  */
/* Note that SSE2 in guaranteed to exist with x86-64           */
#include<immintrin.h>
#include<math.h>
#include<stdio.h>
#include<stdint.h>

int is_smaller_v5(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    uint64_t a64    =  (((uint64_t)a2)<<32) | ((uint64_t)a1); /* Avoid too much port 5 pressure by combining 2 32 bit integers in one 64 bit integer */
    uint64_t b64    =  (((uint64_t)b2)<<32) | ((uint64_t)b1); 
    __m128i ax      = _mm_cvtsi64_si128(a64);         /* Move integer from gpr to xmm register                  */
    __m128i bx      = _mm_cvtsi64_si128(b64);         
    __m128d a       = _mm_cvtepi32_pd(ax);            /* Convert 2 integers to double                           */
    __m128d b       = _mm_cvtepi32_pd(bx);            /* We don't need _mm_cvtepu32_pd since a,b < 1e6          */
    __m128d sqrt_b  = _mm_sqrt_pd(b);                 /* Vectorized sqrt: compute 2 sqrt-s with 1 instruction   */
    __m128d sum     = _mm_add_pd(a, sqrt_b);
    __m128d sum_lo  = sum;                            /* a1 + sqrt(b1) in the lower 64 bits                     */
    __m128d sum_hi  =  _mm_unpackhi_pd(sum, sum);     /* a2 + sqrt(b2) in the lower 64 bits                     */
    return _mm_comilt_sd(sum_lo, sum_hi);
}


int is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);
}


int main(){
    unsigned a1; unsigned b1; unsigned a2; unsigned b2;
    a1 = 11; b1 = 10; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 11; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 11; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 10; b2 = 11;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));

    return 0;
}
_


生成されたアセンブリについては このGodboltリンク を参照してください。

コンパイラオプション_gcc -m64 -O3 -fno-math-errno -march=nehalem_を使用したIntel Skylakeの単純なスループットテストでは、is_smaller_v5()のスループットが元のis_smaller()より2.6倍優れていることがわかりました。ループオーバーヘッドを含む18 CPUサイクル。ただし、入力が_a1, a2, b1, b2_が以前のis_smaller(_v5)の結果に依存する(あまりにも?)単純なレイテンシテストでは、改善は見られませんでした。 (39.7サイクルvs 39サイクル)。

2
wim

おそらく他の回答よりは良くありませんが、別のアイデア(および大量の事前分析)を使用しています。

// Compute approximate integer square root of input in the range [0,10^6].
// Uses a piecewise linear approximation to sqrt() with bounded error in each piece:
//   0 <= x <= 784 : x/28
//   784 < x <= 7056 : 21 + x/112
//   7056 < x <= 28224 : 56 + x/252
//   28224 < x <= 78400 : 105 + x/448
//   78400 < x <= 176400 : 168 + x/700
//   176400 < x <= 345744 : 245 + x/1008
//   345744 < x <= 614656 : 336 + x/1372
//   614656 < x <= 1000000 : (784000+x)/1784
// It is the case that sqrt(x) - 7.9992711366390365897... <= pseudosqrt(x) <= sqrt(x).
unsigned pseudosqrt(unsigned x) {
    return 
        x <= 78400 ? 
            x <= 7056 ?
                x <= 764 ? x/28 : 21 + x/112
              : x <= 28224 ? 56 + x/252 : 105 + x/448
          : x <= 345744 ?
                x <= 176400 ? 168 + x/700 : 245 + x/1008
              : x <= 614656 ? 336 + x/1372 : (x+784000)/1784 ;
}

// known pre-conditions: a1 < a2, 
//                  0 <= b1 <= 1000000
//                  0 <= b2 <= 1000000
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
// Try three refinements:
// 1: a1 + sqrt(b1) <= a1 + 1000, 
//    so is a1 + 1000 < a2 ?  
//    Convert to a2 - a1 > 1000 .
// 2: a1 + sqrt(b1) <= a1 + pseudosqrt(b1) + 8 and
//    a2 + pseudosqrt(b2) <= a2 + sqrt(b2), 
//    so is  a1 + pseudosqrt(b1) + 8 < a2 + pseudosqrt(b2) ?
//    Convert to a2 - a1 > pseudosqrt(b1) - pseudosqrt(b2) + 8 .
// 3: Actually do the work.
//    Convert to a2 - a1 > sqrt(b1) - sqrt(b2)
// Use short circuit evaluation to stop when resolved.
    unsigned ad = a2 - a1;
    return (ad > 1000)
           || (ad > pseudosqrt(b1) - pseudosqrt(b2) + 8)
           || ((int) ad > (int)(sqrt(b1) - sqrt(b2)));
}

(私は便利なコンパイラを持っていないので、これはおそらくタイプミスを含んでいます。)

1
Eric Towers