web-dev-qa-db-ja.com

PySpark:フィルター関数を使用した後、列の平均を取る

次のコードを使用して、給与がしきい値を超えている人の平均年齢を取得しています。

dataframe.filter(df['salary'] > 100000).agg({"avg": "age"})

列の経過時間は数値(浮動)ですが、それでもこのエラーが発生します。

py4j.protocol.Py4JJavaError: An error occurred while calling o86.agg. 
: scala.MatchError: age (of class Java.lang.String)

groupBy関数とSQLクエリを使用せずにavgなどを取得する他の方法を知っていますか。

15

集計関数は値であり、列名はキーである必要があります。

dataframe.filter(df['salary'] > 100000).agg({"age": "avg"})

または、pyspark.sql.functionsを使用できます。

from pyspark.sql.functions import col, avg

dataframe.filter(df['salary'] > 100000).agg(avg(col("age")))

CASE .. WHENを使用することもできます

from pyspark.sql.functions import when

dataframe.select(avg(when(df['salary'] > 100000, df['age'])))
34
zero323