# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # cython: language_level = 3 from cython.operator cimport dereference as deref from libcpp.vector cimport vector as std_vector from pyarrow import Buffer, py_buffer from pyarrow._compute cimport Expression from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * # TODO GH-37235: Fix exception handling cdef CDeclaration _create_named_table_provider( dict named_args, const std_vector[c_string]& names, const CSchema& schema ) noexcept: cdef: c_string c_name shared_ptr[CTable] c_in_table shared_ptr[CTableSourceNodeOptions] c_tablesourceopts shared_ptr[CExecNodeOptions] c_input_node_opts vector[CDeclaration.Input] no_c_inputs py_names = [] for i in range(names.size()): c_name = names[i] py_names.append(frombytes(c_name)) py_schema = pyarrow_wrap_schema(make_shared[CSchema](schema)) py_table = named_args["provider"](py_names, py_schema) c_in_table = pyarrow_unwrap_table(py_table) c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table) c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions]( c_tablesourceopts) return CDeclaration(tobytes("table_source"), no_c_inputs, c_input_node_opts) def run_query(plan, *, table_provider=None, use_threads=True): """ Execute a Substrait plan and read the results as a RecordBatchReader. Parameters ---------- plan : Union[Buffer, bytes] The serialized Substrait plan to execute. table_provider : object (optional) A function to resolve any NamedTable relation to a table. The function will receive two arguments which will be a list of strings representing the table name and a pyarrow.Schema representing the expected schema and should return a pyarrow.Table. use_threads : bool, default True If True then multiple threads will be used to run the query. If False then all CPU intensive work will be done on the calling thread. Returns ------- RecordBatchReader A reader containing the result of the executed query Examples -------- >>> import pyarrow as pa >>> from pyarrow.lib import tobytes >>> import pyarrow.substrait as substrait >>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) >>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]}) >>> def table_provider(names, schema): ... if not names: ... raise Exception("No names provided") ... elif names[0] == "t1": ... return test_table_1 ... elif names[1] == "t2": ... return test_table_2 ... else: ... raise Exception("Unrecognized table name") ... >>> substrait_query = ''' ... { ... "relations": [ ... {"rel": { ... "read": { ... "base_schema": { ... "struct": { ... "types": [ ... {"i64": {}} ... ] ... }, ... "names": [ ... "x" ... ] ... }, ... "namedTable": { ... "names": ["t1"] ... } ... } ... }} ... ] ... } ... ''' >>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) >>> reader = pa.substrait.run_query(buf, table_provider=table_provider) >>> reader.read_all() pyarrow.Table x: int64 ---- x: [[1,2,3]] """ cdef: CResult[shared_ptr[CRecordBatchReader]] c_res_reader shared_ptr[CRecordBatchReader] c_reader RecordBatchReader reader shared_ptr[CBuffer] c_buf_plan CConversionOptions c_conversion_options c_bool c_use_threads c_use_threads = use_threads if isinstance(plan, bytes): c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan)) elif isinstance(plan, Buffer): c_buf_plan = pyarrow_unwrap_buffer(plan) else: raise TypeError( f"Expected 'pyarrow.Buffer' or bytes, got '{type(plan)}'") if table_provider is not None: named_table_args = { "provider": table_provider } c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider]( &_create_named_table_provider, named_table_args) with nogil: c_res_reader = ExecuteSerializedPlan( deref(c_buf_plan), default_extension_id_registry(), GetFunctionRegistry(), c_conversion_options, c_use_threads) c_reader = GetResultValue(c_res_reader) reader = RecordBatchReader.__new__(RecordBatchReader) reader.reader = c_reader return reader def _parse_json_plan(plan): """ Parse a JSON plan into equivalent serialized Protobuf. Parameters ---------- plan : bytes Substrait plan in JSON. Returns ------- Buffer A buffer containing the serialized Protobuf plan. """ cdef: CResult[shared_ptr[CBuffer]] c_res_buffer c_string c_str_plan shared_ptr[CBuffer] c_buf_plan c_str_plan = plan c_res_buffer = SerializeJsonPlan(c_str_plan) with nogil: c_buf_plan = GetResultValue(c_res_buffer) return pyarrow_wrap_buffer(c_buf_plan) def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False): """ Serialize a collection of expressions into Substrait Substrait expressions must be bound to a schema. For example, the Substrait expression ``a:i32 + b:i32`` is different from the Substrait expression ``a:i64 + b:i64``. Pyarrow expressions are typically unbound. For example, both of the above expressions would be represented as ``a + b`` in pyarrow. This means a schema must be provided when serializing an expression. It also means that the serialization may fail if a matching function call cannot be found for the expression. Parameters ---------- exprs : list of Expression The expressions to serialize names : list of str Names for the expressions schema : Schema The schema the expressions will be bound to allow_arrow_extensions : bool, default False If False then only functions that are part of the core Substrait function definitions will be allowed. Set this to True to allow pyarrow-specific functions and user defined functions but the result may not be accepted by other compute libraries. Returns ------- Buffer An ExtendedExpression message containing the serialized expressions """ cdef: CResult[shared_ptr[CBuffer]] c_res_buffer shared_ptr[CBuffer] c_buffer CNamedExpression c_named_expr CBoundExpressions c_bound_exprs CConversionOptions c_conversion_options if len(exprs) != len(names): raise ValueError("exprs and names need to have the same length") for expr, name in zip(exprs, names): if not isinstance(expr, Expression): raise TypeError(f"Expected Expression, got '{type(expr)}' in exprs") if not isinstance(name, str): raise TypeError(f"Expected str, got '{type(name)}' in names") c_named_expr.expression = ( expr).unwrap() c_named_expr.name = tobytes( name) c_bound_exprs.named_expressions.push_back(c_named_expr) c_bound_exprs.schema = ( schema).sp_schema c_conversion_options.allow_arrow_extensions = allow_arrow_extensions with nogil: c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options) c_buffer = GetResultValue(c_res_buffer) return pyarrow_wrap_buffer(c_buffer) cdef class BoundExpressions(_Weakrefable): """ A collection of named expressions and the schema they are bound to This is equivalent to the Substrait ExtendedExpression message """ cdef: CBoundExpressions c_bound_exprs def __init__(self): msg = 'BoundExpressions is an abstract class thus cannot be initialized.' raise TypeError(msg) cdef void init(self, CBoundExpressions bound_expressions): self.c_bound_exprs = bound_expressions @property def schema(self): """ The common schema that all expressions are bound to """ return pyarrow_wrap_schema(self.c_bound_exprs.schema) @property def expressions(self): """ A dict from expression name to expression """ expr_dict = {} for named_expr in self.c_bound_exprs.named_expressions: name = frombytes(named_expr.name) expr = Expression.wrap(named_expr.expression) expr_dict[name] = expr return expr_dict @staticmethod cdef wrap(const CBoundExpressions& bound_expressions): cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions) self.init(bound_expressions) return self def deserialize_expressions(buf): """ Deserialize an ExtendedExpression Substrait message into a BoundExpressions object Parameters ---------- buf : Buffer or bytes The message to deserialize Returns ------- BoundExpressions The deserialized expressions, their names, and the bound schema """ cdef: shared_ptr[CBuffer] c_buffer CResult[CBoundExpressions] c_res_bound_exprs CBoundExpressions c_bound_exprs if isinstance(buf, bytes): c_buffer = pyarrow_unwrap_buffer(py_buffer(buf)) elif isinstance(buf, Buffer): c_buffer = pyarrow_unwrap_buffer(buf) else: raise TypeError( f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'") with nogil: c_res_bound_exprs = DeserializeExpressions(deref(c_buffer)) c_bound_exprs = GetResultValue(c_res_bound_exprs) return BoundExpressions.wrap(c_bound_exprs) def get_supported_functions(): """ Get a list of Substrait functions that the underlying engine currently supports. Returns ------- list[str] A list of function ids encoded as '{uri}#{name}' """ cdef: ExtensionIdRegistry* c_id_registry std_vector[c_string] c_ids c_id_registry = default_extension_id_registry() c_ids = c_id_registry.GetSupportedSubstraitFunctions() functions_list = [] for c_id in c_ids: functions_list.append(frombytes(c_id)) return functions_list