web-dev-qa-db-ja.com

pyspark抽出ROC曲線?

ROC曲線上のポイントをpysparkのSpark MLから取得する方法はありますか?ドキュメントには、Scalaの例がありますが、Pythonではありません:- https://spark.Apache.org/docs/2.1.0/mllib-evaluation-metrics.html

そうですか?私は確かにそれを実装する方法を考えることができますが、事前に構築された関数があればそれがより速いと想像しなければなりません。私は300万のスコアと数十のモデルを扱っているので、速度が重要です。

ありがとう!

4
seth127

ROC曲線がTPRに対するFPRのプロットである限り、次のように必要な値を抽出できます。

your_model.summary.roc.select('FPR').collect()
your_model.summary.roc.select('TPR').collect())

your_modelは、たとえば、次のようなものから取得したモデルである可能性があります。

from pyspark.ml.classification import LogisticRegression
log_reg = LogisticRegression()
your_model = log_reg.fit(df)

ここで、たとえばmatplotlibを使用して、TPRに対してFPRをプロットする必要があります。

追記

これは、your_model(およびその他の!)という名前のモデルを使用してROC曲線をプロットするための完全な例です。また、ROCプロット内に参照「ランダム推測」線をプロットしました。

import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
plt.plot([0, 1], [0, 1], 'r--')
plt.plot(your_model.summary.roc.select('FPR').collect(),
         your_model.summary.roc.select('TPR').collect())
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.show()
4
Andrea

ロジスティック回帰以外のモデル(モデルの要約がないディシジョンツリーやランダムフォレストなど)で機能するより一般的なソリューションの場合、 BinaryClassificationMetrics from Spark = MLlib。

PySparkバージョンは、 Scalaバージョン が実装するすべてのメソッドを実装しているわけではないため、 JavaModelWrapper.call(name)関数を使用する必要があることに注意してください。 =。また、py4jはscala.Tuple2クラスの解析をサポートしていないようであるため、手動で処理する必要があります。

例:

from pyspark.mllib.evaluation import BinaryClassificationMetrics

# Scala version implements .roc() and .pr()
# Python: https://spark.Apache.org/docs/latest/api/python/_modules/pyspark/mllib/common.html
# Scala: https://spark.Apache.org/docs/latest/api/Java/org/Apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
class CurveMetrics(BinaryClassificationMetrics):
    def __init__(self, *args):
        super(CurveMetrics, self).__init__(*args)

    def _to_list(self, rdd):
        points = []
        # Note this collect could be inefficient for large datasets 
        # considering there may be one probability per datapoint (at most)
        # The Scala version takes a numBins parameter, 
        # but it doesn't seem possible to pass this from Python to Java
        for row in rdd.collect():
            # Results are returned as type scala.Tuple2, 
            # which doesn't appear to have a py4j mapping
            points += [(float(row._1()), float(row._2()))]
        return points

    def get_curve(self, method):
        rdd = getattr(self._Java_model, method)().toJavaRDD()
        return self._to_list(rdd)

使用法:

import matplotlib.pyplot as plt

# Create a Pipeline estimator and fit on train DF, predict on test DF
model = estimator.fit(train)
predictions = model.transform(test)

# Returns as a list (false positive rate, true positive rate)
preds = predictions.select('label','probability').rdd.map(lambda row: (float(row['probability'][1]), float(row['label'])))
points = CurveMetrics(preds).get_curve('roc')

plt.figure()
x_val = [x[0] for x in points]
y_val = [x[1] for x in points]
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.plot(x_val, y_val)

ROC curve generated with PySpark BinaryClassificationMetrics

ScalaのBinaryClassificationMetricsは、他のいくつかの便利なメソッドも実装しています。

metrics = CurveMetrics(preds)
metrics.get_curve('fMeasureByThreshold')
metrics.get_curve('precisionByThreshold')
metrics.get_curve('recallByThreshold')
2
Alex Ross