diff --git a/conftest.py b/conftest.py index 73e90077a..0c9410636 100644 --- a/conftest.py +++ b/conftest.py @@ -19,6 +19,7 @@ import datafusion as dfn import numpy as np +import pyarrow as pa import pytest from datafusion import col, lit from datafusion import functions as F @@ -29,6 +30,7 @@ def _doctest_namespace(doctest_namespace: dict) -> None: """Add common imports to the doctest namespace.""" doctest_namespace["dfn"] = dfn doctest_namespace["np"] = np + doctest_namespace["pa"] = pa doctest_namespace["col"] = col doctest_namespace["lit"] = lit doctest_namespace["F"] = F diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 1300a1595..ce11ef04e 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -41,7 +41,7 @@ use datafusion::execution::context::{ }; use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; -use datafusion::execution::options::ReadOptions; +use datafusion::execution::options::{ArrowReadOptions, ReadOptions}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::{ @@ -974,6 +974,39 @@ impl PySessionContext { Ok(()) } + #[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn register_arrow( + &self, + name: &str, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult<()> { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.register_arrow(name, path, options); + wait_for_future(py, result)??; + Ok(()) + } + + pub fn register_batch( + &self, + name: &str, + batch: PyArrowType, + ) -> PyDataFusionResult<()> { + self.ctx.register_batch(name, batch.0)?; + Ok(()) + } + // Registers a PyArrow.Dataset pub fn register_dataset( &self, @@ -1214,6 +1247,29 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } + #[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn read_arrow( + &self, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.read_arrow(path, options); + let df = wait_for_future(py, result)??; + Ok(PyDataFrame::new(df)) + } + pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { let session = self.clone().into_bound_py_any(table.py())?; let table = PyTable::new(table, Some(session))?; diff --git a/python/datafusion/context.py b/python/datafusion/context.py index f190e3ca1..7a306f04c 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -903,6 +903,27 @@ def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" self.ctx.register_udtf(func._udtf) + def register_batch(self, name: str, batch: pa.RecordBatch) -> None: + """Register a single :py:class:`pa.RecordBatch` as a table. + + Args: + name: Name of the resultant table. + batch: Record batch to register as a table. + + Examples: + >>> ctx = dfn.SessionContext() + >>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + >>> ctx.register_batch("batch_tbl", batch) + >>> ctx.sql("SELECT * FROM batch_tbl").collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + self.ctx.register_batch(name, batch) + def deregister_udtf(self, name: str) -> None: """Remove a user-defined table function from the session. @@ -1109,6 +1130,86 @@ def register_avro( name, str(path), schema, file_extension, table_partition_cols ) + def register_arrow( + self, + name: str, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> None: + """Register an Arrow IPC file as a table. + + The registered table can be referenced from SQL statements executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + table_partition_cols: Partition columns. + + Examples: + >>> import tempfile, os + >>> ctx = dfn.SessionContext() + >>> table = pa.table({"x": [10, 20, 30]}) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow("arrow_tbl", path) + ... ctx.sql("SELECT * FROM arrow_tbl").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + + Provide an explicit ``schema`` to override schema inference: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow( + ... "arrow_schema", + ... path, + ... schema=pa.schema([("x", pa.int64())]), + ... ) + ... ctx.sql("SELECT * FROM arrow_schema").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + + Use ``file_extension`` to read files with a non-default extension: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.ipc") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow( + ... "arrow_ipc", path, file_extension=".ipc" + ... ) + ... ctx.sql("SELECT * FROM arrow_ipc").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + """ + if table_partition_cols is None: + table_partition_cols = [] + table_partition_cols = _convert_table_partition_cols(table_partition_cols) + self.ctx.register_arrow( + name, str(path), schema, file_extension, table_partition_cols + ) + def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None: """Register a :py:class:`pa.dataset.Dataset` as a table. @@ -1369,6 +1470,86 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) + def read_arrow( + self, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> DataFrame: + """Create a :py:class:`DataFrame` for reading an Arrow IPC data source. + + Args: + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + file_partition_cols: Partition columns. + + Returns: + DataFrame representation of the read Arrow IPC file. + + Examples: + >>> import tempfile, os + >>> ctx = dfn.SessionContext() + >>> table = pa.table({"a": [1, 2, 3]}) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path) + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + + Provide an explicit ``schema`` to override schema inference: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path, schema=pa.schema([("a", pa.int64())])) + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + + Use ``file_extension`` to read files with a non-default extension: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.ipc") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path, file_extension=".ipc") + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + if file_partition_cols is None: + file_partition_cols = [] + file_partition_cols = _convert_table_partition_cols(file_partition_cols) + return DataFrame( + self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols) + ) + + def read_empty(self) -> DataFrame: + """Create an empty :py:class:`DataFrame` with no columns or rows. + + See Also: + This is an alias for :meth:`empty_table`. + """ + return self.empty_table() + def read_table( self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset ) -> DataFrame: diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 3eaccdfa3..848ab4cee 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -213,7 +213,6 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 Examples: Using ``udf`` as a function: - >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> from datafusion.user_defined import ScalarUDF >>> def double_func(x): @@ -480,7 +479,6 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 instance in which this UDAF is used. Examples: - >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> from datafusion.user_defined import AggregateUDF, Accumulator, udaf >>> class Summarize(Accumulator): @@ -874,7 +872,6 @@ def udwf(*args: Any, **kwargs: Any): # noqa: D417 When using ``udwf`` as a decorator, do not pass ``func`` explicitly. Examples: - >>> import pyarrow as pa >>> from datafusion.user_defined import WindowUDF, WindowEvaluator, udwf >>> class BiasedNumbers(WindowEvaluator): ... def __init__(self, start: int = 0): diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 8491cc3a5..25f66a647 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -788,6 +788,68 @@ def test_read_avro(ctx): assert avro_df is not None +def test_read_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then read it back + table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + df = ctx.read_arrow(str(arrow_path)) + result = df.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array(["x", "y", "z"]) + + # Also verify pathlib.Path works + df = ctx.read_arrow(arrow_path) + result = df.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + +def test_read_empty(ctx): + df = ctx.read_empty() + result = df.collect() + assert len(result) == 1 + assert result[0].num_columns == 0 + + df = ctx.empty_table() + result = df.collect() + assert len(result) == 1 + assert result[0].num_columns == 0 + + +def test_register_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then register and query it + table = pa.table({"x": [10, 20, 30]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + ctx.register_arrow("arrow_tbl", str(arrow_path)) + result = ctx.sql("SELECT * FROM arrow_tbl").collect() + assert result[0].column(0) == pa.array([10, 20, 30]) + + # Also verify pathlib.Path works + ctx.register_arrow("arrow_tbl_path", arrow_path) + result = ctx.sql("SELECT * FROM arrow_tbl_path").collect() + assert result[0].column(0) == pa.array([10, 20, 30]) + + +def test_register_batch(ctx): + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + ctx.register_batch("batch_tbl", batch) + result = ctx.sql("SELECT * FROM batch_tbl").collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array([4, 5, 6]) + + +def test_register_batch_empty(ctx): + batch = pa.RecordBatch.from_pydict({"a": pa.array([], type=pa.int64())}) + ctx.register_batch("empty_batch_tbl", batch) + result = ctx.sql("SELECT * FROM empty_batch_tbl").collect() + assert result[0].num_rows == 0 + + def test_create_sql_options(): SQLOptions()