web-dev-qa-db-ja.com

なぜこの素朴な行列の乗算はベースRよりも速いのですか?

Rでは、行列の乗算は非常に最適化されています。つまり、BLAS/LAPACKへの呼び出しにすぎません。ただし、この非常に素朴な行列とベクトルの乗算用のC++コードは、確実に30%高速に思えます。

 library(Rcpp)

 # Simple C++ code for matrix multiplication
 mm_code = 
 "NumericVector my_mm(NumericMatrix m, NumericVector v){
   int nRow = m.rows();
   int nCol = m.cols();
   NumericVector ans(nRow);
   double v_j;
   for(int j = 0; j < nCol; j++){
     v_j = v[j];
     for(int i = 0; i < nRow; i++){
       ans[i] += m(i,j) * v_j;
     }
   }
   return(ans);
 }
 "
 # Compiling
 my_mm = cppFunction(code = mm_code)

 # Simulating data to use
 nRow = 10^4
 nCol = 10^4

 m = matrix(rnorm(nRow * nCol), nrow = nRow)
 v = rnorm(nCol)

 system.time(my_ans <- my_mm(m, v))
#>    user  system elapsed 
#>   0.103   0.001   0.103 
 system.time(r_ans <- m %*% v)
#>   user  system elapsed 
#>  0.154   0.001   0.154 

 # Double checking answer is correct
 max(abs(my_ans - r_ans))
 #> [1] 0

ベースRの%*%は、スキップしているある種のデータチェックを実行しますか?

編集:

何が起こっているのかを理解した後(SOに感謝!)、これはRの%*%の最悪のシナリオ、つまりベクトルによる行列であることに注意してください。たとえば、@ RalfStubnerは、RcppArmadillo実装の行列-ベクトル乗算の使用は、私が示した単純な実装よりもさらに高速であり、ベースRよりもかなり高速であることを示唆していますが、ベースRの%*%とほぼ同じです。 -matrix乗算(両方の行列が大きくて正方の場合):

 arma_code <- 
   "arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
 return m * m2;
 };"
 arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

 nRow = 10^3 
 nCol = 10^3

 mat1 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)
 mat2 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)

 system.time(arma_mm(mat1, mat2))
#>   user  system elapsed 
#>   0.798   0.008   0.814 
 system.time(mat1 %*% mat2)
#>   user  system elapsed 
#>   0.807   0.005   0.822  

したがって、Rの現在の(v3.5.0)%*%は、matrix-matrixにほぼ最適ですが、チェックをスキップしても大丈夫な場合は、matrix-vectorを大幅に高速化できます。

30
Cliff AB

names.cの概要( ここでは特に )は、do_matprodによって呼び出され、%*%ファイルにあるC関数であるarray.cを示しています。 (興味深いことに、crossprodtcrossprodの両方が同じ関数にディスパッチすることがわかりました)。 ここにリンクがありますdo_matprodのコードへ。

関数をスクロールすると、単純な実装では行われない次のような多くのことが行われていることがわかります。

  1. 行と列の名前を保持します。
  2. %*%の呼び出しによって操作される2つのオブジェクトが、そのようなメソッドが提供されているクラスのものである場合に、代替のS4メソッドへのディスパッチを可能にします。 (それが関数の この部分 で起こっていることです。)
  3. 実数行列と複素数行列の両方を処理します。
  4. 行列と行列、ベクトルと行列、行列とベクトル、およびベクトルとベクトルの乗算の処理方法に関する一連のルールを実装します。 (Rのクロス乗算では、LHSのベクトルは行ベクトルとして扱われますが、RHSでは列ベクトルとして扱われます。これは、これを行うコードです。)

関数の終わり近く 、それは matprod または cmatprod のいずれかにディスパッチされます。興味深いことに(少なくとも私にとっては)、実際の行列の場合、ifいずれかの行列にNaNまたはInfの値が含まれ、次にmatprodディスパッチ( ここsimple_matprod と呼ばれる関数に、これはあなた自身のものと同じくらい単純で簡単です。それ以外の場合は、いくつかのBLAS Fortranルーチンの1つにディスパッチします。これは、均一に「正常に動作する」行列要素が保証される場合は、おそらく高速です。

27
Josh O'Brien

ジョシュの答えは、なぜRの行列乗算がこの素朴なアプローチほど速くないかを説明しています。 RcppArmadilloを使用してどれだけの利益が得られるか知りたいと思っていました。コードは非常に単純です:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

基準:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10

したがって、RcppArmadilloを使用すると、構文が改善され、パフォーマンスが向上します。

好奇心は私に良くなった。ここでBLASを直接使用するためのソリューション:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")

基準:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10

ArmadilloとBLAS(私の場合はOpenBLAS)はほとんど同じです。そして、BLASコードは、Rが最後に行うことでもあります。したがって、Rの2/3はエラーチェックなどです。

7
Ralf Stubner