Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,29 @@ expr_fn_vec!(named_struct);
expr_fn!(from_unixtime, unixtime);
expr_fn!(arrow_typeof, arg_1);
expr_fn!(arrow_cast, arg_1 datatype);
expr_fn_vec!(arrow_metadata);
expr_fn!(union_tag, arg1);
expr_fn!(random);

#[pyfunction]
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
functions::core::get_field()
.call(vec![expr.into(), name.into()])
.into()
}

#[pyfunction]
fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr {
functions::core::union_extract()
.call(vec![union_expr.into(), field_name.into()])
.into()
}

#[pyfunction]
fn version() -> PyExpr {
functions::core::version().call(vec![]).into()
}

// Array Functions
array_fn!(array_append, array element);
array_fn!(array_to_string, array delimiter);
Expand Down Expand Up @@ -1014,6 +1035,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
m.add_wrapped(wrap_pyfunction!(asinh))?;
Expand Down Expand Up @@ -1142,6 +1164,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(trim))?;
m.add_wrapped(wrap_pyfunction!(trunc))?;
m.add_wrapped(wrap_pyfunction!(upper))?;
m.add_wrapped(wrap_pyfunction!(get_field))?;
m.add_wrapped(wrap_pyfunction!(union_extract))?;
m.add_wrapped(wrap_pyfunction!(union_tag))?;
m.add_wrapped(wrap_pyfunction!(version))?;
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
m.add_wrapped(wrap_pyfunction!(var_pop))?;
m.add_wrapped(wrap_pyfunction!(var_sample))?;
Expand Down
174 changes: 171 additions & 3 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"arrays_overlap",
"arrays_zip",
"arrow_cast",
"arrow_metadata",
"arrow_typeof",
"ascii",
"asin",
Expand Down Expand Up @@ -163,6 +164,7 @@
"gcd",
"gen_series",
"generate_series",
"get_field",
"greatest",
"ifnull",
"in_list",
Expand Down Expand Up @@ -280,6 +282,7 @@
"reverse",
"right",
"round",
"row",
"row_number",
"rpad",
"rtrim",
Expand Down Expand Up @@ -322,12 +325,15 @@
"translate",
"trim",
"trunc",
"union_extract",
"union_tag",
"upper",
"uuid",
"var",
"var_pop",
"var_samp",
"var_sample",
"version",
"when",
# Window Functions
"window",
Expand Down Expand Up @@ -2628,22 +2634,184 @@ def arrow_typeof(arg: Expr) -> Expr:
return Expr(f.arrow_typeof(arg.expr))


def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
"""Casts an expression to a specified data type.

The ``data_type`` can be a string, a ``pyarrow.DataType``, or an
``Expr``. For simple types, :py:meth:`Expr.cast()
<datafusion.expr.Expr.cast>` is more concise
(e.g., ``col("a").cast(pa.float64())``). Use ``arrow_cast`` when
you want to specify the target type as a string using DataFusion's
type syntax, which can be more readable for complex types like
``"Timestamp(Nanosecond, None)"``.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1]})
>>> data_type = dfn.string_literal("Float64")
>>> result = df.select(
... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c")
... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c")
... )
>>> result.collect_column("c")[0].as_py()
1.0

>>> import pyarrow as pa
>>> result = df.select(
... dfn.functions.arrow_cast(
... dfn.col("a"), data_type=pa.float64()
... ).alias("c")
... )
>>> result.collect_column("c")[0].as_py()
1.0
"""
if isinstance(data_type, pa.DataType):
return expr.cast(data_type)
if isinstance(data_type, str):
data_type = Expr.string_literal(data_type)
return Expr(f.arrow_cast(expr.expr, data_type.expr))


