from __future__ import annotations from typing import TYPE_CHECKING from typing import Any from typing import Generic from typing import Iterable from typing import Iterator from typing import TypeVar from typing import cast from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.utils import tupleify if TYPE_CHECKING: from narwhals.typing import IntoExpr DataFrameT = TypeVar("DataFrameT") LazyFrameT = TypeVar("LazyFrameT") class GroupBy(Generic[DataFrameT]): def __init__(self, df: DataFrameT, *keys: str, drop_null_keys: bool) -> None: self._df = cast(DataFrame[Any], df) self._keys = keys self._grouped = self._df._compliant_frame.group_by( *self._keys, drop_null_keys=drop_null_keys ) def agg( self, *aggs: IntoExpr | Iterable[IntoExpr], **named_aggs: IntoExpr ) -> DataFrameT: """Compute aggregations for each group of a group by operation. Arguments: aggs: Aggregations to compute for each group of the group by operation, specified as positional arguments. named_aggs: Additional aggregations, specified as keyword arguments. Returns: A new Dataframe. Examples: Group by one column or by multiple columns and call `agg` to compute the grouped sum of another column. >>> import pandas as pd >>> import polars as pl >>> import narwhals as nw >>> df_pd = pd.DataFrame( ... { ... "a": ["a", "b", "a", "b", "c"], ... "b": [1, 2, 1, 3, 3], ... "c": [5, 4, 3, 2, 1], ... } ... ) >>> df_pl = pl.DataFrame( ... { ... "a": ["a", "b", "a", "b", "c"], ... "b": [1, 2, 1, 3, 3], ... "c": [5, 4, 3, 2, 1], ... } ... ) We define library agnostic functions: >>> @nw.narwhalify ... def func(df): ... return df.group_by("a").agg(nw.col("b").sum()).sort("a") >>> @nw.narwhalify ... def func_mult_col(df): ... return df.group_by("a", "b").agg(nw.sum("c")).sort("a", "b") We can then pass either pandas or Polars to `func` and `func_mult_col`: >>> func(df_pd) a b 0 a 2 1 b 5 2 c 3 >>> func(df_pl) shape: (3, 2) ┌─────┬─────┠│ a ┆ b │ │ --- ┆ --- │ │ str ┆ i64 │ â•žâ•â•â•â•â•â•ªâ•â•â•â•â•â•¡ │ a ┆ 2 │ │ b ┆ 5 │ │ c ┆ 3 │ └─────┴─────┘ >>> func_mult_col(df_pd) a b c 0 a 1 8 1 b 2 4 2 b 3 2 3 c 3 1 >>> func_mult_col(df_pl) shape: (4, 3) ┌─────┬─────┬─────┠│ a ┆ b ┆ c │ │ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 │ â•žâ•â•â•â•â•â•ªâ•â•â•â•â•â•ªâ•â•â•â•â•â•¡ │ a ┆ 1 ┆ 8 │ │ b ┆ 2 ┆ 4 │ │ b ┆ 3 ┆ 2 │ │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), ) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: yield from ( # type: ignore[misc] (tupleify(key), self._df._from_compliant_dataframe(df)) for (key, df) in self._grouped.__iter__() ) class LazyGroupBy(Generic[LazyFrameT]): def __init__(self, df: LazyFrameT, *keys: str, drop_null_keys: bool) -> None: self._df = cast(LazyFrame[Any], df) self._keys = keys self._grouped = self._df._compliant_frame.group_by( *self._keys, drop_null_keys=drop_null_keys ) def agg( self, *aggs: IntoExpr | Iterable[IntoExpr], **named_aggs: IntoExpr ) -> LazyFrameT: """Compute aggregations for each group of a group by operation. If a library does not support lazy execution, then this is a no-op. Arguments: aggs: Aggregations to compute for each group of the group by operation, specified as positional arguments. named_aggs: Additional aggregations, specified as keyword arguments. Returns: A new LazyFrame. Examples: Group by one column or by multiple columns and call `agg` to compute the grouped sum of another column. >>> import polars as pl >>> import narwhals as nw >>> from narwhals.typing import IntoFrameT >>> lf_pl = pl.LazyFrame( ... { ... "a": ["a", "b", "a", "b", "c"], ... "b": [1, 2, 1, 3, 3], ... "c": [5, 4, 3, 2, 1], ... } ... ) We define library agnostic functions: >>> def agnostic_func_one_col(lf_native: IntoFrameT) -> IntoFrameT: ... lf = nw.from_native(lf_native) ... return nw.to_native(lf.group_by("a").agg(nw.col("b").sum()).sort("a")) >>> def agnostic_func_mult_col(lf_native: IntoFrameT) -> IntoFrameT: ... lf = nw.from_native(lf_native) ... return nw.to_native(lf.group_by("a", "b").agg(nw.sum("c")).sort("a", "b")) We can then pass a lazy frame and materialise it with `collect`: >>> agnostic_func_one_col(lf_pl).collect() shape: (3, 2) ┌─────┬─────┠│ a ┆ b │ │ --- ┆ --- │ │ str ┆ i64 │ â•žâ•â•â•â•â•â•ªâ•â•â•â•â•â•¡ │ a ┆ 2 │ │ b ┆ 5 │ │ c ┆ 3 │ └─────┴─────┘ >>> agnostic_func_mult_col(lf_pl).collect() shape: (4, 3) ┌─────┬─────┬─────┠│ a ┆ b ┆ c │ │ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 │ â•žâ•â•â•â•â•â•ªâ•â•â•â•â•â•ªâ•â•â•â•â•â•¡ │ a ┆ 1 ┆ 8 │ │ b ┆ 2 ┆ 4 │ │ b ┆ 3 ┆ 2 │ │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), )