web-dev-qa-db-ja.com

NumPyのアインサムを理解する

einsumの動作を正確に理解するのに苦労しています。私はドキュメントといくつかの例を見てきましたが、それは固執していないようです。

クラスで行った例を次に示します。

C = np.einsum("ij,jk->ki", A, B)

2つのarrayAおよびBの場合

これにはA^T * Bが必要だと思いますが、よくわかりません(そのうちの1つの転置が正しいのでしょうか?)。誰もがここで何が起こっているのか正確に説明できますか(一般的にeinsumを使用している場合)?

147
Lance Strait

(注:この回答は、私が少し前に書いたeinsumについての短い ブログ投稿 に基づいています。)

einsumは何をしますか?

ABの2つの多次元配列があるとします。では、次のことをしたいとしましょう...

  • multiply特定の方法でABを使用して、製品の新しい配列を作成します。そして多分
  • sum特定の軸に沿ったこの新しい配列。そして多分
  • transpose新しい配列の軸を特定の順序で並べ替えます。

einsumは、multiplysumtransposeなどのNumPy関数の組み合わせで可能になるため、einsumがこれをより速く、より効率的に行うのに役立つ可能性があります。

Aはどのように機能しますか?

以下に、単純な(ただし、完全に自明ではない)例を示します。次の2つの配列を使用します。

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

BAを要素ごとに乗算してから、新しい配列の行に沿って合計します。 「通常の」NumPyでは、次のように記述します。

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

したがって、ここでは、einsumのインデックス付け操作が2つの配列の最初の軸を整列させ、乗算をブロードキャストできるようにします。次に、製品の配列の行が合計されて回答が返されます。

代わりにAを使用したい場合、次のように記述できます。

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

signature文字列'i,ij->i'はここでのキーであり、少し説明が必要です。 2つの半分に考えることができます。左側(->の左側)で、2つの入力配列にラベルを付けました。 ->の右側に、目的の配列にラベルを付けました。

次に何が起こるかを示します。

  • iには1つの軸があります。 Bというラベルを付けました。また、iには2つの軸があります。軸0をj、軸1をiとラベル付けしました。

  • 両方の入力配列のラベルeinsumrepeatingすることで、これら2つの軸をmultiplied一緒に。つまり、_A[:, np.newaxis] * Bが行うように、配列Aに配列Bの各列を乗算します。

  • jは、目的の出力ではラベルとして表示されないことに注意してください。 iを使用しました(最終的に1D配列になります)。ラベルをomittingすることにより、これに沿ってeinsumsumを伝えます軸。つまり、.sum(axis=1)と同じように、製品の行を合計します。

基本的に、einsumを使用するために知っておく必要があるのはこれだけです。少し遊ぶのに役立ちます。出力に両方のラベル'i,ij->ij'を残すと、製品の2D配列が返されます(A[:, np.newaxis] * Bと同じ)。出力ラベル'i,ij->を指定しない場合、単一の数値が返されます((A[:, np.newaxis] * B).sum()を実行するのと同じ)。

ただし、einsumの素晴らしい点は、最初に製品の一時配列を作成しないことです。そのまま製品を合計するだけです。これにより、メモリ使用量を大幅に節約できます。

少し大きい例

内積を説明するために、2つの新しい配列を次に示します。

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

np.einsum('ij,jk->ik', A, B)を使用してドット積を計算します。以下は、AおよびBのラベル付けと、関数から取得する出力配列を示す図です。

enter image description here

ラベルjが繰り返されていることがわかります。これは、Aの行にBの列を乗算していることを意味します。さらに、ラベルjは出力に含まれません-これらの製品を合計しています。ラベルiおよびkは出力用に保持されているため、2D配列が返されます。

この結果を、ラベルjnot加算された配列と比較すると、さらに明確になる場合があります。以下の左側には、np.einsum('ij,jk->ijk', A, B)(つまり、ラベルjを保持している)を書き込んだ結果の3D配列が表示されています。

enter image description here

合計軸jは、右側に示すように、予想される内積を与えます。

いくつかの演習

einsumの感触をよりよく理解するには、添字表記を使用して使い慣れたNumPy配列操作を実装すると便利です。乗算軸と加算軸の組み合わせを含むものはすべて、einsumを使用して記述できます。

AとBを同じ長さの2つの1D配列とします。たとえば、A = np.arange(10)およびB = np.arange(5, 15)です。

  • Aの合計は次のように記述できます。

    np.einsum('i->', A)
    
  • 要素単位の乗算A * Bは、次のように記述できます。

    np.einsum('i,i->i', A, B)
    
  • 内積またはドット積np.inner(A, B)またはnp.dot(A, B)は、次のように記述できます。

    np.einsum('i,i->', A, B) # or just use 'i,i'
    
  • 外積np.outer(A, B)は、次のように記述できます。

    np.einsum('i,j->ij', A, B)
    

2D配列、CおよびDの場合、軸の長さが互換性がある場合(同じ長さまたはいずれかが長さ1である場合)、いくつかの例を次に示します。

  • C(主対角線の合計)np.trace(C)のトレースは、次のように記述できます。

    np.einsum('ii', C)
    
  • Cの要素ごとの乗算とDの転置、C * D.Tは、次のように記述できます。

    np.einsum('ij,ji->ij', C, D)
    
  • Cの各要素に、D配列(4D配列を作成するため)を乗算すると、C[:, :, None, None] * Dが記述できます。

    np.einsum('ij,kl->ijkl', C, D)  
    
