web-dev-qa-db-ja.com

pysparkのDataFrameの各グループの上位nを取得します

Pysparkには、次のようなデータを持つDataFrameがあります。

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

私が期待しているのは、同じuser_idで各グループに2つのレコードを返すことです。これには最高のスコアが必要です。その結果、結果は次のようになります。

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

私は本当にpysparkに慣れていないのですが、この問題の関連ドキュメントへのコードスニペットまたはポータルを教えてもらえますか?まことにありがとうございます!

33
KAs

user_idscoreに基づいて各行のランクを取得するには、 ウィンドウ関数 を使用し、その後、結果をフィルタリングして最初の行のみを保持する必要があると思います2つの値。

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

一般に、公式 プログラミングガイド は、Sparkの学習を開始するのに適した場所です。

データ

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
51
mtoto

ランクの同等性を取得するときにrankの代わりに_row_number_を使用すると、Top-nはより正確になります。

_val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()
_

より良いフォーマットのために、Jupyterノートブックではlimit(20).toPandas()の代わりにshow()トリックに注意してください。

19
Martin Tapp

私は質問がpysparkを求められていることを知っており、Scalaで同様の答えを探していました。

ScalaのDataFrameの各グループの上位n個の値を取得します

@mtotoの回答のscalaバージョンを以下に示します。

import org.Apache.spark.sql.expressions.Window
import org.Apache.spark.sql.functions.rank
import org.Apache.spark.sql.functions.col

val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() 
# you can change the value 2 to any number you want. Here 2 represents the top 2 values

より多くの例は here にあります。

2
Abu Shoeb