diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 8bb927718..74654ce46 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -695,8 +695,29 @@ expr_fn_vec!(named_struct); expr_fn!(from_unixtime, unixtime); expr_fn!(arrow_typeof, arg_1); expr_fn!(arrow_cast, arg_1 datatype); +expr_fn_vec!(arrow_metadata); +expr_fn!(union_tag, arg1); expr_fn!(random); +#[pyfunction] +fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr { + functions::core::get_field() + .call(vec![expr.into(), name.into()]) + .into() +} + +#[pyfunction] +fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr { + functions::core::union_extract() + .call(vec![union_expr.into(), field_name.into()]) + .into() +} + +#[pyfunction] +fn version() -> PyExpr { + functions::core::version().call(vec![]).into() +} + // Array Functions array_fn!(array_append, array element); array_fn!(array_to_string, array delimiter); @@ -1014,6 +1035,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; m.add_wrapped(wrap_pyfunction!(arrow_cast))?; + m.add_wrapped(wrap_pyfunction!(arrow_metadata))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; m.add_wrapped(wrap_pyfunction!(asinh))?; @@ -1142,6 +1164,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; + m.add_wrapped(wrap_pyfunction!(get_field))?; + m.add_wrapped(wrap_pyfunction!(union_extract))?; + m.add_wrapped(wrap_pyfunction!(union_tag))?; + m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision m.add_wrapped(wrap_pyfunction!(var_pop))?; m.add_wrapped(wrap_pyfunction!(var_sample))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 1b267731e..aa7f28746 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -98,6 +98,7 @@ "arrays_overlap", "arrays_zip", "arrow_cast", + "arrow_metadata", "arrow_typeof", "ascii", "asin", @@ -163,6 +164,7 @@ "gcd", "gen_series", "generate_series", + "get_field", "greatest", "ifnull", "in_list", @@ -280,6 +282,7 @@ "reverse", "right", "round", + "row", "row_number", "rpad", "rtrim", @@ -322,12 +325,15 @@ "translate", "trim", "trunc", + "union_extract", + "union_tag", "upper", "uuid", "var", "var_pop", "var_samp", "var_sample", + "version", "when", # Window Functions "window", @@ -2628,22 +2634,184 @@ def arrow_typeof(arg: Expr) -> Expr: return Expr(f.arrow_typeof(arg.expr)) -def arrow_cast(expr: Expr, data_type: Expr) -> Expr: +def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: """Casts an expression to a specified data type. + The ``data_type`` can be a string, a ``pyarrow.DataType``, or an + ``Expr``. For simple types, :py:meth:`Expr.cast() + ` is more concise + (e.g., ``col("a").cast(pa.float64())``). Use ``arrow_cast`` when + you want to specify the target type as a string using DataFusion's + type syntax, which can be more readable for complex types like + ``"Timestamp(Nanosecond, None)"``. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) - >>> data_type = dfn.string_literal("Float64") >>> result = df.select( - ... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c") + ... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c") + ... ) + >>> result.collect_column("c")[0].as_py() + 1.0 + + >>> import pyarrow as pa + >>> result = df.select( + ... dfn.functions.arrow_cast( + ... dfn.col("a"), data_type=pa.float64() + ... ).alias("c") ... ) >>> result.collect_column("c")[0].as_py() 1.0 """ + if isinstance(data_type, pa.DataType): + return expr.cast(data_type) + if isinstance(data_type, str): + data_type = Expr.string_literal(data_type) return Expr(f.arrow_cast(expr.expr, data_type.expr)) +def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: + """Returns the metadata of the input expression. + + If called with one argument, returns a Map of all metadata key-value pairs. + If called with two arguments, returns the value for the specified metadata key. + + Examples: + >>> import pyarrow as pa + >>> field = pa.field("val", pa.int64(), metadata={"k": "v"}) + >>> schema = pa.schema([field]) + >>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema) + >>> ctx = dfn.SessionContext() + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.arrow_metadata(dfn.col("val")).alias("meta") + ... ) + >>> ("k", "v") in result.collect_column("meta")[0].as_py() + True + + >>> result = df.select( + ... dfn.functions.arrow_metadata( + ... dfn.col("val"), key="k" + ... ).alias("meta_val") + ... ) + >>> result.collect_column("meta_val")[0].as_py() + 'v' + """ + if key is None: + return Expr(f.arrow_metadata(expr.expr)) + if isinstance(key, str): + key = Expr.string_literal(key) + return Expr(f.arrow_metadata(expr.expr, key.expr)) + + +def get_field(expr: Expr, name: Expr | str) -> Expr: + """Extracts a field from a struct or map by name. + + When the field name is a static string, the bracket operator + ``expr["field"]`` is a convenient shorthand. Use ``get_field`` + when the field name is a dynamic expression. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1], "b": [2]}) + >>> df = df.with_column( + ... "s", + ... dfn.functions.named_struct( + ... [("x", dfn.col("a")), ("y", dfn.col("b"))] + ... ), + ... ) + >>> result = df.select( + ... dfn.functions.get_field(dfn.col("s"), "x").alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 + + Equivalent using bracket syntax: + + >>> result = df.select( + ... dfn.col("s")["x"].alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 + """ + if isinstance(name, str): + name = Expr.string_literal(name) + return Expr(f.get_field(expr.expr, name.expr)) + + +def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: + """Extracts a value from a union type by field name. + + Returns the value of the named field if it is the currently selected + variant, otherwise returns NULL. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_extract(dfn.col("u"), "int").alias("val") + ... ) + >>> result.collect_column("val").to_pylist() + [1, None, 2] + """ + if isinstance(field_name, str): + field_name = Expr.string_literal(field_name) + return Expr(f.union_extract(union_expr.expr, field_name.expr)) + + +def union_tag(union_expr: Expr) -> Expr: + """Returns the tag (active field name) of a union type. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_tag(dfn.col("u")).alias("tag") + ... ) + >>> result.collect_column("tag").to_pylist() + ['int', 'str', 'int'] + """ + return Expr(f.union_tag(union_expr.expr)) + + +def version() -> Expr: + """Returns the DataFusion version string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.empty_table() + >>> result = df.select(dfn.functions.version().alias("v")) + >>> "Apache DataFusion" in result.collect_column("v")[0].as_py() + True + """ + return Expr(f.version()) + + +def row(*args: Expr) -> Expr: + """Returns a struct with the given arguments. + + See Also: + This is an alias for :py:func:`struct`. + """ + return struct(*args) + + def random() -> Expr: """Returns a random value in the range ``0.0 <= x < 1.0``. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 2100da9ae..4e99fa9e3 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -20,7 +20,7 @@ import numpy as np import pyarrow as pa import pytest -from datafusion import SessionContext, column, literal, string_literal +from datafusion import SessionContext, column, literal from datafusion import functions as f np.seterr(invalid="ignore") @@ -1291,11 +1291,8 @@ def test_make_time(df): def test_arrow_cast(df): df = df.select( - # we use `string_literal` to return utf8 instead of `literal` which returns - # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view - # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 - f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"), - f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"), + f.arrow_cast(column("b"), "Float64").alias("b_as_float"), + f.arrow_cast(column("b"), "Int32").alias("b_as_int"), ) result = df.collect() assert len(result) == 1 @@ -1305,6 +1302,19 @@ def test_arrow_cast(df): assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) +def test_arrow_cast_with_pyarrow_type(df): + df = df.select( + f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"), + f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"), + f.arrow_cast(column("b"), pa.string()).alias("b_as_str"), + ) + result = df.collect()[0] + + assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64()) + assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) + assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string()) + + def test_case(df): df = df.select( f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)), @@ -1810,6 +1820,89 @@ def test_conditional_functions(df_with_nulls, expr, expected): assert result.column(0) == expected +def test_get_field(df): + df = df.with_column( + "s", + f.named_struct( + [ + ("x", column("a")), + ("y", column("b")), + ] + ), + ) + result = df.select( + f.get_field(column("s"), "x").alias("x_val"), + f.get_field(column("s"), "y").alias("y_val"), + ).collect()[0] + + assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view()) + assert result.column(1) == pa.array([4, 5, 6]) + + +def test_arrow_metadata(): + ctx = SessionContext() + field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"}) + schema = pa.schema([field]) + batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema) + df = ctx.create_dataframe([[batch]]) + + # One-argument form: returns a Map of all metadata key-value pairs + result = df.select( + f.arrow_metadata(column("val")).alias("meta"), + ).collect()[0] + assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8()) + meta = result.column(0)[0].as_py() + assert ("key1", "value1") in meta + assert ("key2", "value2") in meta + + # Two-argument form: returns the value for a specific metadata key + result = df.select( + f.arrow_metadata(column("val"), "key1").alias("meta_val"), + ).collect()[0] + assert result.column(0)[0].as_py() == "value1" + + +def test_version(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + result = df.select(f.version().alias("v")).collect()[0] + version_str = result.column(0)[0].as_py() + assert "Apache DataFusion" in version_str + + +def test_row(df): + result = df.select( + f.row(column("a"), column("b")).alias("r"), + f.struct(column("a"), column("b")).alias("s"), + ).collect()[0] + # row is an alias for struct, so they should produce the same output + assert result.column(0) == result.column(1) + + +def test_union_tag(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0] + assert result.column(0).to_pylist() == ["int", "str", "int"] + + +def test_union_extract(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_extract(column("u"), "int").alias("val")).collect()[0] + assert result.column(0).to_pylist() == [1, None, 2] + + @pytest.mark.parametrize("func", [f.array_any_value, f.list_any_value]) def test_any_value_aliases(func): ctx = SessionContext()