279
Alex Riley

numpy.einsum() の概念を把握することは、直感的に理解できれば非常に簡単です。例として、行列乗算を含む簡単な説明から始めましょう。


numpy.einsum() を使用するには、いわゆるsubscripts stringを引数として渡し、その後に入力配列を渡すだけです。

ABの2つの2D配列があり、行列の乗算を実行するとします。そうしたらいい:

np.einsum("ij, jk -> ik", A, B)

ここで、添え字文字列ijは配列Aに対応しますが、添え字文字列jkは配列Bに対応します。また、ここで最も重要なことは、各添え字文字列must文字数が配列の次元と一致することです。 (つまり、2D配列の場合は2文字、3D配列の場合は3文字など)。そして、添え字文字列(この例ではj)の間で文字を繰り返す場合、 einsumは、これらのディメンションに沿って発生します。したがって、それらは合計削減されます。 (つまり、その次元はなくなった

この->の後の添え字文字列は、結果の配列になります。空のままにすると、すべてが合計され、結果としてスカラー値が返されます。そうでない場合、結果の配列は添え字文字列に応じた次元を持ちます。この例では、ikになります。行列の乗算では、配列Aの列数が配列Bの行数と一致する必要があることがわかっているため、これは直感的です(つまり、これをエンコードします) 添え字文字列)でchar jを繰り返すことによる知識


いくつかの一般的なテンソルまたはnd-array操作を簡潔に実装する際の np.einsum() の使用/能力を示すいくつかの例を次に示します。

入力

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1)行列の乗算np.matmul(arr1, arr2)に類似)

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2)主対角線に沿って要素を抽出しますnp.diag(arr)に類似)

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

)アダマール積(つまり、2つの配列の要素単位の積)arr1 * arr2に類似)

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4)要素ごとの二乗np.square(arr)またはarr ** 2に類似)

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5)トレース(主対角要素の合計)np.trace(arr)に類似)

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6)行列の転置np.transpose(arr)に類似)

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7)(ベクトルの)外積np.outer(vec1, vec2)に類似)

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8)内積(ベクトルの)np.inner(vec1, vec2)に類似)

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9)軸0に沿った合計np.sum(arr, axis=0)に類似)

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10)軸1に沿った合計np.sum(arr, axis=1)に類似)

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11)バッチマトリックス乗算

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12)軸2に沿った合計np.sum(arr, axis=2)に類似)

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13)配列内のすべての要素を合計しますnp.sum(arr)に類似)

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14)複数の軸での合計(周辺化)
np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7))に類似)

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15)ダブルドット積np.sum(hadamard-product) cf. )と同様

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16)2Dおよび3D配列乗算

このような乗算は、結果を検証したい線形連立方程式(Ax = b)を解くときに非常に便利です。

# inputs
In [115]: A = np.random.Rand(3,3)
In [116]: b = np.random.Rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

逆に、この検証に np.matmul() を使用する必要がある場合、次のような同じ結果を得るために、いくつかのreshape操作を実行する必要があります。

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

ボーナス:数学の詳細はこちら: Einstein-Summation そして間違いなくここ: Tensor-Notation

33
kmario23

相互作用を強調するために、異なるが互換性のある寸法を持つ2つの配列を作成できます

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

計算では、(2,3)と(3,4)の「ドット」(製品の合計)を取り、(4,2)配列を生成します。 iAの最初のdim、最後のCです。 kBの最後、Cの最初。 jは、合計によって「消費」されます。

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

これはnp.dot(A,B).Tと同じです-転置されるのは最終出力です。

jの詳細を確認するには、C添え字をijkに変更します。

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

これは次のものでも生成できます。

A[:,:,None]*B[None,:,:]

つまり、kディメンションをAの最後に追加し、iBの前に追加すると、(2,3,4)配列になります。

0 + 4 + 16 = 209 + 28 + 55 = 92など; jを合計し、転置して前の結果を取得します。

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]
7
hpaulj

見つけました NumPy:取引の秘trick(パートII) 有益

->を使用して、出力配列の順序を示します。したがって、「ij、i-> j」は左側(LHS)および右側(RHS)であると考えてください。 LHSでラベルが繰り返されると、製品要素が賢明に計算され、合計されます。 RHS(出力)側のラベルを変更することにより、入力配列に対して進む軸、つまり軸0、1などに沿った合計を定義できます。

import numpy as np

>>> a
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
>>> b
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> d = np.einsum('ij, jk->ki', a, b)

I、j、kの3つの軸があり、jが繰り返されていることに注意してください(左側)。 i,jは、aの行と列を表します。 bj,k

製品を計算し、j軸を調整するには、aに軸を追加する必要があります。 (bは、最初の軸に沿ってブロードキャストされますか?)

a[i, j, k]
   b[j, k]

>>> c = a[:,:,np.newaxis] * b
>>> c
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 0,  3,  6],
        [ 9, 12, 15],
        [18, 21, 24]]])

jは右側にないため、3x3x3配列の2番目の軸であるjを合計します

>>> c = c.sum(1)
>>> c
array([[ 9, 12, 15],
       [18, 24, 30],
       [27, 36, 45]])

最後に、インデックスは(アルファベット順で)右側で逆になるため、転置します。

>>> c.T
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])

>>> np.einsum('ij, jk->ki', a, b)
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])
>>>
5
wwii