diff --git a/_duckdb-stubs/__init__.pyi b/_duckdb-stubs/__init__.pyi index 81d69be7..ea7d3c88 100644 --- a/_duckdb-stubs/__init__.pyi +++ b/_duckdb-stubs/__init__.pyi @@ -13,7 +13,7 @@ if typing.TYPE_CHECKING: import pandas import pyarrow.lib from collections.abc import Callable, Iterable, Sequence, Mapping - from duckdb import sqltypes, func + from duckdb import sqltypes, func, template from builtins import list as lst # needed to avoid mypy error on DuckDBPyRelation.list method shadowing # the field_ids argument to to_parquet and write_parquet has a recursive structure @@ -241,8 +241,12 @@ class DuckDBPyConnection: def dtype(self, type_str: str) -> sqltypes.DuckDBPyType: ... def duplicate(self) -> DuckDBPyConnection: ... def enum_type(self, name: str, type: sqltypes.DuckDBPyType, values: lst[typing.Any]) -> sqltypes.DuckDBPyType: ... - def execute(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... - def executemany(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... + def execute( + self, query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None + ) -> DuckDBPyConnection: ... + def executemany( + self, query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None + ) -> DuckDBPyConnection: ... def extract_statements(self, query: str) -> lst[Statement]: ... def fetch_arrow_table(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.Table: """Deprecated: use to_arrow_table() instead.""" @@ -331,7 +335,9 @@ class DuckDBPyConnection: union_by_name: bool = False, compression: str | None = None, ) -> DuckDBPyRelation: ... - def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def from_query( + self, query: str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None + ) -> DuckDBPyRelation: ... def get_table_names(self, query: str, *, qualified: bool = False) -> set[str]: ... def install_extension( self, @@ -360,7 +366,9 @@ class DuckDBPyConnection: def pl( self, rows_per_batch: typing.SupportsInt = 1000000, *, lazy: bool = False ) -> polars.DataFrame | polars.LazyFrame: ... - def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def query( + self, query: str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None + ) -> DuckDBPyRelation: ... def query_progress(self) -> float: ... def read_csv( self, @@ -462,7 +470,13 @@ class DuckDBPyConnection: def row_type( self, fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType] ) -> sqltypes.DuckDBPyType: ... - def sql(self, query: Statement | str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def sql( + self, + query: Statement | str | template.SqlTemplate | template.CompiledSql, + *, + alias: str = "", + params: object = None, + ) -> DuckDBPyRelation: ... def sqltype(self, type_str: str) -> sqltypes.DuckDBPyType: ... def string_type(self, collation: str = "") -> sqltypes.DuckDBPyType: ... def struct_type( @@ -1160,7 +1174,7 @@ def enum_type( connection: DuckDBPyConnection | None = None, ) -> sqltypes.DuckDBPyType: ... def execute( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None, *, connection: DuckDBPyConnection | None = None, @@ -1282,7 +1296,7 @@ def from_parquet( connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... def from_query( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None, @@ -1350,7 +1364,7 @@ def project( df: pandas.DataFrame, *args: _ExpressionLike, groups: str = "", connection: DuckDBPyConnection | None = None ) -> DuckDBPyRelation: ... def query( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None, @@ -1474,7 +1488,7 @@ def row_type( def rowcount(*, connection: DuckDBPyConnection | None = None) -> int: ... def set_default_connection(connection: DuckDBPyConnection) -> None: ... def sql( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None, diff --git a/duckdb/template.py b/duckdb/template.py new file mode 100644 index 00000000..e386cecd --- /dev/null +++ b/duckdb/template.py @@ -0,0 +1,450 @@ +"""Template system for duckdb SQL statements, based on Python's string.templatelib.""" + +from __future__ import annotations + +import dataclasses +from collections import Counter +from collections.abc import Iterable +from typing import TYPE_CHECKING, Literal, NoReturn, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterator + + from typing_extensions import TypeIs + +__all__ = [ + "CompiledSql", + "IntoInterpolation", + "Param", + "SupportsDuckdbTemplate", + "compile", + "param", + "template", +] + + +@dataclasses.dataclass(frozen=True, slots=True) +class CompiledSql: + """Represents a compiled SQL statement, with the final SQL string and a dict of params to be passed to duckdb. + + You will typically not create this directly, but will get it as the result + of calling .compile() on a SqlTemplate or ResolvedSqlTemplate. + + Example: + >>> age = 37 + >>> c = compile(t"SELECT * FROM users WHERE age >= {age}") + >>> c + >>> CompiledSql(sql="SELECT * FROM users WHERE age >= $p0_age", params={"p0_age": 37}) + duckdb.query(c.sql, c.params) + """ + + sql: str + params: dict[str, object] = dataclasses.field(default_factory=dict) + + def __str__(self) -> NoReturn: + """Disallow accidentally converting to a string, since that would lose the parameters.""" + msg = "CompiledSql cannot be directly converted to a string, since it may contain parameters. Please use the .sql attribute for the SQL string, and the .params attribute for the parameters." # noqa: E501 + raise NotImplementedError(msg) + + +@runtime_checkable +class SupportsDuckdbTemplate(Protocol): + """Something that can be converted into a Template by implementing the __duckdb_template__ method.""" + + def __duckdb_template__( + self, /, **future_kwargs + ) -> ( + str + | IntoInterpolation + | Param + | SupportsDuckdbTemplate + | object + | Iterable[str | IntoInterpolation | Param | SupportsDuckdbTemplate | object] + ): + """Convert self into something that template() understands.""" + + +@dataclasses.dataclass(frozen=True, slots=True) +class Param: + """Represents a parameter to be passed to duckdb, with an optional name and an optional exact flag.""" + + value: object + name: str | None = None + exact: bool = False + + def __post_init__(self) -> None: + """Ensure passed args were valid.""" + if self.exact: + if self.name is None: + msg = "Param with exact=True must have a name." + raise ValueError(msg) + else: + assert_param_name_legal(self.name) + + +def param(value: object, name: str | None = None, *, exact: bool = False) -> Param: + """Helper function to create an Param with an optional name and an optional exact flag.""" + return Param(value=value, name=name, exact=exact) + + +def template(*parts: str | IntoInterpolation | Param | SupportsDuckdbTemplate | object) -> SqlTemplate: + """Convert a sequence of things into a SqlTemplate. + + We go through the parts and convert it into a sequence of str and Interpolations, + which we then hand off to SqlTemplate. + - If the thing has a `.__duckdb_template__()` method, call it, + and call template() recursively on the result. + - If the thing is a `str`, treat it as raw SQL and return a SqlTemplate with that string. + - If it's an Interpolation, leave it as is, treating it as an interpolation. + - It it's a Param, wrap it in an Interpolation. + - Otherwise, treat the thing as a param, and then wrap the Param in an Interpolation. + + Examples: + A very simple example is just passing a string, which will be treated as raw SQL: + + >>> t = template("SELECT * FROM users WHERE id = 123") + >>> repr(t) + SqlTemplate('SELECT * FROM users WHERE id = 123') + >>> t.compile() + CompiledSql(sql='SELECT * FROM users WHERE id = 123', params={}) + + In python 3.14+, [tstrings](https://docs.python.org/3/library/string.templatelib.html) + are very useful here. + Any interpolation inside a tstring will be richly interpreted, + either treated as a param, or expanded as a subquery: + + >>> user_id = 123 + >>> t = template(t"SELECT * FROM users WHERE id = {user_id}") + >>> repr(t) + SqlTemplate('SELECT * FROM users WHERE id = ', Param(value=123, name='user_id'), '') + >>> t.compile() + CompiledSql(sql='SELECT * FROM users WHERE id = $p0_user_id', params={'p0_user_id': 123}) + + This is very friendly with chaining relations: + >>> all_people = duckdb.sql("SELECT * FROM people") + >>> age = 18 + >>> adults = template(t"SELECT * FROM ({all_people}) WHERE age >= {age}") + >>> names = template(t"SELECT name FROM ({adults})") + >>> names.compile() + CompiledSql(sql='SELECT name FROM (SELECT * FROM (SELECT * FROM people) WHERE age >= $p0_age)', params={'p0_age': 18}) + + We also support iterables of strings and Interpolations/Params/etc, + which will be joined together into a single template. + This is very useful for versions of python before 3.14 that don't have tstrings, + since it allows you to build up a template from smaller pieces: + + >>> t = template("SELECT * FROM (", all_people, ") WHERE age >= ", age) + >>> t.compile() + CompiledSql(sql='SELECT * FROM (SELECT * FROM people) WHERE age >= $p0_age', params={'p0_age': 18}) + + You can define evaluation logic for your custom types by defining a `.__duckdb_template__()` method. + If this method is defined, the result of that call will be used to create the template. + + >>> class Record: + ... def __init__(self, table_name: str, id: int): + ... self.table_name = table_name + ... self.id = id + ... + ... def __duckdb_template__(self, **kwargs): + ... id = self.id + ... # note the use of !s to indicate that the table name should be treated as raw SQL + ... return t"SELECT * FROM {self.table_name!s} WHERE id = {id}" + >>> t = template(Record("users", 123)) + >>> t.compile() + CompiledSql(sql='SELECT * FROM users WHERE id = $p0_id', params={'p0_id': 123}) + """ # noqa: E501 + expanded = [] + for part in parts: + expanded.extend(_expand_part(part)) + return SqlTemplate(*expanded) + + +def compile(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | object) -> CompiledSql: + """Compile a sequence of things into a final SQL string with named parameter placeholders, and a list of Params. + + This is a convenience function that combines template() and .compile() into one step. + + For more details and examples, see template(). + """ + t = template(*part) + return t.compile() + + +def _is_iterable_nonstring(value: object) -> TypeIs[Iterable]: + return isinstance(value, Iterable) and not isinstance(value, str | bytes) + + +def _expand_part(part: object) -> Iterable[str | IntoInterpolation]: + if isinstance(part, SupportsDuckdbTemplate): + raw = part.__duckdb_template__() + if isinstance(raw, str): # noqa: SIM114 + yield raw + elif isinstance(raw, IntoInterpolation): + yield raw + elif isinstance(raw, Param): + yield ParamInterpolation(raw) + elif _is_iterable_nonstring(raw): + for item in raw: + yield from _expand_part(item) + else: + p = param(value=raw) + yield ParamInterpolation(p) + elif isinstance(part, str): # noqa: SIM114 + yield part + elif isinstance(part, IntoInterpolation): + yield part + elif isinstance(part, Param): + yield ParamInterpolation(part) + elif _is_iterable_nonstring(part): + for item in part: + yield from _expand_part(item) + else: + p = param(value=part) + yield ParamInterpolation(p) + + +class ParamInterpolation: + """A simple wrapper that implements the IntoInterpolation protocol for a given Param.""" + + def __init__(self, param: Param): # noqa: ANN204 + self.value = param + self.expression = param.name + self.conversion = None + self.format_spec = "" + + def __repr__(self) -> str: + return repr(self.value) + + +def _resolve(parts: Iterable[str | IntoInterpolation]) -> ResolvedSqlTemplate: + """Resolve a stream of strings and Interpolations, recursively resolving inner interpolations.""" + resolved: list[str | Param] = [] + for part in parts: + if isinstance(part, str): + resolved.append(part) + else: + resolved.extend(_resolve_interpolation(part)) + return ResolvedSqlTemplate(resolved) + + +def _resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: + value = interp.value + if isinstance(value, Param): + # Preserve direct ParamInterpolation behavior while allowing nested named Params to keep exact names. + if isinstance(interp, ParamInterpolation): + return (value,) + if value.name is None: + return (param(value.value),) + return (param(value.value, name=value.name, exact=True),) + + # if conversion specified (!s, !r, !a), treat as raw sql, eg + # name = "Alice" + # t"SELECT * FROM users where name = '{name!s}'" should be + # "SELECT * FROM users where name = 'Alice'", eg with no param, + # since the user is explicitly asking for the value to be directly interpolated into the SQL string, + # rather than passed as a param. + # This is useful for cases where the value is not something that can be passed as a param, + # eg an identifier like a table or column name, or a SQL expression like "CURRENT_DATE", + # or if the user just wants to write raw SQL and doesn't care about safety + # Note that this is potentially unsafe if the value comes from an untrusted source, + # since it could lead to SQL injection vulnerabilities, so it should be used with caution. + # + # Follow Python's f-string semantics: apply conversion first, then format_spec. + # e.g. {x!r:.10} means format(repr(x), ".10") + if interp.conversion == "s": + converted = str(value) + elif interp.conversion == "r": + converted = repr(value) + elif interp.conversion == "a": + converted = ascii(value) + else: + converted = None + + if converted is not None: + return (format(converted, interp.format_spec),) + + if isinstance(value, str): + # do NOT pass to template, since that would treat it as a raw SQL. + return (param(value, name=interp.expression),) + templ = template(value) + # If the resolved inner is just a single Interpolation, then just return + # the original value so that we preserve the expression name. + # For example, if we have + # ```python + # age = 37 + # people = t"SELECT * FROM people WHERE age = {age}" + # resolve_template(people) + # ``` + # should resolve to: + # "SELECT * FROM people WHERE age = $age", with a param $age=37, + # eg with a friendly param name, rather than + # "SELECT * FROM people WHERE age = $p0", with a param $p0=37 + if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.interpolations) == 1: + return (param(value, name=interp.expression),) + else: + # We got something nested, eg + # age = 37 + # people = t"SELECT * FROM people WHERE age = {age}" + # names = t"SELECT name FROM ({foo})" + # names should resolve to: + # "SELECT name FROM (SELECT * FROM people WHERE age = $age)", with a param $age=37 + return _resolve(templ) + + +@runtime_checkable +class IntoInterpolation(Protocol): + """Something that can be converted into a string.templatelib.Interpolation.""" + + value: object + conversion: Literal["s", "r", "a"] | None + expression: str | None + format_spec: str + + +def assert_param_name_legal(name: str) -> None: + """Eg `$param_1` is legal, but `$1param`, `$param-1`, `$param 1`, and `$p ; DROP TABLE users` are not.""" + # not implemented yet + # Not exactly sure what part of the stack this should get called in, + # or perhaps we shouldn't even check, just pass it to duckdb and let it error if it's illegal + + +class SqlTemplate: + """A sequence of strings and Interpolations.""" + + def __init__(self, *parts: str | IntoInterpolation) -> None: + for part in parts: + if not isinstance(part, str | IntoInterpolation): + msg = f"Unexpected part type in SqlTemplate: {type(part).__name__}. Expected str or IntoInterpolation." + raise TypeError(msg) + self.strings, self.interpolations = parse_parts(parts) + + def __iter__(self) -> Iterator[str | IntoInterpolation]: + """Iterate over the strings and interpolations in order.""" + for s, i in zip(self.strings, self.interpolations, strict=False): + yield s + yield i + yield self.strings[-1] + + def resolve(self) -> ResolvedSqlTemplate: + """Recursively resolve Interpolations into Params, returning a ResolvedSqlTemplate.""" + return _resolve(self) + + def compile(self) -> CompiledSql: + """Compile this template into a final SQL string with named parameter placeholders, and a list of Params.""" + resolved = self.resolve() + return compile_parts(resolved) + + def __str__(self) -> NoReturn: + msg = "SqlTemplate cannot be directly converted to a string, since it may contain unresolved interpolations. Please call .resolve() or .compile() first." # noqa: E501 + raise NotImplementedError(msg) + + def __repr__(self) -> str: + part_strings = [repr(part) for part in self] + return f"SqlTemplate({', '.join(part_strings)})" + + +class ResolvedSqlTemplate: + """A SqlTemplate that has been resolved to only strings and Params.""" + + def __init__(self, parts: Iterable[str | Param]) -> None: + self.parts = tuple(parts) + + def compile(self) -> CompiledSql: + """Compile this template into a final SQL string with named parameter placeholders, and a list of Params.""" + return compile_parts(self.parts) + + def __str__(self) -> NoReturn: + msg = "ResolvedSqlTemplate cannot be directly converted to a string, since it may contain unresolved interpolations. Please call .compile() first." # noqa: E501 + raise NotImplementedError(msg) + + def __repr__(self) -> str: + part_strings = [] + for part in self.parts: + if isinstance(part, str): + part_strings.append(repr(part)) + else: + part_strings.append(f"{{{part.name}={part.value}}}") + return f"ResolvedSqlTemplate({', '.join(part_strings)})" + + def __iter__(self) -> Iterator[str | Param]: + start = 0 + end = len(self.parts) + while start < end and self.parts[start] == "": + start += 1 + while end > start and self.parts[end - 1] == "": + end -= 1 + yield from self.parts[start:end] + + def __eq__(self, other: object) -> bool: + if isinstance(other, CompiledSql): + return self.compile() == other + if isinstance(other, ResolvedSqlTemplate): + return self.parts == other.parts + return False + + +T = TypeVar("T") + + +def parse_parts(parts: Iterable[str | T]) -> tuple[tuple[str, ...], tuple[T, ...]]: + """Parse an iterable of strings and others into separate tuples of strings and others. + + This merges adjacent strings and ensuring that the number of strings is one more than the number of others. + """ + strings, others = [], [] + last_thing: Literal["string", "other"] | None = None + for part in parts: + if isinstance(part, str): + if last_thing == "string": + # Merge adjacent string parts + strings[-1] += part + else: + strings.append(part) + last_thing = "string" + else: + if last_thing is None or last_thing == "other": + # this is the first part or there were two adjacent others, + # so we need an empty string spacer + strings.append("") + others.append(part) + last_thing = "other" + if last_thing is None: + # Empty input — return a single empty string to maintain the invariant + strings.append("") + elif last_thing == "other": + # If the last part was an other, we need to end with an empty string + strings.append("") + assert len(strings) == len(others) + 1 + return tuple(strings), tuple(others) + + +def compile_parts(parts: Iterable[str | Param], /) -> CompiledSql: + """Compile parts into a final SQL string with named parameter placeholders, and a list of Params.""" + sql_parts: list[str] = [] + params_items = [] + + def next_name(suffix: str | None = None) -> str: + # still count exact params in the count, so we get p0, my_param, p2, p3, my_param_2, p5, etc + base = f"p{len(params_items)}" + if suffix is not None: + return f"{base}_{suffix}" + else: + return base + + for part in parts: + if isinstance(part, str): + sql_parts.append(part) + else: + if passed_name := part.name: + param_name = passed_name if part.exact else next_name(passed_name) + else: + param_name = next_name() + assert_param_name_legal(param_name) + sql_parts.append(f"${param_name}") + params_items.append((param_name, part.value)) + param_name_counts = Counter(name for name, _ in params_items) + dupes = [name for name, count in param_name_counts.items() if count > 1] + if dupes: + msg = f"Duplicate parameter names found: {dupes}. Please ensure all parameter names are unique." + raise ValueError(msg) + return CompiledSql(sql="".join(sql_parts), params=dict(params_items)) diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index dd7c9d2e..3d0e0509 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -363,6 +363,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, bool side_effects); void RegisterArrowObject(const py::object &arrow_object, const string &name); + pair ExtractCompiledSqlAndParams(const py::object &query, py::object params); vector> GetStatements(const py::object &query); static PythonEnvironmentType environment; diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 6883ba45..400dcd09 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -708,11 +708,74 @@ shared_ptr DuckDBPyConnection::ExecuteFromString(const strin return Execute(py::str(query)); } +pair DuckDBPyConnection::ExtractCompiledSqlAndParams(const py::object &query, py::object params) { + py::object compiled = query; + if (!py::hasattr(compiled, "sql") || !py::hasattr(compiled, "params")) { + if (!py::hasattr(query, "compile")) { + return {string(), params}; + } + compiled = query.attr("compile")(); + } + + if (!py::hasattr(compiled, "sql") || !py::hasattr(compiled, "params")) { + return {string(), params}; + } + + auto compiled_sql = py::cast(compiled.attr("sql")); + auto compiled_params_obj = compiled.attr("params"); + if (!py::is_dict_like(compiled_params_obj)) { + throw InvalidInputException("Compiled SQL parameters must be a dictionary"); + } + + auto compiled_params = py::cast(compiled_params_obj); + if (compiled_params.empty()) { + return {compiled_sql, params}; + } + + if (params.is_none()) { + return {compiled_sql, compiled_params}; + } + + if (py::is_dict_like(params)) { + auto merged_params = py::dict(); + for (auto &item : compiled_params) { + merged_params[item.first] = item.second; + } + auto provided_params = py::cast(params); + for (auto &item : provided_params) { + if (merged_params.contains(item.first)) { + throw py::value_error("Cannot merge compiled SQL parameters with duplicate parameter names"); + } + merged_params[item.first] = item.second; + } + return {compiled_sql, merged_params}; + } + + if (py::is_list_like(params)) { + if (py::len(params) == 0) { + return {compiled_sql, compiled_params}; + } + throw py::value_error("Cannot merge compiled SQL named parameters with positional parameters"); + } + + throw InvalidInputException("Prepared parameters can only be passed as a list or a dictionary"); +} + shared_ptr DuckDBPyConnection::Execute(const py::object &query, py::object params) { py::gil_scoped_acquire gil; con.SetResult(nullptr); - auto statements = GetStatements(query); + auto normalized_query = ExtractCompiledSqlAndParams(query, params); + auto &compiled_sql = normalized_query.first; + auto &merged_params = normalized_query.second; + vector> statements; + if (!compiled_sql.empty()) { + statements = GetStatements(py::str(compiled_sql)); + params = merged_params; + } else { + statements = GetStatements(query); + } + if (statements.empty()) { // TODO: should we throw? return nullptr; @@ -1603,7 +1666,17 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer alias = "unnamed_relation_" + StringUtil::GenerateRandomName(16); } - auto statements = GetStatements(query); + auto normalized_query = ExtractCompiledSqlAndParams(query, params); + auto &compiled_sql = normalized_query.first; + auto &merged_params = normalized_query.second; + vector> statements; + if (!compiled_sql.empty()) { + statements = GetStatements(py::str(compiled_sql)); + params = merged_params; + } else { + statements = GetStatements(query); + } + if (statements.empty()) { // TODO: should we throw? return nullptr; @@ -1616,7 +1689,7 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer // Attempt to create a Relation for lazy execution if possible shared_ptr relation; - bool has_params = !py::none().is(params) && py::len(params) > 0; + bool has_params = !params.is_none() && py::len(params) > 0; if (!has_params) { // No params (or empty params) — use lazy QueryRelation path { diff --git a/src/duckdb_py/pyexpression/initialize.cpp b/src/duckdb_py/pyexpression/initialize.cpp index 11cf5dc3..d709c582 100644 --- a/src/duckdb_py/pyexpression/initialize.cpp +++ b/src/duckdb_py/pyexpression/initialize.cpp @@ -314,6 +314,7 @@ void DuckDBPyExpression::Initialize(py::module_ &m) { Print the stringified version of the expression. )"; expression.def("show", &DuckDBPyExpression::Print, docs); + expression.def("__duckdb_template__", &DuckDBPyExpression::ToString); docs = R"( Set the order by modifier to ASCENDING. diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 4393889a..cd9b1e67 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -342,6 +342,7 @@ void DuckDBPyRelation::Initialize(py::handle &m) { .def("show", &DuckDBPyRelation::Print, "Display a summary of the data", py::kw_only(), py::arg("max_width") = py::none(), py::arg("max_rows") = py::none(), py::arg("max_col_width") = py::none(), py::arg("null_value") = py::none(), py::arg("render_mode") = py::none()) + .def("__duckdb_template__", &DuckDBPyRelation::ToSQL) .def("__str__", &DuckDBPyRelation::ToString) .def("__repr__", &DuckDBPyRelation::ToString); diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 5087de50..03e87873 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -329,6 +329,7 @@ void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); + type_module.def("__duckdb_template__", &DuckDBPyType::ToString); type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), diff --git a/tests/fast/test_template_e2e.py b/tests/fast/test_template_e2e.py new file mode 100644 index 00000000..96f14314 --- /dev/null +++ b/tests/fast/test_template_e2e.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import pytest + +import duckdb +from duckdb.template import param, template + + +def test_connection_sql_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(5) t(i) WHERE i >= ", 2, " ORDER BY i") + assert conn.sql(query).fetchall() == [(2,), (3,), (4,)] + + +def test_connection_query_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(3) t(i) WHERE i < ", 2, " ORDER BY i") + assert conn.query(query).fetchall() == [(0,), (1,)] + + +def test_connection_from_query_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(4) t(i) WHERE i % ", 2, " = 0 ORDER BY i") + assert conn.from_query(query).fetchall() == [(0,), (2,)] + + +def test_connection_execute_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT ", 42) + assert conn.execute(query).fetchone() == (42,) + + +def test_module_level_sql_apis_accept_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(5) t(i) WHERE i BETWEEN ", 1, " AND ", 3, " ORDER BY i") + + assert duckdb.sql(query, connection=conn).fetchall() == [(1,), (2,), (3,)] + assert duckdb.query(query, connection=conn).fetchall() == [(1,), (2,), (3,)] + assert duckdb.from_query(query, connection=conn).fetchall() == [(1,), (2,), (3,)] + + +def test_module_level_execute_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT ", 5) + assert duckdb.execute(query, connection=conn).fetchone() == (5,) + + +def test_connection_sql_accepts_alias_kwarg_with_template() -> None: + conn = duckdb.connect() + inner = conn.sql(template("SELECT 42 AS x"), alias="my_alias") + assert inner.alias == "my_alias" + outer = conn.sql(template("SELECT x FROM (", inner, ")")) + assert outer.fetchall() == [(42,)] + + +def test_connection_sql_template_can_merge_additional_params() -> None: + conn = duckdb.connect() + query = template("SELECT ", 10, " + $another") + assert conn.sql(query, params={"another": 5}).fetchall() == [(15,)] + + +def test_connection_sql_template_param_name_conflict_with_additional_params_raises() -> None: + conn = duckdb.connect() + query = template("SELECT ", param(10, "num", exact=True), " + $num") + with pytest.raises((duckdb.InvalidInputException, ValueError)): + conn.sql(query, params={"num": 5}).fetchall() + + +def test_cant_merge_with_positional_params() -> None: + conn = duckdb.connect() + # It doesn't even have a name, but still should error + query = template("SELECT ", 10, " + ?") + with pytest.raises(ValueError, match="Cannot merge compiled SQL named parameters with positional parameters"): + conn.sql(query, params=[5]).fetchall() + + +def test_sql_apis_accept_compiled_sql() -> None: + conn = duckdb.connect() + compiled = template("SELECT i FROM range(5) t(i) WHERE i >= ", 3, " ORDER BY i").compile() + + assert conn.sql(compiled).fetchall() == [(3,), (4,)] + assert conn.query(compiled).fetchall() == [(3,), (4,)] + assert conn.from_query(compiled).fetchall() == [(3,), (4,)] + assert conn.execute(compiled).fetchall() == [(3,), (4,)] + + +def test_relation_interpolation_works_end_to_end() -> None: + conn = duckdb.connect() + rel = conn.sql("SELECT i FROM range(6) t(i)") + query = template("SELECT i FROM (", rel, ") WHERE i % ", 2, " = 0 ORDER BY i") + assert conn.sql(query).fetchall() == [(0,), (2,), (4,)] + + +def test_builtin_duckdbpytype_object_interpolates_in_template() -> None: + conn = duckdb.connect() + integer_type = duckdb.sqltype("INTEGER") + query = template("SELECT 42::", integer_type) + assert conn.sql(query).fetchall() == [(42,)] + + +def test_builtin_expression_object_interpolates_in_template() -> None: + conn = duckdb.connect() + expr = duckdb.ColumnExpression("i") + query = template("SELECT ", expr, " FROM range(3) t(i) ORDER BY i") + assert conn.sql(query).fetchall() == [(0,), (1,), (2,)] + + +def test_builtin_sqlexpression_object_interpolates_in_template() -> None: + conn = duckdb.connect() + expr = duckdb.SQLExpression("i + 1") + query = template("SELECT ", expr, " FROM range(3) t(i) ORDER BY i") + assert conn.sql(query).fetchall() == [(1,), (2,), (3,)] diff --git a/tests/fast/test_template_python.py b/tests/fast/test_template_python.py new file mode 100644 index 00000000..2896455b --- /dev/null +++ b/tests/fast/test_template_python.py @@ -0,0 +1,1024 @@ +"""Pure-python tests that don't require a compiled extension module.""" + +from __future__ import annotations + +import pytest + +from duckdb.template import ( + CompiledSql, + IntoInterpolation, + Param, + ParamInterpolation, + ResolvedSqlTemplate, + SqlTemplate, + SupportsDuckdbTemplate, + compile, + compile_parts, + param, + parse_parts, + template, +) + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +class FakeInterpolation: + """Minimal object satisfying IntoInterpolation protocol.""" + + def __init__(self, value, *, expression=None, conversion=None, format_spec="") -> None: + self.value = value + self.expression = expression + self.conversion = conversion + self.format_spec = format_spec + + +class SimpleRelation: + """A minimal SupportsDuckdbTemplate implementation returning a string.""" + + def __init__(self, sql: str) -> None: + self._sql = sql + + def __duckdb_template__(self, **kwargs) -> str: + return self._sql + + +class Cafe: + """Test for ascii(obj) conversion.""" + + def __repr__(self) -> str: + return "Café" + + @classmethod + def ascii(cls) -> str: + return r"Caf\xe9" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Param +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestParam: + def test_basic_creation(self): + p = Param(value=42) + assert p.value == 42 + assert p.name is None + assert p.exact is False + + def test_named_param(self): + p = Param(value="hello", name="greeting") + assert p.name == "greeting" + assert p.exact is False + + def test_positional_creation(self): + p = Param(42, "greeting", True) + assert p.value == 42 + assert p.name == "greeting" + assert p.exact is True + + def test_exact_param_requires_name(self): + with pytest.raises(ValueError, match="exact=True must have a name"): + Param(value=1, exact=True) + + def test_exact_param_with_name(self): + p = Param(value=1, name="x", exact=True) + assert p.name == "x" + assert p.exact is True + + def test_frozen(self): + p = Param(value=1) + with pytest.raises(AttributeError): + p.value = 2 # ty:ignore[invalid-assignment] + + def test_param_helper_function(self): + p = param(42, "answer", exact=True) + assert isinstance(p, Param) + assert p.value == 42 + assert p.name == "answer" + assert p.exact is True + + def test_param_repr(self): + p = Param(value=42, name="x") + r = repr(p) + expected = "Param(value=42, name='x', exact=False)" + assert r == expected + + def test_param_equality(self): + """Frozen dataclasses support equality by default.""" + assert Param(value=1, name="x") == Param(value=1, name="x") + assert Param(value=1) != Param(value=2) + + def test_param_various_value_types(self): + """Params should accept any Python object as a value.""" + for val in [None, 3.14, True, [1, 2], {"a": 1}, b"bytes", object()]: + p = Param(value=val) + assert p.value is val + + def test_param_factory_function_defaults(self): + """param() should allow defaults for name and exact.""" + expected = Param(value=42, name=None, exact=False) + assert param(42) == expected + assert param(value=42) == expected + + def test_param_factory_function_named(self): + p = param(42, name="x") + assert p.value == 42 + assert p.name == "x" + assert p.exact is False + + def test_param_factory_function_exact(self): + p = param(42, name="x", exact=True) + assert p.value == 42 + assert p.name == "x" + assert p.exact is True + + def test_param_factory_function_exact_requires_name(self): + with pytest.raises(ValueError, match="exact=True must have a name"): + param(42, exact=True) + + def test_param_factory_function_exact_with_name(self): + with pytest.raises(TypeError): + param(42, "x", True) # ty:ignore[too-many-positional-arguments] + + +# ═══════════════════════════════════════════════════════════════════════════════ +# ParamInterpolation +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestParamInterpolation: + def test_wraps_param(self): + p = param(42, "x") + pi = ParamInterpolation(p) + assert pi.value is p + assert pi.expression == "x" + assert pi.conversion is None + assert pi.format_spec == "" + + def test_unnamed_param_expression_is_none(self): + p = param(42) + pi = ParamInterpolation(p) + assert pi.expression is None + + def test_satisfies_into_interpolation_protocol(self): + pi = ParamInterpolation(param(1, "x")) + assert isinstance(pi, IntoInterpolation) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse_parts +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestParseParts: + def test_all_strings(self): + strings, others = parse_parts(["hello", " ", "world"]) + assert strings == ("hello world",) + assert others == () + + def test_all_others(self): + a, b = object(), object() + strings, others = parse_parts([a, b]) + # two others → need three string spacers: "", "", "" + assert strings == ("", "", "") + assert others == (a, b) + + def test_alternating(self): + a = object() + strings, others = parse_parts(["before", a, "after"]) + assert strings == ("before", "after") + assert others == (a,) + + def test_string_then_other(self): + a = object() + strings, others = parse_parts(["sql", a]) + assert strings == ("sql", "") + assert others == (a,) + + def test_other_then_string(self): + a = object() + strings, others = parse_parts([a, "sql"]) + assert strings == ("", "sql") + assert others == (a,) + + def test_adjacent_strings_merged(self): + strings, others = parse_parts(["a", "b", "c"]) + assert strings == ("abc",) + assert others == () + + def test_adjacent_others_get_empty_string_spacers(self): + a, b, c = object(), object(), object() + strings, others = parse_parts([a, b, c]) + assert strings == ("", "", "", "") + assert others == (a, b, c) + + def test_invariant_strings_one_more_than_others(self): + """The fundamental invariant: len(strings) == len(others) + 1.""" + cases = [ + ["a"], + ["a", object()], + [object(), "a"], + [object(), object()], + ["a", object(), "b", object(), "c"], + ] + for parts in cases: + strings, others = parse_parts(parts) + assert len(strings) == len(others) + 1, f"Failed for parts={parts}" + + def test_empty_input(self): + """Empty input should return a single empty string and no others.""" + strings, others = parse_parts([]) + assert strings == ("",) + assert others == () + + def test_single_string(self): + strings, others = parse_parts(["SELECT 1"]) + assert strings == ("SELECT 1",) + assert others == () + + def test_single_other(self): + a = object() + strings, others = parse_parts([a]) + assert strings == ("", "") + assert others == (a,) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SqlTemplate construction +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestSqlTemplateConstruction: + def test_plain_string(self): + t = SqlTemplate("SELECT 1") + assert t.strings == ("SELECT 1",) + assert t.interpolations == () + + def test_multiple_strings_merged(self): + t = SqlTemplate("SELECT ", "1") + assert t.strings == ("SELECT 1",) + assert t.interpolations == () + + def test_with_interpolation(self): + interp = FakeInterpolation(value=42, expression="x") + t = SqlTemplate("a ", interp, " b") + assert len(t.interpolations) == 1 + assert t.interpolations[0] is interp + + def test_bare_param_errors(self): + with pytest.raises(TypeError, match="Unexpected part type"): + SqlTemplate("SELECT ", param(42)) # ty:ignore[invalid-argument-type] + + def test_wrapped_param(self): + wrapped = ParamInterpolation(param(42, "x")) + t = SqlTemplate("SELECT ", wrapped, " FROM t") + assert len(t.interpolations) == 1 + + def test_rejects_invalid_types(self): + """Items that are not str, IntoInterpolation, or Param should raise TypeError.""" + with pytest.raises(TypeError, match="Unexpected part type"): + SqlTemplate(42) # ty:ignore[invalid-argument-type] + + def test_no_args(self): + """Empty SqlTemplate should produce a single empty string.""" + t = SqlTemplate() + assert t.strings == ("",) + assert t.interpolations == () + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SqlTemplate iteration and repr +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestSqlTemplateIterRepr: + def test_iter_plain_string(self): + t = SqlTemplate("hello") + parts = list(t) + assert parts == ["hello"] + + def test_iter_with_interpolations(self): + t = template("a ", param(1, "x"), " b") + parts = list(t) + assert len(parts) == 3 + assert parts[0] == "a " + assert isinstance(parts[1], IntoInterpolation) + assert parts[2] == " b" + + def test_str_raises(self): + t = SqlTemplate("hello") + with pytest.raises(NotImplementedError): + str(t) + + def test_repr_plain_string(self): + t = SqlTemplate("SELECT 1") + r = repr(t) + assert "SqlTemplate" in r + assert "'SELECT 1'" in r + + +# ═══════════════════════════════════════════════════════════════════════════════ +# template() factory — basic cases +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestTemplateFactory: + def test_plain_string(self): + t = template("SELECT 1") + compiled = t.compile() + expected = CompiledSql("SELECT 1", {}) + assert expected == compiled + + def test_param(self): + t = template(param(42, "answer")) + compiled = t.compile() + expected = CompiledSql("$p0_answer", {"p0_answer": 42}) + assert expected == compiled + + def test_into_interpolation(self): + interp = FakeInterpolation(value=42, expression="x") + t = template(interp) + compiled = t.compile() + expected = CompiledSql("$p0_x", {"p0_x": 42}) + assert expected == compiled + + def test_supports_duckdb_template_string(self): + rel = SimpleRelation("SELECT 1") + t = template(rel) + compiled = t.compile() + expected = CompiledSql("SELECT 1", {}) + assert expected == compiled + + def test_supports_duckdb_template_returns_iterable(self): + class MultiPart: + def __duckdb_template__(self, **kwargs) -> list[str | Param]: + return ["SELECT * FROM ", param(42, "x")] + + t = template(MultiPart()) + compiled = t.compile() + expected = CompiledSql("SELECT * FROM $p0_x", {"p0_x": 42}) + assert expected == compiled + + def test_supports_duckdb_template_returns_interpolation(self): + class InterpReturner: + def __duckdb_template__(self, **kwargs) -> FakeInterpolation: + return FakeInterpolation(value="hello", expression="val") + + t = template(InterpReturner()) + compiled = t.compile() + expected = CompiledSql("$p0_val", {"p0_val": "hello"}) + assert expected == compiled + + def test_iterable_of_strings(self): + t = template("SELECT ", "1") + compiled = t.compile() + expected = CompiledSql("SELECT 1", {}) + assert expected == compiled + + def test_iterable_with_params(self): + t = template("SELECT * FROM t WHERE id = ", param(5, "id")) + compiled = t.compile() + expected = CompiledSql("SELECT * FROM t WHERE id = $p0_id", {"p0_id": 5}) + assert expected == compiled + + def test_iterable_with_bare_values(self): + """Bare values in an iterable should be treated as params.""" + t = template("SELECT * FROM t WHERE id = ", 42) + compiled = t.compile() + expected = CompiledSql("SELECT * FROM t WHERE id = $p0", {"p0": 42}) + assert expected == compiled + + def test_bare_value_at_top_level_becomes_param(self): + """A bare value passed directly to template() becomes a param.""" + t = template(42) # type: ignore[arg-type] + compiled = t.compile() + expected = CompiledSql("$p0", {"p0": 42}) + assert expected == compiled + + def test_bytes_not_treated_as_iterable(self): + """Bytes should not be iterated — _is_iterable excludes bytes.""" + t = template(b"hello") # type: ignore[arg-type] + compiled = t.compile() + expected = CompiledSql("$p0", params={"p0": b"hello"}) + assert compiled == expected + + +# ═══════════════════════════════════════════════════════════════════════════════ +# template() factory — t-string integration (Python 3.14+) +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestTemplateWithInterpolations: + """Tests for template() with interpolations, using FakeInterpolation to simulate t-string behavior.""" + + def test_simple_param(self): + interp = FakeInterpolation(value=123, expression="user_id") + compiled = compile("SELECT * FROM users WHERE id = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_user_id", {"p0_user_id": 123}) + assert expected == compiled + + def test_multiple_params(self): + name_interp = FakeInterpolation(value="Alice", expression="name") + age_interp = FakeInterpolation(value=30, expression="age") + t = template("SELECT * FROM users WHERE name = ", name_interp, " AND age = ", age_interp, "") + compiled = t.compile() + expected = CompiledSql( + "SELECT * FROM users WHERE name = $p0_name AND age = $p1_age", + {"p0_name": "Alice", "p1_age": 30}, + ) + assert expected == compiled + + def test_no_params(self): + compiled = compile("SELECT 1") + expected = CompiledSql("SELECT 1") + assert expected == compiled + + def test_string_conversion_s(self): + """!s should inline the value as raw SQL (no param).""" + interp = FakeInterpolation(value="users", expression="table", conversion="s") + compiled = compile("SELECT * FROM ", interp, "") + expected = CompiledSql("SELECT * FROM users") + assert expected == compiled + + def test_repr_conversion_r(self): + """!r should inline repr(value) as raw SQL.""" + interp = FakeInterpolation(value="hello", expression="val", conversion="r") + compiled = compile("SELECT ", interp, "") + # repr of "hello" is "'hello'" + expected = CompiledSql("SELECT 'hello'") + assert expected == compiled + + def test_ascii_conversion_a(self): + interp = FakeInterpolation(value=Cafe(), expression="val", conversion="a") + compiled = compile("SELECT ", interp) + expected = CompiledSql("SELECT " + Cafe.ascii()) + assert expected == compiled + + def test_string_value_becomes_param_not_raw_sql(self): + """A string interpolation WITHOUT conversion should be a param, not raw SQL.""" + interp = FakeInterpolation(value="Alice", expression="name") + compiled = compile("SELECT * FROM users WHERE name = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE name = $p0_name", {"p0_name": "Alice"}) + assert expected == compiled + + def test_nested_template_via_supports_duckdb_template(self): + """An interpolated SupportsDuckdbTemplate should be expanded inline.""" + inner = SimpleRelation("SELECT * FROM people") + interp = FakeInterpolation(value=inner, expression="inner") + compiled = compile("SELECT name FROM (", interp, ")") + expected = CompiledSql("SELECT name FROM (SELECT * FROM people)", {}) + assert expected == compiled + + def test_nested_chaining(self): + """Chaining templates: inner template params should propagate.""" + inner = template("SELECT * FROM people WHERE age >= ", 18) + interp = FakeInterpolation(value=inner, expression="inner") + t = template("SELECT name FROM (", interp, ")") + compiled = t.compile() + # Skipping CompiledSql equality: exact param name depends on counter offsets across nesting levels + assert "SELECT * FROM people WHERE age >= $" in compiled.sql + assert 18 in compiled.params.values() + + def test_param_name_derived_from_expression(self): + """Param names should be based on the expression in the interpolation.""" + interp = FakeInterpolation(value=99, expression="my_value") + t = template("SELECT ", interp, "") + compiled = t.compile() + expected = CompiledSql("SELECT $p0_my_value", {"p0_my_value": 99}) + assert expected == compiled + + def test_explicit_param_in_interpolation(self): + """An explicit Param() used in an interpolation should be treated as a param.""" + p = param(42, "answer") + interp = FakeInterpolation(value=p, expression="p") + t = template("SELECT ", interp, "") + compiled = t.compile() + # Skipping exact equality: param name may come from Param.name or interpolation expression + assert 42 in compiled.params.values() + + def test_format_spec_on_conversion(self): + """Format spec combined with conversion: conversion first, then format.""" + interp = FakeInterpolation(value=3.14159, expression="val", conversion="s", format_spec=".5") + t = template("SELECT ", interp, "") + compiled = t.compile() + # Python semantics: str(3.14159) = "3.14159", then format("3.14159", ".5") = "3.141" + expected = CompiledSql("SELECT 3.141") + assert expected == compiled + + def test_format_spec_without_conversion_is_ignored(self): + """Format spec without conversion should ideally apply, but currently it's silently dropped.""" + interp = FakeInterpolation(value=3.14159, expression="val", format_spec=".2f") + t = template("SELECT ", interp, "") + compiled = t.compile() + expected = CompiledSql("SELECT $p0_val", {"p0_val": 3.14159}) + assert expected == compiled + + +# ═══════════════════════════════════════════════════════════════════════════════ +# resolve / _resolve_interpolation +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestResolve: + def test_plain_string_resolves_to_itself(self): + t = SqlTemplate("SELECT 1") + resolved = t.resolve() + assert list(resolved) == ["SELECT 1"] + + def test_conversion_s_resolves_to_string(self): + interp = FakeInterpolation(value=42, expression="x", conversion="s") + t = SqlTemplate("SELECT ", interp) + resolved = t.compile() + expected = CompiledSql("SELECT 42") + assert resolved == expected + + def test_conversion_r_resolves_to_repr(self): + interp = FakeInterpolation(value="hello", expression="x", conversion="r") + t = SqlTemplate("SELECT ", interp) + resolved = t.resolve() + expected = CompiledSql("SELECT 'hello'") + assert resolved == expected + + def test_conversion_a_resolves_to_ascii(self): + interp = FakeInterpolation(value=Cafe(), expression="x", conversion="a") + actual = compile("SELECT ", interp) + expected = CompiledSql("SELECT " + Cafe.ascii()) + assert actual == expected + + def test_string_value_without_conversion_becomes_param(self): + """A string value in an interpolation (no conversion) must become a param, not raw SQL.""" + interp = FakeInterpolation(value="DROP TABLE users", expression="val") + t = SqlTemplate("SELECT ", interp) + resolved = t.resolve() + expected = CompiledSql("SELECT $p0_val", params={"p0_val": "DROP TABLE users"}) + assert resolved == expected + + def test_nested_supports_duckdb_template(self): + rel = SimpleRelation("SELECT 1") + interp = FakeInterpolation(value=rel, expression="rel") + t = SqlTemplate("SELECT * FROM (", interp, ")") + resolved = t.resolve() + expected = CompiledSql("SELECT * FROM (SELECT 1)") + assert resolved == expected + + def test_expression_name_preserved_for_simple_param(self): + """When interpolation resolves to a single param, the expression name should be kept.""" + interp = FakeInterpolation(value=42, expression="my_age") + t = SqlTemplate("age = ", interp) + resolved = t.resolve() + expected = CompiledSql("age = $p0_my_age", params={"p0_my_age": 42}) + assert resolved == expected + + +# ═══════════════════════════════════════════════════════════════════════════════ +# ResolvedSqlTemplate +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestResolvedSqlTemplate: + def test_basic(self): + r = ResolvedSqlTemplate(["SELECT ", param(42, "x")]) + parts = list(r) + assert len(parts) == 2 + + def test_compile(self): + r = ResolvedSqlTemplate(["SELECT ", param(42, "x")]) + compiled = r.compile() + assert isinstance(compiled, CompiledSql) + assert 42 in compiled.params.values() + + def test_str_raises(self): + r = ResolvedSqlTemplate(["SELECT 1"]) + with pytest.raises(NotImplementedError): + str(r) + + def test_repr(self): + r = ResolvedSqlTemplate(["SELECT ", param(42, "x")]) + rep = repr(r) + assert "ResolvedSqlTemplate" in rep + assert "x=42" in rep + + def test_iter(self): + parts_in = ["a", param(1, "x"), "b"] + r = ResolvedSqlTemplate(parts_in) + assert list(r) == parts_in + + +# ═══════════════════════════════════════════════════════════════════════════════ +# compile_parts +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestCompileParts: + def test_all_strings(self): + result = compile_parts(["SELECT 1"]) + assert result == CompiledSql("SELECT 1") + + def test_single_unnamed_param(self): + result = compile_parts(["SELECT ", param(42)]) + expected = CompiledSql("SELECT $p0", params={"p0": 42}) + assert result == expected + + def test_single_named_param(self): + result = compile_parts(["SELECT ", param(42, "x")]) + expected = CompiledSql("SELECT $p0_x", params={"p0_x": 42}) + assert result == expected + + def test_exact_param_uses_literal_name(self): + result = compile_parts(["SELECT ", param(42, "my_param", exact=True)]) + expected = CompiledSql("SELECT $my_param", params={"my_param": 42}) + assert result == expected + + def test_multiple_params_numbered_sequentially(self): + result = compile_parts(["SELECT * WHERE a = ", param(1, "a"), " AND b = ", param(2, "b")]) + expected = CompiledSql("SELECT * WHERE a = $p0_a AND b = $p1_b", params={"p0_a": 1, "p1_b": 2}) + assert result == expected + + def test_duplicate_param_names_raises(self): + with pytest.raises(ValueError, match="Duplicate parameter names"): + compile_parts([param(1, "x", exact=True), param(2, "x", exact=True)]) + + def test_unnamed_params_get_sequential_names(self): + result = compile_parts(["a = ", param(1), " AND b = ", param(2)]) + expected = CompiledSql("a = $p0 AND b = $p1", params={"p0": 1, "p1": 2}) + assert result == expected + + def test_exact_param_causes_counter_gap(self): + """Known issue: exact params still increment the counter, causing non-sequential auto names.""" + result = compile_parts( + [ + "a = ", + param(1, "x"), # → p0_x + " AND b = ", + param(2, "b", exact=True), # → b (exact), but counter increments + " AND c = ", + param(3, "y"), # → p2_y (not p1_y!) + ] + ) + expected = CompiledSql("a = $p0_x AND b = $b AND c = $p2_y", params={"p0_x": 1, "b": 2, "p2_y": 3}) + assert result == expected + + def test_empty_parts(self): + result = compile_parts([]) + expected = CompiledSql("") + assert result == expected + + def test_adjacent_strings(self): + result = compile_parts(["SELECT ", "1"]) + expected = CompiledSql("SELECT 1") + assert result == expected + + def test_param_with_none_value(self): + result = compile_parts(["SELECT ", param(None, "x")]) + expected = CompiledSql("SELECT $p0_x", params={"p0_x": None}) + assert result == expected + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SupportsDuckdbTemplate protocol +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestSupportsDuckdbTemplate: + def test_protocol_check(self): + rel = SimpleRelation("SELECT 1") + assert isinstance(rel, SupportsDuckdbTemplate) + + def test_non_implementor(self): + assert not isinstance("hello", SupportsDuckdbTemplate) + assert not isinstance(42, SupportsDuckdbTemplate) + + def test_template_calls_dunder(self): + class Tracking: + def __init__(self) -> None: + self.called = False + + def __duckdb_template__(self, **kwargs) -> str: + self.called = True + return "SELECT 1" + + obj = Tracking() + template(obj) + assert obj.called + + def test_future_kwargs_accepted(self): + """Implementations should accept **kwargs for future extensibility.""" + + class Strict: + def __duckdb_template__(self, **kwargs) -> str: + assert isinstance(kwargs, dict) + return "SELECT 1" + + template(Strict()) + + def test_returns_interpolation(self): + """__duckdb_template__ can return an interpolation.""" + + class InterpolationReturner: + def __duckdb_template__(self, **kwargs) -> FakeInterpolation: + return FakeInterpolation(value=42, expression="val") + + t = template(InterpolationReturner()) + compiled = t.compile() + expected = CompiledSql("$p0_val", params={"p0_val": 42}) + assert compiled == expected + + def test_supports_duckdb_template_priority_over_iterable(self): + class IterableWithTemplate: + def __duckdb_template__(self, **kwargs) -> str: + return "SELECT 1" + + def __iter__(self) -> list[str]: + return ["SELECT 2"] + + result = template(IterableWithTemplate()).compile() + expected = CompiledSql("SELECT 1") + assert result == expected + + +# ═══════════════════════════════════════════════════════════════════════════════ +# IntoInterpolation protocol +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestIntoInterpolation: + def test_protocol_check_positive(self): + interp = FakeInterpolation(value=1, expression="x") + assert isinstance(interp, IntoInterpolation) + + def test_protocol_check_negative(self): + assert not isinstance("hello", IntoInterpolation) + assert not isinstance(42, IntoInterpolation) + + def test_param_interpolation_satisfies(self): + pi = ParamInterpolation(param(1)) + assert isinstance(pi, IntoInterpolation) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# CompiledSql +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestCompiledSql: + def test_basic(self): + c = CompiledSql("SELECT $p0", params={"p0": 42}) + assert c.sql == "SELECT $p0" + assert c.params == {"p0": 42} + + def test_positional_params(self): + c = CompiledSql("SELECT $p0", {"p0": 42}) + assert c.sql == "SELECT $p0" + assert c.params == {"p0": 42} + + def test_empty_params(self): + c = CompiledSql("SELECT 1") + assert c.sql == "SELECT 1" + assert c.params == {} + + def test_optional_params(self): + c = CompiledSql("SELECT 1") + assert c.sql == "SELECT 1" + assert c.params == {} + + def test_frozen(self): + c = CompiledSql("SELECT 1") + with pytest.raises(AttributeError): + c.sql = "SELECT 2" # ty:ignore[invalid-assignment] + + def test_equality(self): + a = CompiledSql("SELECT 1") + b = CompiledSql("SELECT 1") + assert a == b + + def test_repr(self): + c = CompiledSql("SELECT 1") + expected = "CompiledSql(sql='SELECT 1', params={})" + assert repr(c) == expected + + def test_str(self): + c = CompiledSql("SELECT 1") + with pytest.raises(NotImplementedError): + str(c) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# End-to-end: compile +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestEndToEndCompile: + def test_plain_sql(self): + result = compile("SELECT * FROM users") + assert result == CompiledSql("SELECT * FROM users", {}) + + def test_strings_joined(self): + result = compile("SELECT * FROM ", "users") + assert result == CompiledSql("SELECT * FROM users", {}) + + def test_param_in_list(self): + result = compile("SELECT * FROM users WHERE id = ", param(5, "id")) + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_id", {"p0_id": 5}) + assert expected == result + + def test_multiple_params_in_list(self): + result = compile("SELECT * FROM users WHERE name = ", param("Alice", "name"), " AND age > ", param(18, "age")) + expected = CompiledSql( + "SELECT * FROM users WHERE name = $p0_name AND age > $p1_age", + {"p0_name": "Alice", "p1_age": 18}, + ) + assert expected == result + + def test_supports_duckdb_template_end_to_end(self): + class MyTable: + def __duckdb_template__(self, **kwargs) -> str: + return "SELECT * FROM my_table" + + result = compile(MyTable()) + expected = CompiledSql("SELECT * FROM my_table", {}) + assert expected == result + + def test_unparameterized_strings(self): + result = compile("SELECT * FROM users WHERE id = ", 42, " AND name = ", "Alice") + expected = CompiledSql("SELECT * FROM users WHERE id = $p0 AND name = Alice", {"p0": 42}) + assert expected == result + + def test_parameterized_strings(self): + result = compile("SELECT * FROM users WHERE id = ", 42, " AND name = ", param("Alice")) + expected = CompiledSql("SELECT * FROM users WHERE id = $p0 AND name = $p1", {"p0": 42, "p1": "Alice"}) + assert expected == result + + def test_nested_template_relations(self): + """The key use case: chaining template queries.""" + inner = template("SELECT * FROM people WHERE age >= ", 18) + outer = template("SELECT name FROM (", inner, ")") + result = outer.compile() + # Skipping exact equality: param name depends on counter offsets across nesting levels + assert "SELECT * FROM people WHERE age >=" in result.sql + assert "SELECT name FROM (" in result.sql + assert 18 in result.params.values() + assert len(result.params) == 1 + + def test_deeply_nested_tstrings(self): + """Three levels of nesting.""" + val = 100 + level1 = template("SELECT * FROM t WHERE x = ", val) + level2 = template("SELECT * FROM (", level1, ") WHERE y = 1") + level3 = template("SELECT count(*) FROM (", level2, ")") + result = level3.compile() + # Skipping exact equality: param name depends on counter offsets across nesting levels + assert "SELECT * FROM t WHERE x = $" in result.sql + assert 100 in result.params.values() + assert len(result.params) == 1 + + def test_multiple_nested_with_separate_params(self): + """Two nested templates each with their own params.""" + a = 1 + b = 2 + t1 = template("SELECT * FROM t1 WHERE a = ", a) + t2 = template("SELECT * FROM t2 WHERE b = ", b) + outer = template("SELECT * FROM (", t1, ") JOIN (", t2, ")") + result = outer.compile() + # Skipping exact equality: param names depend on counter offsets across nesting levels + assert 1 in result.params.values() + assert 2 in result.params.values() + assert len(result.params) == 2 + + def test_sql_injection_prevented_by_default(self): + """Without !s, even SQL-like strings become params, not raw SQL.""" + evil = "1; DROP TABLE users; --" + interp = FakeInterpolation(value=evil, expression="evil") + result = compile("SELECT * FROM users WHERE id = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_evil", {"p0_evil": evil}) + assert expected == result + + def test_sql_injection_possible_with_s_conversion(self): + """With !s, values ARE inlined — this is intentional but dangerous.""" + evil = "1; DROP TABLE users; --" + interp = FakeInterpolation(value=evil, expression="evil", conversion="s") + result = compile("SELECT * FROM users WHERE id = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE id = 1; DROP TABLE users; --", {}) + assert expected == result + + def test_exact_param_name(self): + result = compile("SELECT * FROM t WHERE id = ", param(value=42, name="my_id", exact=True)) + expected = CompiledSql("SELECT * FROM t WHERE id = $my_id", {"my_id": 42}) + assert expected == result + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Edge cases and known issues +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestEdgeCases: + def test_empty_string_template(self): + result = compile("") + expected = CompiledSql("") + assert result == expected + + def test_param_with_none_value(self): + result = compile(param(value=None, name="x")) + expected = CompiledSql("$p0_x", params={"p0_x": None}) + assert result == expected + + def test_param_with_list_value(self): + result = compile(param(value=[1, 2, 3], name="ids")) + expected = CompiledSql("$p0_ids", params={"p0_ids": [1, 2, 3]}) + assert result == expected + + def test_param_with_dict_value(self): + d = {"key": "value"} + result = compile(param(value=d, name="data")) + expected = CompiledSql("$p0_data", params={"p0_data": d}) + assert result == expected + + def test_bool_param(self): + result = compile("SELECT * FROM t WHERE active = ", True) + expected = CompiledSql("SELECT * FROM t WHERE active = $p0", params={"p0": True}) + assert result == expected + + def test_float_param(self): + interp = FakeInterpolation(value=3.14, expression="threshold") + result = template("SELECT * FROM t WHERE score > ", interp, "").compile() + expected = CompiledSql("SELECT * FROM t WHERE score > $p0_threshold", params={"p0_threshold": 3.14}) + assert result == expected + + def test_none_param(self): + interp = FakeInterpolation(value=None, expression="val") + result = template("SELECT * FROM t WHERE x IS ", interp, "").compile() + expected = CompiledSql("SELECT * FROM t WHERE x IS $p0_val", params={"p0_val": None}) + assert result == expected + + def test_param_object_in_interpolation_preserves_name(self): + """An explicit Param used in an interpolation should keep its name.""" + interp = FakeInterpolation(value=param(42, "custom_name"), expression="p") + result = template("SELECT ", interp, "").compile() + expected = CompiledSql("SELECT $custom_name", params={"custom_name": 42}) + assert result == expected + + def test_same_expression_used_twice(self): + """Using the same expression twice should create two separate params.""" + interp1 = FakeInterpolation(value=42, expression="x") + interp2 = FakeInterpolation(value=42, expression="x") + result = template("SELECT * FROM t WHERE a = ", interp1, " AND b = ", interp2, "").compile() + expected = CompiledSql("SELECT * FROM t WHERE a = $p0_x AND b = $p1_x", params={"p0_x": 42, "p1_x": 42}) + assert result == expected + + def test_mixed_conversion_and_param(self): + """Mix of !s (inlined) and regular (parameterized) in same template.""" + table_interp = FakeInterpolation(value="users", expression="table", conversion="s") + id_interp = FakeInterpolation(value=5, expression="user_id") + result = template("SELECT * FROM ", table_interp, " WHERE id = ", id_interp, "").compile() + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_user_id", params={"p0_user_id": 5}) + assert result == expected + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Conversion semantics — documenting current (potentially incorrect) behavior +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestConversionSemantics: + """Verify conversion + format_spec follows Python f-string semantics.""" + + def test_s_conversion_on_int(self): + interp = FakeInterpolation(value=42, expression="x", conversion="s") + actual = compile(interp) + expected = CompiledSql("42") + assert actual == expected + + def test_r_conversion_on_string(self): + """repr('hello') = "'hello'".""" + interp = FakeInterpolation(value="hello", expression="x", conversion="r") + actual = compile(interp) + expected = CompiledSql("'hello'") + assert actual == expected + + def test_r_conversion_on_int(self): + """repr(42) = '42', no quotes.""" + interp = FakeInterpolation(value=42, expression="x", conversion="r") + actual = compile(interp) + expected = CompiledSql("42") + assert actual == expected + + def test_s_conversion_with_format_spec(self): + """Conversion first, then format_spec: str(3.14159) then format with '.5' truncates.""" + interp = FakeInterpolation(value=3.14159, expression="x", conversion="s", format_spec=".5") + actual = compile(interp) + expected = CompiledSql("3.141") + assert actual == expected + + def test_r_conversion_with_format_spec(self): + """Python semantics: repr first, then format.""" + interp = FakeInterpolation(value="hi", expression="x", conversion="r", format_spec=".4") + actual = compile(interp) + expected = CompiledSql("'hi'") + assert actual == expected + + def test_format_spec_ignored_for_parameterized_values(self): + """When no conversion is specified, format_spec is silently ignored — the value is parameterized as-is.""" + interp = FakeInterpolation(value=3.14159, expression="x", format_spec=".2f") + actual = compile(interp) + expected = CompiledSql("$p0_x", params={"p0_x": 3.14159}) + assert actual == expected diff --git a/tests/fast/test_template_tstrings.py314 b/tests/fast/test_template_tstrings.py314 new file mode 100644 index 00000000..27d9c0d9 --- /dev/null +++ b/tests/fast/test_template_tstrings.py314 @@ -0,0 +1,178 @@ +"""Tests for template.py that use t-string literals (Python 3.14+ only).""" + +from __future__ import annotations + +import sys + +import pytest + +from duckdb.template import ( + Param, + param, + template, +) + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 14), + reason="t-strings require Python 3.14+", +) + + +class SimpleRelation: + """A minimal SupportsDuckdbTemplate implementation returning a string.""" + + def __init__(self, sql: str): + self._sql = sql + + def __duckdb_template__(self, **kwargs): + return self._sql + + +class TestTStringBasics: + def test_simple_param(self): + user_id = 123 + t = template(t"SELECT * FROM users WHERE id = {user_id}") + compiled = t.compile() + assert "user_id" in compiled.sql + assert "$" in compiled.sql + assert 123 in compiled.params.values() + + def test_multiple_params(self): + name = "Alice" + age = 30 + t = template(t"SELECT * FROM users WHERE name = {name} AND age = {age}") + compiled = t.compile() + assert len(compiled.params) == 2 + assert "Alice" in compiled.params.values() + assert 30 in compiled.params.values() + + def test_no_params(self): + t = template(t"SELECT 1") + compiled = t.compile() + assert compiled.sql == "SELECT 1" + assert compiled.params == {} + + def test_param_name_derived_from_expression(self): + my_value = 99 + t = template(t"SELECT {my_value}") + compiled = t.compile() + assert any("my_value" in k for k in compiled.params) + + def test_explicit_param_in_tstring(self): + p = param(42, "answer") + t = template(t"SELECT {p}") + compiled = t.compile() + assert 42 in compiled.params.values() + + +class TestTStringConversions: + def test_string_conversion_s(self): + table = "users" + t = template(t"SELECT * FROM {table!s}") + compiled = t.compile() + assert compiled.sql == "SELECT * FROM users" + assert compiled.params == {} + + def test_repr_conversion_r(self): + val = "hello" + t = template(t"SELECT {val!r}") + compiled = t.compile() + assert "'hello'" in compiled.sql + assert compiled.params == {} + + def test_ascii_conversion_a(self): + val = "café" + t = template(t"SELECT {val!a}") + compiled = t.compile() + assert compiled.params == {} + assert "\\xe9" in compiled.sql or "caf" in compiled.sql + + def test_string_value_becomes_param_not_raw_sql(self): + name = "Alice" + t = template(t"SELECT * FROM users WHERE name = {name}") + compiled = t.compile() + assert "Alice" not in compiled.sql + assert "Alice" in compiled.params.values() + + def test_format_spec_on_conversion(self): + val = 3.14159 + t = template(t"SELECT {val!s:.5}") + compiled = t.compile() + assert compiled.params == {} + assert "3.141" in compiled.sql + + def test_format_spec_without_conversion_is_ignored(self): + val = 3.14159 + t = template(t"SELECT {val:.2f}") + compiled = t.compile() + assert 3.14159 in compiled.params.values() + assert "3.14" not in compiled.sql + + +class TestTStringNesting: + def test_nested_template_via_supports_duckdb_template(self): + inner = SimpleRelation("SELECT * FROM people") + t = template(t"SELECT name FROM ({inner})") + compiled = t.compile() + assert "SELECT * FROM people" in compiled.sql + assert compiled.params == {} + + def test_nested_tstring_chaining(self): + age = 18 + inner = template(t"SELECT * FROM people WHERE age >= {age}") + t = template(t"SELECT name FROM ({inner})") + compiled = t.compile() + assert "SELECT * FROM people WHERE age >= $" in compiled.sql + assert 18 in compiled.params.values() + + def test_duckdb_template_returns_tstring(self): + class TStringReturner: + def __duckdb_template__(self, **kwargs): + val = 42 + return t"SELECT {val}" + + t = template(TStringReturner()) + compiled = t.compile() + assert 42 in compiled.params.values() + + +class TestTStringEdgeCases: + def test_float_param(self): + threshold = 3.14 + result = template(t"SELECT * FROM t WHERE score > {threshold}").compile() + assert 3.14 in result.params.values() + + def test_none_param(self): + val = None + result = template(t"SELECT * FROM t WHERE x IS {val}").compile() + assert None in result.params.values() + + def test_param_object_preserves_name(self): + p = Param(value=42, name="custom_name") + result = template(t"SELECT {p}").compile() + assert 42 in result.params.values() + + def test_same_variable_used_twice(self): + x = 42 + result = template(t"SELECT * FROM t WHERE a = {x} AND b = {x}").compile() + assert len(result.params) == 2 + assert all(v == 42 for v in result.params.values()) + + def test_mixed_conversion_and_param(self): + table = "users" + user_id = 5 + result = template(t"SELECT * FROM {table!s} WHERE id = {user_id}").compile() + assert "users" in result.sql + assert "5" not in result.sql + assert 5 in result.params.values() + + def test_sql_injection_prevented_by_default(self): + evil = "1; DROP TABLE users; --" + result = template(t"SELECT * FROM users WHERE id = {evil}").compile() + assert "DROP TABLE" not in result.sql + assert evil in result.params.values() + + def test_sql_injection_possible_with_s_conversion(self): + evil = "1; DROP TABLE users; --" + result = template(t"SELECT * FROM users WHERE id = {evil!s}").compile() + assert "DROP TABLE" in result.sql