def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
"""Returns the metadata of the input expression.

If called with one argument, returns a Map of all metadata key-value pairs.
If called with two arguments, returns the value for the specified metadata key.

Examples:
>>> import pyarrow as pa
>>> field = pa.field("val", pa.int64(), metadata={"k": "v"})
>>> schema = pa.schema([field])
>>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema)
>>> ctx = dfn.SessionContext()
>>> df = ctx.create_dataframe([[batch]])
>>> result = df.select(
... dfn.functions.arrow_metadata(dfn.col("val")).alias("meta")
... )
>>> ("k", "v") in result.collect_column("meta")[0].as_py()
True

>>> result = df.select(
... dfn.functions.arrow_metadata(
... dfn.col("val"), key="k"
... ).alias("meta_val")
... )
>>> result.collect_column("meta_val")[0].as_py()
'v'
"""
if key is None:
return Expr(f.arrow_metadata(expr.expr))
if isinstance(key, str):
key = Expr.string_literal(key)
return Expr(f.arrow_metadata(expr.expr, key.expr))


def get_field(expr: Expr, name: Expr | str) -> Expr:
"""Extracts a field from a struct or map by name.

When the field name is a static string, the bracket operator
``expr["field"]`` is a convenient shorthand. Use ``get_field``
when the field name is a dynamic expression.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1], "b": [2]})
>>> df = df.with_column(
... "s",
... dfn.functions.named_struct(
... [("x", dfn.col("a")), ("y", dfn.col("b"))]
... ),
... )
>>> result = df.select(
... dfn.functions.get_field(dfn.col("s"), "x").alias("x_val")
... )
>>> result.collect_column("x_val")[0].as_py()
1

Equivalent using bracket syntax:

>>> result = df.select(
... dfn.col("s")["x"].alias("x_val")
... )
>>> result.collect_column("x_val")[0].as_py()
1
"""
if isinstance(name, str):
name = Expr.string_literal(name)
return Expr(f.get_field(expr.expr, name.expr))


def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
"""Extracts a value from a union type by field name.

Returns the value of the named field if it is the currently selected
variant, otherwise returns NULL.

Examples:
>>> import pyarrow as pa
>>> ctx = dfn.SessionContext()
>>> types = pa.array([0, 1, 0], type=pa.int8())
>>> offsets = pa.array([0, 0, 1], type=pa.int32())
>>> arr = pa.UnionArray.from_dense(
... types, offsets, [pa.array([1, 2]), pa.array(["hi"])],
... ["int", "str"], [0, 1],
... )
>>> batch = pa.RecordBatch.from_arrays([arr], names=["u"])
>>> df = ctx.create_dataframe([[batch]])
>>> result = df.select(
... dfn.functions.union_extract(dfn.col("u"), "int").alias("val")
... )
>>> result.collect_column("val").to_pylist()
[1, None, 2]
"""
if isinstance(field_name, str):
field_name = Expr.string_literal(field_name)
return Expr(f.union_extract(union_expr.expr, field_name.expr))


def union_tag(union_expr: Expr) -> Expr:
"""Returns the tag (active field name) of a union type.

Examples:
>>> import pyarrow as pa
>>> ctx = dfn.SessionContext()
>>> types = pa.array([0, 1, 0], type=pa.int8())
>>> offsets = pa.array([0, 0, 1], type=pa.int32())
>>> arr = pa.UnionArray.from_dense(
... types, offsets, [pa.array([1, 2]), pa.array(["hi"])],
... ["int", "str"], [0, 1],
... )
>>> batch = pa.RecordBatch.from_arrays([arr], names=["u"])
>>> df = ctx.create_dataframe([[batch]])
>>> result = df.select(
... dfn.functions.union_tag(dfn.col("u")).alias("tag")
... )
>>> result.collect_column("tag").to_pylist()
['int', 'str', 'int']
"""
return Expr(f.union_tag(union_expr.expr))


def version() -> Expr:
"""Returns the DataFusion version string.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.empty_table()
>>> result = df.select(dfn.functions.version().alias("v"))
>>> "Apache DataFusion" in result.collect_column("v")[0].as_py()
True
"""
return Expr(f.version())


def row(*args: Expr) -> Expr:
"""Returns a struct with the given arguments.

See Also:
This is an alias for :py:func:`struct`.
"""
return struct(*args)


def random() -> Expr:
"""Returns a random value in the range ``0.0 <= x < 1.0``.

Expand Down
Loading
Loading