pyspark.sql.GroupedData.applyInArrow#
- GroupedData.applyInArrow(func, schema)#
Maps each group of the current
DataFrameusing an Arrow udf and returns the result as a DataFrame.The function can take one of two forms: It can take a pyarrow.Table and return a pyarrow.Table, or it can take an iterator of pyarrow.RecordBatch and yield pyarrow.RecordBatch. Alternatively each form can take a tuple of pyarrow.Scalar as the first argument in addition to the input type above. For each group, all columns are passed together in the pyarrow.Table or pyarrow.RecordBatch, and the returned pyarrow.Table or iterator of pyarrow.RecordBatch are combined as a
DataFrame.The schema should be a
StructTypedescribing the schema of the returned pyarrow.Table or pyarrow.RecordBatch. The column labels of the returned pyarrow.Table or pyarrow.RecordBatch must either match the field names in the defined schema if specified as strings, or match the field data types by position if not strings, e.g. integer indices. The length of the returned pyarrow.Table or iterator of pyarrow.RecordBatch can be arbitrary.New in version 4.0.0.
Changed in version 4.1.0: Added support for an iterator of pyarrow.RecordBatch API.
- Parameters
- funcfunction
a Python native function that either takes a pyarrow.Table and outputs a pyarrow.Table or takes an iterator of pyarrow.RecordBatch and yields pyarrow.RecordBatch. Additionally, each form can take a tuple of grouping keys as the first argument, with the pyarrow.Table or iterator of pyarrow.RecordBatch as the second argument.
- schema
pyspark.sql.types.DataTypeor str the return type of the func in PySpark. The value can be either a
pyspark.sql.types.DataTypeobject or a DDL-formatted type string.
See also
Notes
This function requires a full shuffle. If using the pyarrow.Table API, all data of a group will be loaded into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory, and can use the iterator of pyarrow.RecordBatch API to mitigate this.
This API is unstable, and for developers.
Examples
>>> from pyspark.sql.functions import ceil >>> import pyarrow >>> import pyarrow.compute as pc >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> def normalize(table): ... v = table.column("v") ... norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1)) ... return table.set_column(1, "v", norm) >>> df.groupby("id").applyInArrow( ... normalize, schema="id long, v double").show() +---+-------------------+ | id| v| +---+-------------------+ | 1|-0.7071067811865475| | 1| 0.7071067811865475| | 2|-0.8320502943378437| | 2|-0.2773500981126146| | 2| 1.1094003924504583| +---+-------------------+
The function can also take and return an iterator of pyarrow.RecordBatch using type hints.
>>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> def sum_func( ... batches: Iterator[pyarrow.RecordBatch] ... ) -> Iterator[pyarrow.RecordBatch]: ... total = 0 ... for batch in batches: ... total += pc.sum(batch.column("v")).as_py() ... yield pyarrow.RecordBatch.from_pydict({"v": [total]}) >>> df.groupby("id").applyInArrow( ... sum_func, schema="v double").show() +----+ | v| +----+ | 3.0| |18.0| +----+
Alternatively, the user can pass a function that takes two arguments. In this case, the grouping key(s) will be passed as the first argument and the data will be passed as the second argument. The grouping key(s) will be passed as a tuple of Arrow scalars types, e.g., pyarrow.Int32Scalar and pyarrow.FloatScalar. The data will still be passed in as a pyarrow.Table containing all columns from the original Spark DataFrame. This is useful when the user does not want to hardcode grouping key(s) in the function.
>>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> def mean_func(key, table): ... # key is a tuple of one pyarrow.Int64Scalar, which is the value ... # of 'id' for the current group ... mean = pc.mean(table.column("v")) ... return pyarrow.Table.from_pydict({"id": [key[0].as_py()], "v": [mean.as_py()]}) >>> df.groupby('id').applyInArrow( ... mean_func, schema="id long, v double") +---+---+ | id| v| +---+---+ | 1|1.5| | 2|6.0| +---+---+
>>> def sum_func(key, table): ... # key is a tuple of two pyarrow.Int64Scalars, which is the values ... # of 'id' and 'ceil(df.v / 2)' for the current group ... sum = pc.sum(table.column("v")) ... return pyarrow.Table.from_pydict({ ... "id": [key[0].as_py()], ... "ceil(v / 2)": [key[1].as_py()], ... "v": [sum.as_py()] ... }) >>> df.groupby(df.id, ceil(df.v / 2)).applyInArrow( ... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() +---+-----------+----+ | id|ceil(v / 2)| v| +---+-----------+----+ | 2| 5|10.0| | 1| 1| 3.0| | 2| 3| 5.0| | 2| 2| 3.0| +---+-----------+----+
>>> def sum_func( ... key: Tuple[pyarrow.Scalar, ...], batches: Iterator[pyarrow.RecordBatch] ... ) -> Iterator[pyarrow.RecordBatch]: ... total = 0 ... for batch in batches: ... total += pc.sum(batch.column("v")).as_py() ... yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], "v": [total]}) >>> df.groupby("id").applyInArrow( ... sum_func, schema="id long, v double").show() +---+----+ | id| v| +---+----+ | 1| 3.0| | 2|18.0| +---+----+