From 50aa53476a70fa2bb391f62a8005a98ee852d631 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 26 Mar 2026 23:08:17 -0800 Subject: [PATCH 01/29] try a mock API --- tstrings.py | 135 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tstrings.py diff --git a/tstrings.py b/tstrings.py new file mode 100644 index 00000000..2f0c347b --- /dev/null +++ b/tstrings.py @@ -0,0 +1,135 @@ +from collections.abc import Iterable +from typing import Literal, Protocol, TypedDict, runtime_checkable + +from typing_extensions import Self + + +# from string.templatelib import Interpolation, Template +class PInterpolation(Protocol): + value: object + conversion: Literal["s", "r", "a"] | None + + +@runtime_checkable +class PTemplate(Protocol): + strings: list[str] + interpolations: list[PInterpolation] + + +class Param(TypedDict): + name: str + value: object + + +class ResolvedQuery(TypedDict): + sql: str + params: list[Param] + + +class PQueryContext(Protocol): + def current_params(self) -> list[Param]: + """Return the list of parameters currently in the context.""" + + def resolve_next( + self, + thing: object, + /, + ) -> str: + """Resolve the given thing as the next step in resolving a query.""" + + def add_param(self, value: object, /, name: str | None = None) -> Param: + """Create a new parameter with the given value and an optional name. Does NOT add it to the context.""" + + +class ResolveOneFunc(Protocol): + def __call__(self, thing: object, ctx: PQueryContext, /, **ignored_kwargs) -> str: ... + + +class QueryContext: + def __init__(self, params: Iterable[Param] | None = None, resolver: ResolveOneFunc | None = None): + self._params: list[Param] = list(params) if params is not None else [] + self._resolver: ResolveOneFunc = resolver if resolver is not None else resolve_one + + def current_params(self) -> list[Param]: + # Full deep copy so the user can't mess up the internal state + return [{"name": p["name"], "value": p["value"]} for p in self._params] + + def resolve_next(self, thing: object, /) -> str: + return self._resolver(thing, self) + + def add_param(self, value: object, /, name: str | None = None) -> Param: + param_name = ( + name + if name is not None + else generate_unique_param_name((p["name"] for p in self._params), template="param_{idx}") + ) + p = Param(name=param_name, value=value) + self._params.append(p) + return p + + +def resolve_query(thing: object) -> ResolvedQuery: + ctx = QueryContext() + sql = resolve_one(thing, ctx) + return {"sql": sql, "params": ctx.current_params()} + + +def resolve_one(thing: object, ctx: PQueryContext, /, **ignored_kwargs) -> str: + if isinstance(thing, SupportsDuckdbResolve): + return thing.__duckdb_resolve__(ctx, **ignored_kwargs) + if isinstance(thing, PTemplate): + return resolve_template(thing, ctx) + param = ctx.add_param(thing) + return f"${param['name']}" + + +def resolve_template(template: PTemplate, ctx: PQueryContext) -> str: + """Resolve a Template, recursively resolving any interpolations and applying any specified conversions.""" + sql_parts = [] + + for i, static_part in enumerate(template.strings): + sql_parts.append(static_part) + if i < len(template.interpolations): + interp = template.interpolations[i] + value = interp.value + + # Apply conversion if specified (!s, !r, !a) + if interp.conversion == "s": + value = str(value) + elif interp.conversion == "r": + value = repr(value) + elif interp.conversion == "a": + value = ascii(value) + + sql_parts.append(ctx.resolve_next(value)) + + return "".join(sql_parts) + + +def generate_unique_param_name(existing_names: Iterable[str], *, template: str | None = None) -> str: + """Generate a unique parameter name that does not conflict with existing_names. + + If template is provided, it should be a string with a single {idx} placeholder + that will be replaced with an integer index to generate candidate names. + If template is None, a default template of "param_{idx}" will be used. + """ + if template is None: + template = "param_{idx}" + existing_names = set(existing_names) + idx = 1 + while True: + candidate = template.format(idx=idx) + if candidate not in existing_names: + return candidate + idx += 1 + + +@runtime_checkable +class SupportsDuckdbResolve(Protocol): + def __duckdb_resolve__(self, ctx: PQueryContext, /, **future_kwargs) -> str: ... + + +class DuckdDbPyRelation: + def __duckdb_resolve__(self, ctx: PQueryContext, /, **future_kwargs) -> str: + # this would just return the existing SQL. + return "SELECT * FROM some_table" From 495606f94277d37db26ba3fe1f875737c5bd7024 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 10:11:43 -0800 Subject: [PATCH 02/29] overhaul and improve the API --- template.py | 199 ++++++++++++++++++++++++++++++++++++++++++++++++++++ tstrings.py | 135 ----------------------------------- 2 files changed, 199 insertions(+), 135 deletions(-) create mode 100644 template.py delete mode 100644 tstrings.py diff --git a/template.py b/template.py new file mode 100644 index 00000000..1ce9c0fc --- /dev/null +++ b/template.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Protocol, TypedDict, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + +class Param(TypedDict): + """Represents a parameter to be passed to duckdb, with a name and a value.""" + + value: object + name: str + + +class CompiledSql(TypedDict): + """Represents a compiled SQL statement, with the final SQL string and a list of Params to be passed to duckdb.""" + + sql: str + params: list[Param] + + +@runtime_checkable +class SupportsDuckdbTemplate(Protocol): + def __duckdb_template__( + self, /, **future_kwargs + ) -> str | IntoInterpolation | Iterable[str | IntoInterpolation]: ... + + +def resolve_to_template(thing: object, /, **ignored_kwargs) -> SqlTemplate: + if isinstance(thing, SupportsDuckdbTemplate): + raw = thing.__duckdb_template__(**ignored_kwargs) + return SqlTemplate.from_part_or_parts(raw) + if isinstance(thing, IntoTemplate): + return resolve_into_template(thing) + if isinstance(thing, IntoInterpolation): + return resolve_interpolation(thing) + return SqlTemplate(OurInterpolation(thing)) + + +def compile(thing: object) -> CompiledSql: + resolved_template = resolve_to_template(thing) + return compile_sql_template(resolved_template) + + +def resolve_into_template(template: IntoTemplate) -> SqlTemplate: + """Resolve a Template, recursively resolving interpolations and flattening nested templates.""" + parts: list[str | IntoInterpolation] = [] + for part in template: + if isinstance(part, str): + parts.append(part) + else: + inner_parts = resolve_interpolation(part) + parts.extend(inner_parts) + return SqlTemplate(*parts) + + +def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: + # if conversion specified (!s, !r, !a), treat as raw sql, eg + # t"SELECT {"mycol"!s} FROM foo" should be "SELECT mycol FROM foo + if interp.conversion == "s": + return SqlTemplate(str(interp.value)) + elif interp.conversion == "r": + return SqlTemplate(repr(interp.value)) + elif interp.conversion == "a": + return SqlTemplate(ascii(interp.value)) + + templ = resolve_to_template(interp.value) + # If the resolved inner is just a single Interpolation, then just return + # the original value so that we preserve the expression name. + # For example, if we have + # ```python + # age = 37 + # people = t"SELECT * FROM people WHERE age = {age}" + # resolve_template(people) + # ``` + # should resolve to: + # "SELECT * FROM people WHERE age = $age", with a param $age=37, + # eg with a friendly param name, rather than + # "SELECT * FROM people WHERE age = $p0", with a param $p0=37 + if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.interpolations) == 1: + return SqlTemplate(interp) + else: + # We got something nested, eg + # age = 37 + # people = t"SELECT * FROM people WHERE age = {age}" + # names = t"SELECT name FROM ({foo})" + # names should resolve to: + # "SELECT name FROM (SELECT * FROM people WHERE age = $age)", with a param $age=37 + return templ + + +class DuckdDbPyRelation: + def __duckdb_template__(self, /, **future_kwargs) -> str: + # this would just return the existing SQL. + return "SELECT * FROM some_table" + + +class OurInterpolation: + def __init__( + self, + value: object, + conversion: Literal["s", "r", "a"] | None = None, + expression: str | None = None, + ): + self.value = value + self.conversion = conversion + self.expression = expression + + +class SqlTemplate: + def __init__(self, *parts: str | IntoInterpolation): + strings, interpolations = [], [] + last_thing: Literal["string", "interpolation"] | None = None + for part in parts: + if isinstance(part, str): + if last_thing == "string": + # Merge adjacent string parts for efficiency, since the template engine allows that + strings[-1] += part + strings.append(part) + last_thing = "string" + else: + if last_thing is None or last_thing == "interpolation": + # this is the first part, + # or there were two adjacent interpolations, + # so we need an empty string spacer + strings.append("") + interpolations.append(OurInterpolation(part.value, part.conversion)) + last_thing = "interpolation" + if last_thing == "interpolation": + # If the last part was an interpolation, we need to end with an empty string + strings.append("") + assert len(strings) == len(interpolations) + 1 + self.strings = strings + self.interpolations = interpolations + + def __iter__(self): + for s, i in zip(self.strings, self.interpolations): + yield s + yield i + yield self.strings[-1] + + @classmethod + def from_part_or_parts( + cls, part_or_parts: str | IntoInterpolation | Iterable[str | IntoInterpolation] + ) -> SqlTemplate: + if isinstance(part_or_parts, (str, IntoInterpolation)): + return cls(part_or_parts) + else: + return cls(*part_or_parts) + + def __str__(self): + msg = f"{self.__class__.__name__} cannot be directly converted to string. It needs to be processed by the SQL engine to produce the final SQL string." # noqa: E501 + raise NotImplementedError(msg) + + +def compile_sql_template(template: SqlTemplate) -> CompiledSql: + """Compile a resolved SqlTemplate into a final SQL string with named parameter placeholders, and a list of Params.""" + sql_parts: list[str] = [] + params: list[Param] = [] + for part in template: + if isinstance(part, str): + sql_parts.append(part) + else: + param_name = f"p{len(params)}" + if part.expression is not None: + param_name += f"_{part.expression}" + assert_param_name_legal(param_name) + sql_parts.append(f"${param_name}") + params.append({"name": param_name, "value": part.value}) + return { + "sql": "".join(sql_parts), + "params": params, + } + + +# from string.templatelib import Interpolation, Template +@runtime_checkable +class IntoInterpolation(Protocol): + """Something that can be converted into a string.templatelib.Interpolation.""" + + value: object + conversion: Literal["s", "r", "a"] | None + expression: str | None + + +@runtime_checkable +class IntoTemplate(Protocol): + strings: Sequence[str] + interpolations: Sequence[IntoInterpolation] + + def __iter__(self) -> Iterator[str | IntoInterpolation]: ... + + +def assert_param_name_legal(name: str) -> None: + """Eg `$param_1` is legal, but `$1param`, `$param-1`, `$param 1`, and `$p ; DROP TABLE users` are not.""" + # not implemented yet + # Not exactly sure what part of the stack this should get called in, + # or perhaps we shouldn't even check, just pass it to duckdb and let it error if it's illegal diff --git a/tstrings.py b/tstrings.py deleted file mode 100644 index 2f0c347b..00000000 --- a/tstrings.py +++ /dev/null @@ -1,135 +0,0 @@ -from collections.abc import Iterable -from typing import Literal, Protocol, TypedDict, runtime_checkable - -from typing_extensions import Self - - -# from string.templatelib import Interpolation, Template -class PInterpolation(Protocol): - value: object - conversion: Literal["s", "r", "a"] | None - - -@runtime_checkable -class PTemplate(Protocol): - strings: list[str] - interpolations: list[PInterpolation] - - -class Param(TypedDict): - name: str - value: object - - -class ResolvedQuery(TypedDict): - sql: str - params: list[Param] - - -class PQueryContext(Protocol): - def current_params(self) -> list[Param]: - """Return the list of parameters currently in the context.""" - - def resolve_next( - self, - thing: object, - /, - ) -> str: - """Resolve the given thing as the next step in resolving a query.""" - - def add_param(self, value: object, /, name: str | None = None) -> Param: - """Create a new parameter with the given value and an optional name. Does NOT add it to the context.""" - - -class ResolveOneFunc(Protocol): - def __call__(self, thing: object, ctx: PQueryContext, /, **ignored_kwargs) -> str: ... - - -class QueryContext: - def __init__(self, params: Iterable[Param] | None = None, resolver: ResolveOneFunc | None = None): - self._params: list[Param] = list(params) if params is not None else [] - self._resolver: ResolveOneFunc = resolver if resolver is not None else resolve_one - - def current_params(self) -> list[Param]: - # Full deep copy so the user can't mess up the internal state - return [{"name": p["name"], "value": p["value"]} for p in self._params] - - def resolve_next(self, thing: object, /) -> str: - return self._resolver(thing, self) - - def add_param(self, value: object, /, name: str | None = None) -> Param: - param_name = ( - name - if name is not None - else generate_unique_param_name((p["name"] for p in self._params), template="param_{idx}") - ) - p = Param(name=param_name, value=value) - self._params.append(p) - return p - - -def resolve_query(thing: object) -> ResolvedQuery: - ctx = QueryContext() - sql = resolve_one(thing, ctx) - return {"sql": sql, "params": ctx.current_params()} - - -def resolve_one(thing: object, ctx: PQueryContext, /, **ignored_kwargs) -> str: - if isinstance(thing, SupportsDuckdbResolve): - return thing.__duckdb_resolve__(ctx, **ignored_kwargs) - if isinstance(thing, PTemplate): - return resolve_template(thing, ctx) - param = ctx.add_param(thing) - return f"${param['name']}" - - -def resolve_template(template: PTemplate, ctx: PQueryContext) -> str: - """Resolve a Template, recursively resolving any interpolations and applying any specified conversions.""" - sql_parts = [] - - for i, static_part in enumerate(template.strings): - sql_parts.append(static_part) - if i < len(template.interpolations): - interp = template.interpolations[i] - value = interp.value - - # Apply conversion if specified (!s, !r, !a) - if interp.conversion == "s": - value = str(value) - elif interp.conversion == "r": - value = repr(value) - elif interp.conversion == "a": - value = ascii(value) - - sql_parts.append(ctx.resolve_next(value)) - - return "".join(sql_parts) - - -def generate_unique_param_name(existing_names: Iterable[str], *, template: str | None = None) -> str: - """Generate a unique parameter name that does not conflict with existing_names. - - If template is provided, it should be a string with a single {idx} placeholder - that will be replaced with an integer index to generate candidate names. - If template is None, a default template of "param_{idx}" will be used. - """ - if template is None: - template = "param_{idx}" - existing_names = set(existing_names) - idx = 1 - while True: - candidate = template.format(idx=idx) - if candidate not in existing_names: - return candidate - idx += 1 - - -@runtime_checkable -class SupportsDuckdbResolve(Protocol): - def __duckdb_resolve__(self, ctx: PQueryContext, /, **future_kwargs) -> str: ... - - -class DuckdDbPyRelation: - def __duckdb_resolve__(self, ctx: PQueryContext, /, **future_kwargs) -> str: - # this would just return the existing SQL. - return "SELECT * FROM some_table" From 8af421d7afb37c20d385c3fe1efe59f904107a40 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 11:28:40 -0800 Subject: [PATCH 03/29] tweak API example more --- template.py | 264 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 184 insertions(+), 80 deletions(-) diff --git a/template.py b/template.py index 1ce9c0fc..5f86f4a0 100644 --- a/template.py +++ b/template.py @@ -1,10 +1,30 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Protocol, TypedDict, runtime_checkable +from typing import TYPE_CHECKING, Literal, NoReturn, Protocol, TypedDict, runtime_checkable if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence + from typing_extensions import NotRequired, TypeIs + +__all__ = [ + "CompiledSql", + "IntoInterpolation", + "IntoParam", + "IntoTemplate", + "Param", + "SupportsDuckdbTemplate", + "compile", + "template", +] + + +class CompiledSql(TypedDict): + """Represents a compiled SQL statement, with the final SQL string and a list of Params to be passed to duckdb.""" + + sql: str + params: list[Param] + class Param(TypedDict): """Represents a parameter to be passed to duckdb, with a name and a value.""" @@ -13,39 +33,112 @@ class Param(TypedDict): name: str -class CompiledSql(TypedDict): - """Represents a compiled SQL statement, with the final SQL string and a list of Params to be passed to duckdb.""" +class IntoParam(TypedDict): + """A Param with a name that is None, which can be used as input to the template engine, which will assign it a name based on its position and optionally an expression.""" - sql: str - params: list[Param] + value: object + name: NotRequired[str | None] -@runtime_checkable -class SupportsDuckdbTemplate(Protocol): - def __duckdb_template__( - self, /, **future_kwargs - ) -> str | IntoInterpolation | Iterable[str | IntoInterpolation]: ... +def is_into_param(thing: object) -> TypeIs[IntoParam]: + try: + value = thing["value"] # ty:ignore[not-subscriptable] + except (TypeError, KeyError): + return False + try: + name = thing["name"] # ty:ignore[not-subscriptable] + except KeyError: + name = None + + return isinstance(value, object) and isinstance(name, (str, type(None))) -def resolve_to_template(thing: object, /, **ignored_kwargs) -> SqlTemplate: +@runtime_checkable +class SupportsDuckdbTemplate(Protocol): + """Something that can be converted into a SqlTemplate by implementing the __duckdb_template__ method.""" + + def __duckdb_template__(self, /, **future_kwargs) -> str | IntoParam | Iterable[str | IntoParam]: + """Convert self into a SqlTemplate, by returning either a string, an IntoParam, or an iterable of these. + + The future_kwargs are for future extensibility, in case duckdb wants + to pass additional information in the future. + To be future-proof, implementations should accept any additional kwargs, + and ignore them at this point. + + Examples: + A simple implementation might just return a string, eg + ```python + class MyRelation: + def __duckdb_template__(self, **kwargs): + return "SELECT * FROM my_table" + ``` + + A more complex implementation might return an iterable of strings and IntoParams. + An IntoParam is a dict with a "value" key, and optionally a "name" key. + + For example: + ```python + class User: + def __init__(self, user_id: int): + self.user_id = user_id + + def __duckdb_template__(self, **kwargs): + return [ + "SELECT * FROM users WHERE id = ", + {"value": self.user_id, "name": "user_id"}, + ] + ``` + + This will resolve to the final SQL and params: + ```python + { + "sql": "SELECT * FROM users WHERE id = $p0_user_id", + "params": [{"name": "p0_user_id", "value": 123}], + } + ``` + """ + + +def param(value: object, name: str | None = None) -> IntoParam: + """Helper function to create an IntoParam with an optional name.""" + if name is not None: + assert_param_name_legal(name) + return IntoParam(value=value, name=name) + + +def template(thing: object, /, **ignored_kwargs) -> SqlTemplate: + """Convert something to a SqlTemplate. + + The rules are: + - If the thing has a __duckdb_template__ method, call it and convert the + resuling strings and IntoParams into a SqlTemplate. + - If the thing is a string, treat it as raw SQL and return a SqlTemplate with that string. + - If the thing is an IntoTemplate, resolve it into a SqlTemplate by recursively resolving + any inner IntoInterpolations and flattening any nested templates. + - If the thing is an IntoInterpolation, resolve it into a SqlTemplate by recursively resolving + its value, and if it has a conversion specified (!s, !r, !a), treat it as raw SQL. + - Otherwise, treat the thing as a param. + """ if isinstance(thing, SupportsDuckdbTemplate): raw = thing.__duckdb_template__(**ignored_kwargs) - return SqlTemplate.from_part_or_parts(raw) + return SqlTemplate(raw) + if isinstance(thing, str): + return SqlTemplate(thing) if isinstance(thing, IntoTemplate): return resolve_into_template(thing) if isinstance(thing, IntoInterpolation): return resolve_interpolation(thing) - return SqlTemplate(OurInterpolation(thing)) + return SqlTemplate(param(value=thing)) def compile(thing: object) -> CompiledSql: - resolved_template = resolve_to_template(thing) + resolved_template = template(thing) return compile_sql_template(resolved_template) def resolve_into_template(template: IntoTemplate) -> SqlTemplate: """Resolve a Template, recursively resolving interpolations and flattening nested templates.""" - parts: list[str | IntoInterpolation] = [] + parts: list[str | IntoParam] = [] for part in template: if isinstance(part, str): parts.append(part) @@ -56,16 +149,30 @@ def resolve_into_template(template: IntoTemplate) -> SqlTemplate: def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: + """Resolve something that can be converted into an Interpolation, recursively resolving any inner templates.""" + value = interp.value # if conversion specified (!s, !r, !a), treat as raw sql, eg - # t"SELECT {"mycol"!s} FROM foo" should be "SELECT mycol FROM foo + # name = "Alice" + # t"SELECT * FROM users where name = '{name!s}'" should be + # "SELECT * FROM users where name = 'Alice'", eg with no param, + # since the user is explicitly asking for the value to be directly interpolated into the SQL string, + # rather than passed as a param. + # This is useful for cases where the value is not something that can be passed as a param, + # eg an identifier like a table or column name, or a SQL expression like "CURRENT_DATE", + # or if the user just wants to write raw SQL and doesn't care about safety + # Note that this is potentially unsafe if the value comes from an untrusted source, + # since it could lead to SQL injection vulnerabilities, so it should be used with caution. if interp.conversion == "s": - return SqlTemplate(str(interp.value)) + return SqlTemplate(str(value)) elif interp.conversion == "r": - return SqlTemplate(repr(interp.value)) + return SqlTemplate(repr(value)) elif interp.conversion == "a": - return SqlTemplate(ascii(interp.value)) + return SqlTemplate(ascii(value)) - templ = resolve_to_template(interp.value) + if isinstance(value, str): + # do NOT pass to template, since that would treat it as a raw SQL. + return SqlTemplate(param(value, name=interp.expression)) + templ = template(value) # If the resolved inner is just a single Interpolation, then just return # the original value so that we preserve the expression name. # For example, if we have @@ -78,8 +185,8 @@ def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: # "SELECT * FROM people WHERE age = $age", with a param $age=37, # eg with a friendly param name, rather than # "SELECT * FROM people WHERE age = $p0", with a param $p0=37 - if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.interpolations) == 1: - return SqlTemplate(interp) + if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.params) == 1: + return SqlTemplate(param(value=value, name=interp.expression)) else: # We got something nested, eg # age = 37 @@ -90,69 +197,37 @@ def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: return templ -class DuckdDbPyRelation: - def __duckdb_template__(self, /, **future_kwargs) -> str: - # this would just return the existing SQL. - return "SELECT * FROM some_table" - - -class OurInterpolation: - def __init__( - self, - value: object, - conversion: Literal["s", "r", "a"] | None = None, - expression: str | None = None, - ): - self.value = value - self.conversion = conversion - self.expression = expression +# class DuckdDbPyRelation: +# def __duckdb_template__(self, /, **future_kwargs) -> str: +# # this would just return the existing SQL. +# return "SELECT * FROM some_table" class SqlTemplate: - def __init__(self, *parts: str | IntoInterpolation): - strings, interpolations = [], [] - last_thing: Literal["string", "interpolation"] | None = None - for part in parts: - if isinstance(part, str): - if last_thing == "string": - # Merge adjacent string parts for efficiency, since the template engine allows that - strings[-1] += part - strings.append(part) - last_thing = "string" - else: - if last_thing is None or last_thing == "interpolation": - # this is the first part, - # or there were two adjacent interpolations, - # so we need an empty string spacer - strings.append("") - interpolations.append(OurInterpolation(part.value, part.conversion)) - last_thing = "interpolation" - if last_thing == "interpolation": - # If the last part was an interpolation, we need to end with an empty string - strings.append("") - assert len(strings) == len(interpolations) + 1 - self.strings = strings - self.interpolations = interpolations - - def __iter__(self): - for s, i in zip(self.strings, self.interpolations): + """Very similar to string.templatelib.Template, but instead of Interpolations, we use IntoParams.""" + + def __init__(self, thing: str | IntoParam | Iterable[str | IntoParam]) -> None: + parts = [thing] if isinstance(thing, str) or is_into_param(thing) else list(thing) + strings, params = parse_strings_and_params(parts) + for param in params: + if name := param.get("name"): + assert_param_name_legal(name) + self.strings = tuple(strings) + self.params = tuple(params) + + def __iter__(self) -> Iterator[str | IntoParam]: + for s, i in zip(self.strings, self.params, strict=False): yield s yield i yield self.strings[-1] - @classmethod - def from_part_or_parts( - cls, part_or_parts: str | IntoInterpolation | Iterable[str | IntoInterpolation] - ) -> SqlTemplate: - if isinstance(part_or_parts, (str, IntoInterpolation)): - return cls(part_or_parts) - else: - return cls(*part_or_parts) - - def __str__(self): + def __str__(self) -> NoReturn: msg = f"{self.__class__.__name__} cannot be directly converted to string. It needs to be processed by the SQL engine to produce the final SQL string." # noqa: E501 raise NotImplementedError(msg) + def compile(self) -> CompiledSql: + return compile_sql_template(self) + def compile_sql_template(template: SqlTemplate) -> CompiledSql: """Compile a resolved SqlTemplate into a final SQL string with named parameter placeholders, and a list of Params.""" @@ -163,11 +238,10 @@ def compile_sql_template(template: SqlTemplate) -> CompiledSql: sql_parts.append(part) else: param_name = f"p{len(params)}" - if part.expression is not None: - param_name += f"_{part.expression}" - assert_param_name_legal(param_name) + if passed_name := part.get("name"): + param_name += f"_{passed_name}" sql_parts.append(f"${param_name}") - params.append({"name": param_name, "value": part.value}) + params.append({"name": param_name, "value": part["value"]}) return { "sql": "".join(sql_parts), "params": params, @@ -186,10 +260,13 @@ class IntoInterpolation(Protocol): @runtime_checkable class IntoTemplate(Protocol): + """Something that can be converted into string.templatelib.Template.""" + strings: Sequence[str] interpolations: Sequence[IntoInterpolation] - def __iter__(self) -> Iterator[str | IntoInterpolation]: ... + def __iter__(self) -> Iterator[str | IntoInterpolation]: + """Iterate over the strings and interpolations in order.""" def assert_param_name_legal(name: str) -> None: @@ -197,3 +274,30 @@ def assert_param_name_legal(name: str) -> None: # not implemented yet # Not exactly sure what part of the stack this should get called in, # or perhaps we shouldn't even check, just pass it to duckdb and let it error if it's illegal + + +def parse_strings_and_params( + parts: Iterable[str | IntoParam], +) -> tuple[tuple[str, ...], tuple[IntoParam, ...]]: + """Parse an iterable of strings and params into separate tuples of strings and params, merging adjacent strings and ensuring that the number of strings is one more than the number of params.""" + strings, params = [], [] + last_thing: Literal["string", "param"] | None = None + for part in parts: + if isinstance(part, str): + if last_thing == "string": + # Merge adjacent string parts for efficiency, since the template engine allows that + strings[-1] += part + strings.append(part) + last_thing = "string" + else: + if last_thing is None or last_thing == "param": + # this is the first part or there were two adjacent params, + # so we need an empty string spacer + strings.append("") + params.append(part) + last_thing = "param" + if last_thing == "param": + # If the last part was a param, we need to end with an empty string + strings.append("") + assert len(strings) == len(params) + 1 + return tuple(strings), tuple(params) From 12d1cdae17878e3ffd167d40aed12fe1dd1f572b Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 12:43:37 -0800 Subject: [PATCH 04/29] try separate steps of creation, resolving, and compilation --- template.py | 191 +++++++++++++++++++++++++++------------------------- 1 file changed, 101 insertions(+), 90 deletions(-) diff --git a/template.py b/template.py index 5f86f4a0..5097e1ef 100644 --- a/template.py +++ b/template.py @@ -1,16 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, NoReturn, Protocol, TypedDict, runtime_checkable +from typing import TYPE_CHECKING, Literal, Protocol, TypedDict, TypeVar, runtime_checkable if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence - from typing_extensions import NotRequired, TypeIs + from typing_extensions import TypeIs __all__ = [ "CompiledSql", "IntoInterpolation", - "IntoParam", "IntoTemplate", "Param", "SupportsDuckdbTemplate", @@ -33,32 +32,36 @@ class Param(TypedDict): name: str -class IntoParam(TypedDict): - """A Param with a name that is None, which can be used as input to the template engine, which will assign it a name based on its position and optionally an expression.""" +# class IntoParam(TypedDict): +# """A Param with a name that is None, which can be used as input to the template engine, which will assign it a name based on its position and optionally an expression.""" - value: object - name: NotRequired[str | None] +# value: object +# name: NotRequired[str | None] + + +# def is_into_param(thing: object) -> TypeIs[IntoParam]: +# try: +# value = thing["value"] # ty:ignore[not-subscriptable] +# except (TypeError, KeyError): +# return False +# try: +# name = thing["name"] # ty:ignore[not-subscriptable] +# except KeyError: +# name = None +# return isinstance(value, object) and isinstance(name, (str, type(None))) -def is_into_param(thing: object) -> TypeIs[IntoParam]: - try: - value = thing["value"] # ty:ignore[not-subscriptable] - except (TypeError, KeyError): - return False - try: - name = thing["name"] # ty:ignore[not-subscriptable] - except KeyError: - name = None - return isinstance(value, object) and isinstance(name, (str, type(None))) +def is_into_interpolation(thing: object) -> TypeIs[IntoInterpolation]: + return isinstance(thing, IntoInterpolation) @runtime_checkable class SupportsDuckdbTemplate(Protocol): - """Something that can be converted into a SqlTemplate by implementing the __duckdb_template__ method.""" + """Something that can be converted into a Template by implementing the __duckdb_template__ method.""" - def __duckdb_template__(self, /, **future_kwargs) -> str | IntoParam | Iterable[str | IntoParam]: - """Convert self into a SqlTemplate, by returning either a string, an IntoParam, or an iterable of these. + def __duckdb_template__(self, /, **future_kwargs) -> str | IntoInterpolation | Iterable[str | IntoInterpolation]: + """Convert self into an IntoTemplate, by returning either a string, an IntoInterpolation, or an iterable of these. The future_kwargs are for future extensibility, in case duckdb wants to pass additional information in the future. @@ -83,10 +86,8 @@ def __init__(self, user_id: int): self.user_id = user_id def __duckdb_template__(self, **kwargs): - return [ - "SELECT * FROM users WHERE id = ", - {"value": self.user_id, "name": "user_id"}, - ] + user_id = self.user_id + return t"SELECT * FROM users WHERE id = {user_id}" ``` This will resolve to the final SQL and params: @@ -99,15 +100,15 @@ def __duckdb_template__(self, **kwargs): """ -def param(value: object, name: str | None = None) -> IntoParam: +def param(value: object, name: str | None = None) -> ParamInterpolation: """Helper function to create an IntoParam with an optional name.""" if name is not None: assert_param_name_legal(name) - return IntoParam(value=value, name=name) + return ParamInterpolation(value=value, expression=name, conversion=None) -def template(thing: object, /, **ignored_kwargs) -> SqlTemplate: - """Convert something to a SqlTemplate. +def template(thing: object, /, **ignored_kwargs) -> OurTemplate: + """Convert something to a Template-ish. The rules are: - If the thing has a __duckdb_template__ method, call it and convert the @@ -121,35 +122,35 @@ def template(thing: object, /, **ignored_kwargs) -> SqlTemplate: """ if isinstance(thing, SupportsDuckdbTemplate): raw = thing.__duckdb_template__(**ignored_kwargs) - return SqlTemplate(raw) + parts = [raw] if isinstance(raw, str) or is_into_interpolation(raw) else list(raw) + return OurTemplate(*parts) if isinstance(thing, str): - return SqlTemplate(thing) + return OurTemplate(thing) if isinstance(thing, IntoTemplate): - return resolve_into_template(thing) + return OurTemplate(*thing) if isinstance(thing, IntoInterpolation): - return resolve_interpolation(thing) - return SqlTemplate(param(value=thing)) + return OurTemplate(thing) + return OurTemplate(param(value=thing)) def compile(thing: object) -> CompiledSql: - resolved_template = template(thing) - return compile_sql_template(resolved_template) + t = template(thing) + resolved = resolve(t) + return compile_parts(resolved) -def resolve_into_template(template: IntoTemplate) -> SqlTemplate: - """Resolve a Template, recursively resolving interpolations and flattening nested templates.""" - parts: list[str | IntoParam] = [] - for part in template: +def resolve(parts: Iterable[str | IntoInterpolation]) -> OurTemplate[str | IntoInterpolation]: + """Resolve an OurTemplate by recursively resolving any inner templates and interpolations.""" + resolved: list[str | IntoInterpolation] = [] + for part in parts: if isinstance(part, str): - parts.append(part) + resolved.append(part) else: - inner_parts = resolve_interpolation(part) - parts.extend(inner_parts) - return SqlTemplate(*parts) + resolved.extend(resolve_interpolation(part)) + return OurTemplate(*resolved) -def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: - """Resolve something that can be converted into an Interpolation, recursively resolving any inner templates.""" +def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInterpolation]: value = interp.value # if conversion specified (!s, !r, !a), treat as raw sql, eg # name = "Alice" @@ -163,16 +164,17 @@ def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: # Note that this is potentially unsafe if the value comes from an untrusted source, # since it could lead to SQL injection vulnerabilities, so it should be used with caution. if interp.conversion == "s": - return SqlTemplate(str(value)) + return OurTemplate(str(value)) elif interp.conversion == "r": - return SqlTemplate(repr(value)) + return OurTemplate(repr(value)) elif interp.conversion == "a": - return SqlTemplate(ascii(value)) + return OurTemplate(ascii(value)) if isinstance(value, str): # do NOT pass to template, since that would treat it as a raw SQL. - return SqlTemplate(param(value, name=interp.expression)) + return OurTemplate(param(value, name=interp.expression)) templ = template(value) + # resolved = resolve(templ) # If the resolved inner is just a single Interpolation, then just return # the original value so that we preserve the expression name. # For example, if we have @@ -185,8 +187,8 @@ def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: # "SELECT * FROM people WHERE age = $age", with a param $age=37, # eg with a friendly param name, rather than # "SELECT * FROM people WHERE age = $p0", with a param $p0=37 - if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.params) == 1: - return SqlTemplate(param(value=value, name=interp.expression)) + if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.interpolations) == 1: + return OurTemplate(param(value=value, name=interp.expression)) else: # We got something nested, eg # age = 37 @@ -197,58 +199,25 @@ def resolve_interpolation(interp: IntoInterpolation) -> SqlTemplate: return templ -# class DuckdDbPyRelation: -# def __duckdb_template__(self, /, **future_kwargs) -> str: -# # this would just return the existing SQL. -# return "SELECT * FROM some_table" - - -class SqlTemplate: - """Very similar to string.templatelib.Template, but instead of Interpolations, we use IntoParams.""" - - def __init__(self, thing: str | IntoParam | Iterable[str | IntoParam]) -> None: - parts = [thing] if isinstance(thing, str) or is_into_param(thing) else list(thing) - strings, params = parse_strings_and_params(parts) - for param in params: - if name := param.get("name"): - assert_param_name_legal(name) - self.strings = tuple(strings) - self.params = tuple(params) - - def __iter__(self) -> Iterator[str | IntoParam]: - for s, i in zip(self.strings, self.params, strict=False): - yield s - yield i - yield self.strings[-1] - - def __str__(self) -> NoReturn: - msg = f"{self.__class__.__name__} cannot be directly converted to string. It needs to be processed by the SQL engine to produce the final SQL string." # noqa: E501 - raise NotImplementedError(msg) - - def compile(self) -> CompiledSql: - return compile_sql_template(self) - - -def compile_sql_template(template: SqlTemplate) -> CompiledSql: +def compile_parts(parts: Iterable[str | IntoInterpolation], /) -> CompiledSql: """Compile a resolved SqlTemplate into a final SQL string with named parameter placeholders, and a list of Params.""" sql_parts: list[str] = [] params: list[Param] = [] - for part in template: + for part in parts: if isinstance(part, str): sql_parts.append(part) else: param_name = f"p{len(params)}" - if passed_name := part.get("name"): + if passed_name := part.expression: param_name += f"_{passed_name}" sql_parts.append(f"${param_name}") - params.append({"name": param_name, "value": part["value"]}) + params.append({"name": param_name, "value": part.value}) return { "sql": "".join(sql_parts), "params": params, } -# from string.templatelib import Interpolation, Template @runtime_checkable class IntoInterpolation(Protocol): """Something that can be converted into a string.templatelib.Interpolation.""" @@ -256,6 +225,7 @@ class IntoInterpolation(Protocol): value: object conversion: Literal["s", "r", "a"] | None expression: str | None + format_spec: str @runtime_checkable @@ -276,9 +246,50 @@ def assert_param_name_legal(name: str) -> None: # or perhaps we shouldn't even check, just pass it to duckdb and let it error if it's illegal +class ParamInterpolation: + """A simple implementation of IntoInterpolation, for testing purposes.""" + + def __init__( + self, value: object, conversion: Literal["s", "r", "a"] | None = None, expression: str | None = None + ) -> None: + self.value = value + self.conversion = conversion + self.expression = expression + self.format_spec = "" + + +class OurTemplate: + """A simple implementation of IntoTemplate, for testing purposes.""" + + def __init__( + self, + *parts: str | IntoInterpolation, + ) -> None: + self.strings, self.interpolations = parse_strings_and_params(parts) + + def __iter__(self) -> Iterator[str | IntoInterpolation]: + """Iterate over the strings and interpolations in order.""" + for s, i in zip(self.strings, self.interpolations, strict=False): + yield s + yield i + yield self.strings[-1] + + def resolve(self) -> OurTemplate: + """Resolve any inner templates and interpolations, returning a new OurTemplate with only strings and ParamInterpolations.""" + return resolve(self) + + def compile(self) -> CompiledSql: + """Compile this template into a final SQL string with named parameter placeholders, and a list of Params.""" + resolved = self.resolve() + return compile_parts(resolved) + + +T = TypeVar("T") + + def parse_strings_and_params( - parts: Iterable[str | IntoParam], -) -> tuple[tuple[str, ...], tuple[IntoParam, ...]]: + parts: Iterable[str | T], +) -> tuple[tuple[str, ...], tuple[T, ...]]: """Parse an iterable of strings and params into separate tuples of strings and params, merging adjacent strings and ensuring that the number of strings is one more than the number of params.""" strings, params = [], [] last_thing: Literal["string", "param"] | None = None From 4677b11af9b5f8499cf0f6dd24b5f9627d7fcce9 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 13:06:28 -0800 Subject: [PATCH 05/29] fix parse_strings_and_params bug --- template.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/template.py b/template.py index 5097e1ef..2598298f 100644 --- a/template.py +++ b/template.py @@ -298,7 +298,8 @@ def parse_strings_and_params( if last_thing == "string": # Merge adjacent string parts for efficiency, since the template engine allows that strings[-1] += part - strings.append(part) + else: + strings.append(part) last_thing = "string" else: if last_thing is None or last_thing == "param": From 1925a9da02858aecf2258a796db657228c557de6 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 13:15:56 -0800 Subject: [PATCH 06/29] use dataclass, not dict --- template.py | 43 ++++++++++++++----------------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/template.py b/template.py index 2598298f..304a190c 100644 --- a/template.py +++ b/template.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Protocol, TypedDict, TypeVar, runtime_checkable +import dataclasses +from typing import TYPE_CHECKING, Literal, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence @@ -18,40 +19,26 @@ ] -class CompiledSql(TypedDict): +@dataclasses.dataclass(frozen=True) +class CompiledSql: """Represents a compiled SQL statement, with the final SQL string and a list of Params to be passed to duckdb.""" sql: str - params: list[Param] + params: tuple[Param, ...] + # def __init__(self, sql: str, params: Iterable[Param]) -> None: + # self.sql = sql + # self.params = tuple(params) -class Param(TypedDict): + +@dataclasses.dataclass(frozen=True) +class Param: """Represents a parameter to be passed to duckdb, with a name and a value.""" value: object name: str -# class IntoParam(TypedDict): -# """A Param with a name that is None, which can be used as input to the template engine, which will assign it a name based on its position and optionally an expression.""" - -# value: object -# name: NotRequired[str | None] - - -# def is_into_param(thing: object) -> TypeIs[IntoParam]: -# try: -# value = thing["value"] # ty:ignore[not-subscriptable] -# except (TypeError, KeyError): -# return False -# try: -# name = thing["name"] # ty:ignore[not-subscriptable] -# except KeyError: -# name = None - -# return isinstance(value, object) and isinstance(name, (str, type(None))) - - def is_into_interpolation(thing: object) -> TypeIs[IntoInterpolation]: return isinstance(thing, IntoInterpolation) @@ -211,11 +198,9 @@ def compile_parts(parts: Iterable[str | IntoInterpolation], /) -> CompiledSql: if passed_name := part.expression: param_name += f"_{passed_name}" sql_parts.append(f"${param_name}") - params.append({"name": param_name, "value": part.value}) - return { - "sql": "".join(sql_parts), - "params": params, - } + # params.append({"name": param_name, "value": part.value}) + params.append(Param(name=param_name, value=part.value)) + return CompiledSql(sql="".join(sql_parts), params=tuple(params)) @runtime_checkable From ff542125e51a7040b6e902f573f8fb00a2c12c17 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 13:19:37 -0800 Subject: [PATCH 07/29] remove conversion from ParamInterpretation --- template.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/template.py b/template.py index 304a190c..bc514dfe 100644 --- a/template.py +++ b/template.py @@ -91,7 +91,7 @@ def param(value: object, name: str | None = None) -> ParamInterpolation: """Helper function to create an IntoParam with an optional name.""" if name is not None: assert_param_name_legal(name) - return ParamInterpolation(value=value, expression=name, conversion=None) + return ParamInterpolation(value=value, expression=name) def template(thing: object, /, **ignored_kwargs) -> OurTemplate: @@ -198,7 +198,6 @@ def compile_parts(parts: Iterable[str | IntoInterpolation], /) -> CompiledSql: if passed_name := part.expression: param_name += f"_{passed_name}" sql_parts.append(f"${param_name}") - # params.append({"name": param_name, "value": part.value}) params.append(Param(name=param_name, value=part.value)) return CompiledSql(sql="".join(sql_parts), params=tuple(params)) @@ -234,12 +233,10 @@ def assert_param_name_legal(name: str) -> None: class ParamInterpolation: """A simple implementation of IntoInterpolation, for testing purposes.""" - def __init__( - self, value: object, conversion: Literal["s", "r", "a"] | None = None, expression: str | None = None - ) -> None: + def __init__(self, value: object, expression: str | None = None) -> None: self.value = value - self.conversion = conversion self.expression = expression + self.conversion = None self.format_spec = "" From f57f4087b5a78e4622ca480874133c9c4ce82802 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 13:22:31 -0800 Subject: [PATCH 08/29] fixup --- template.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/template.py b/template.py index bc514dfe..e2580376 100644 --- a/template.py +++ b/template.py @@ -94,13 +94,13 @@ def param(value: object, name: str | None = None) -> ParamInterpolation: return ParamInterpolation(value=value, expression=name) -def template(thing: object, /, **ignored_kwargs) -> OurTemplate: +def template(thing: object, /, **ignored_kwargs) -> SqlTemplate: """Convert something to a Template-ish. The rules are: - If the thing has a __duckdb_template__ method, call it and convert the resuling strings and IntoParams into a SqlTemplate. - - If the thing is a string, treat it as raw SQL and return a SqlTemplate with that string. + - If the thing is a `str`, treat it as raw SQL and return a SqlTemplate with that string. - If the thing is an IntoTemplate, resolve it into a SqlTemplate by recursively resolving any inner IntoInterpolations and flattening any nested templates. - If the thing is an IntoInterpolation, resolve it into a SqlTemplate by recursively resolving @@ -110,23 +110,24 @@ def template(thing: object, /, **ignored_kwargs) -> OurTemplate: if isinstance(thing, SupportsDuckdbTemplate): raw = thing.__duckdb_template__(**ignored_kwargs) parts = [raw] if isinstance(raw, str) or is_into_interpolation(raw) else list(raw) - return OurTemplate(*parts) + return SqlTemplate(*parts) if isinstance(thing, str): - return OurTemplate(thing) + return SqlTemplate(thing) if isinstance(thing, IntoTemplate): - return OurTemplate(*thing) + return SqlTemplate(*thing) if isinstance(thing, IntoInterpolation): - return OurTemplate(thing) - return OurTemplate(param(value=thing)) + return SqlTemplate(thing) + return SqlTemplate(param(value=thing)) def compile(thing: object) -> CompiledSql: + """Compile a thing into a final SQL string with named parameter placeholders, and a list of Params.""" t = template(thing) resolved = resolve(t) return compile_parts(resolved) -def resolve(parts: Iterable[str | IntoInterpolation]) -> OurTemplate[str | IntoInterpolation]: +def resolve(parts: Iterable[str | IntoInterpolation]) -> SqlTemplate[str | IntoInterpolation]: """Resolve an OurTemplate by recursively resolving any inner templates and interpolations.""" resolved: list[str | IntoInterpolation] = [] for part in parts: @@ -134,7 +135,7 @@ def resolve(parts: Iterable[str | IntoInterpolation]) -> OurTemplate[str | IntoI resolved.append(part) else: resolved.extend(resolve_interpolation(part)) - return OurTemplate(*resolved) + return SqlTemplate(*resolved) def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInterpolation]: @@ -151,15 +152,15 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInter # Note that this is potentially unsafe if the value comes from an untrusted source, # since it could lead to SQL injection vulnerabilities, so it should be used with caution. if interp.conversion == "s": - return OurTemplate(str(value)) + return SqlTemplate(str(value)) elif interp.conversion == "r": - return OurTemplate(repr(value)) + return SqlTemplate(repr(value)) elif interp.conversion == "a": - return OurTemplate(ascii(value)) + return SqlTemplate(ascii(value)) if isinstance(value, str): # do NOT pass to template, since that would treat it as a raw SQL. - return OurTemplate(param(value, name=interp.expression)) + return SqlTemplate(param(value, name=interp.expression)) templ = template(value) # resolved = resolve(templ) # If the resolved inner is just a single Interpolation, then just return @@ -175,7 +176,7 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInter # eg with a friendly param name, rather than # "SELECT * FROM people WHERE age = $p0", with a param $p0=37 if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.interpolations) == 1: - return OurTemplate(param(value=value, name=interp.expression)) + return SqlTemplate(param(value=value, name=interp.expression)) else: # We got something nested, eg # age = 37 @@ -240,7 +241,7 @@ def __init__(self, value: object, expression: str | None = None) -> None: self.format_spec = "" -class OurTemplate: +class SqlTemplate: """A simple implementation of IntoTemplate, for testing purposes.""" def __init__( @@ -256,7 +257,7 @@ def __iter__(self) -> Iterator[str | IntoInterpolation]: yield i yield self.strings[-1] - def resolve(self) -> OurTemplate: + def resolve(self) -> SqlTemplate: """Resolve any inner templates and interpolations, returning a new OurTemplate with only strings and ParamInterpolations.""" return resolve(self) From dddcc5541aa510ba502c370dde5327003e1e812c Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 13:45:57 -0800 Subject: [PATCH 09/29] adjust implementation --- template.py | 43 ++++++++++++++++++------------------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/template.py b/template.py index e2580376..8f299b88 100644 --- a/template.py +++ b/template.py @@ -15,6 +15,7 @@ "Param", "SupportsDuckdbTemplate", "compile", + "resolve", "template", ] @@ -26,10 +27,6 @@ class CompiledSql: sql: str params: tuple[Param, ...] - # def __init__(self, sql: str, params: Iterable[Param]) -> None: - # self.sql = sql - # self.params = tuple(params) - @dataclasses.dataclass(frozen=True) class Param: @@ -94,21 +91,17 @@ def param(value: object, name: str | None = None) -> ParamInterpolation: return ParamInterpolation(value=value, expression=name) -def template(thing: object, /, **ignored_kwargs) -> SqlTemplate: - """Convert something to a Template-ish. +def template(thing: object, /) -> SqlTemplate: + """Convert something to a SqlTemplate. The rules are: - - If the thing has a __duckdb_template__ method, call it and convert the + - If the thing has a `.__duckdb_template__()` method, call it and convert the resuling strings and IntoParams into a SqlTemplate. - If the thing is a `str`, treat it as raw SQL and return a SqlTemplate with that string. - - If the thing is an IntoTemplate, resolve it into a SqlTemplate by recursively resolving - any inner IntoInterpolations and flattening any nested templates. - - If the thing is an IntoInterpolation, resolve it into a SqlTemplate by recursively resolving - its value, and if it has a conversion specified (!s, !r, !a), treat it as raw SQL. - Otherwise, treat the thing as a param. """ if isinstance(thing, SupportsDuckdbTemplate): - raw = thing.__duckdb_template__(**ignored_kwargs) + raw = thing.__duckdb_template__() parts = [raw] if isinstance(raw, str) or is_into_interpolation(raw) else list(raw) return SqlTemplate(*parts) if isinstance(thing, str): @@ -120,15 +113,8 @@ def template(thing: object, /, **ignored_kwargs) -> SqlTemplate: return SqlTemplate(param(value=thing)) -def compile(thing: object) -> CompiledSql: - """Compile a thing into a final SQL string with named parameter placeholders, and a list of Params.""" - t = template(thing) - resolved = resolve(t) - return compile_parts(resolved) - - -def resolve(parts: Iterable[str | IntoInterpolation]) -> SqlTemplate[str | IntoInterpolation]: - """Resolve an OurTemplate by recursively resolving any inner templates and interpolations.""" +def resolve(parts: Iterable[str | IntoInterpolation]) -> SqlTemplate: + """Resolve a stream of strings and Interpolations, recursively resolving inner interpolations.""" resolved: list[str | IntoInterpolation] = [] for part in parts: if isinstance(part, str): @@ -138,6 +124,13 @@ def resolve(parts: Iterable[str | IntoInterpolation]) -> SqlTemplate[str | IntoI return SqlTemplate(*resolved) +def compile(thing: object) -> CompiledSql: + """Compile a thing into a final SQL string with named parameter placeholders, and a list of Params.""" + t = template(thing) + resolved = resolve(t) + return compile_parts(resolved) + + def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInterpolation]: value = interp.value # if conversion specified (!s, !r, !a), treat as raw sql, eg @@ -188,7 +181,7 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInter def compile_parts(parts: Iterable[str | IntoInterpolation], /) -> CompiledSql: - """Compile a resolved SqlTemplate into a final SQL string with named parameter placeholders, and a list of Params.""" + """Compile parts into a final SQL string with named parameter placeholders, and a list of Params.""" sql_parts: list[str] = [] params: list[Param] = [] for part in parts: @@ -242,13 +235,13 @@ def __init__(self, value: object, expression: str | None = None) -> None: class SqlTemplate: - """A simple implementation of IntoTemplate, for testing purposes.""" + """A sequence of strings and Interpolations.""" def __init__( self, *parts: str | IntoInterpolation, ) -> None: - self.strings, self.interpolations = parse_strings_and_params(parts) + self.strings, self.interpolations = parse_parts(parts) def __iter__(self) -> Iterator[str | IntoInterpolation]: """Iterate over the strings and interpolations in order.""" @@ -270,7 +263,7 @@ def compile(self) -> CompiledSql: T = TypeVar("T") -def parse_strings_and_params( +def parse_parts( parts: Iterable[str | T], ) -> tuple[tuple[str, ...], tuple[T, ...]]: """Parse an iterable of strings and params into separate tuples of strings and params, merging adjacent strings and ensuring that the number of strings is one more than the number of params.""" From ed1ba0928d07b586bfaee1e2dff0b7a6c96d5f33 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 15:25:44 -0800 Subject: [PATCH 10/29] ovehaul api more --- template.py | 336 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 234 insertions(+), 102 deletions(-) diff --git a/template.py b/template.py index 8f299b88..ece5c57a 100644 --- a/template.py +++ b/template.py @@ -1,7 +1,8 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Literal, Protocol, TypeVar, runtime_checkable +from collections import Counter +from typing import TYPE_CHECKING, Literal, NoReturn, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence @@ -20,20 +21,12 @@ ] -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class CompiledSql: """Represents a compiled SQL statement, with the final SQL string and a list of Params to be passed to duckdb.""" sql: str - params: tuple[Param, ...] - - -@dataclasses.dataclass(frozen=True) -class Param: - """Represents a parameter to be passed to duckdb, with a name and a value.""" - - value: object - name: str + params: dict[str, object] def is_into_interpolation(thing: object) -> TypeIs[IntoInterpolation]: @@ -54,44 +47,57 @@ def __duckdb_template__(self, /, **future_kwargs) -> str | IntoInterpolation | I Examples: A simple implementation might just return a string, eg - ```python - class MyRelation: - def __duckdb_template__(self, **kwargs): - return "SELECT * FROM my_table" - ``` - A more complex implementation might return an iterable of strings and IntoParams. - An IntoParam is a dict with a "value" key, and optionally a "name" key. + >>> class MyRelation: + ... def __duckdb_template__(self, **kwargs): + ... return "SELECT * FROM my_table" + >>> t = template(MyRelation()) + >>> t.compile() + CompiledSql(sql='SELECT * FROM my_table', params={}) + + A more complex implementation might return an iterable of strings and Params. + A Param is a dict with a "value" key, and optionally a "name" key. For example: - ```python - class User: - def __init__(self, user_id: int): - self.user_id = user_id - - def __duckdb_template__(self, **kwargs): - user_id = self.user_id - return t"SELECT * FROM users WHERE id = {user_id}" - ``` - - This will resolve to the final SQL and params: - ```python - { - "sql": "SELECT * FROM users WHERE id = $p0_user_id", - "params": [{"name": "p0_user_id", "value": 123}], - } - ``` + >>> class Record: + ... def __init__(self, table: str, id: int): + ... self.table = table + ... self.id = id + ... + ... def __duckdb_template__(self, **kwargs): + ... id = self.id + ... return t"SELECT * FROM {self.table!s} WHERE id = {id}" + >>> t = template(Record("users", 123)) + >>> t.compile() + CompiledSql(sql='SELECT * FROM users WHERE id = $p0_id', params={'p0_id': 123}) """ -def param(value: object, name: str | None = None) -> ParamInterpolation: - """Helper function to create an IntoParam with an optional name.""" - if name is not None: - assert_param_name_legal(name) - return ParamInterpolation(value=value, expression=name) +@dataclasses.dataclass(frozen=True, slots=True) +class Param: + """Represents a parameter to be passed to duckdb, with an optional name and an optional exact flag.""" + + value: object + name: str | None = None + exact: bool = False + + def __post_init__(self) -> None: + if self.exact: + if self.name is None: + msg = "Param with exact=True must have a name." + raise ValueError(msg) + else: + assert_param_name_legal(self.name) + + +def param(value: object, name: str | None = None, *, exact: bool = False) -> Param: + """Helper function to create an Param with an optional name and an optional exact flag.""" + return Param(value=value, name=name, exact=exact) -def template(thing: object, /) -> SqlTemplate: +def template( + thing: str | Param | IntoInterpolation | SupportsDuckdbTemplate | Iterable[str | IntoInterpolation | Param], / +) -> SqlTemplate: """Convert something to a SqlTemplate. The rules are: @@ -99,29 +105,98 @@ def template(thing: object, /) -> SqlTemplate: resuling strings and IntoParams into a SqlTemplate. - If the thing is a `str`, treat it as raw SQL and return a SqlTemplate with that string. - Otherwise, treat the thing as a param. - """ + + Examples: + A very simple example is just passing a string, which will be treated as raw SQL: + + >>> t = template("SELECT * FROM users WHERE id = 123") + >>> repr(t) + SqlTemplate('SELECT * FROM users WHERE id = 123') + >>> t.compile() + CompiledSql(sql='SELECT * FROM users WHERE id = 123', params={}) + + In python 3.14+, [tstrings](https://docs.python.org/3/library/string.templatelib.html) + are very useful here. + Any interpolation inside a tstring will be richly interpreted, + either treated as a param, or expanded as a subquery: + + >>> user_id = 123 + >>> t = template(t"SELECT * FROM users WHERE id = {user_id}") + >>> repr(t) + SqlTemplate('SELECT * FROM users WHERE id = ', Param(value=123, name='user_id'), '') + >>> t.compile() + CompiledSql(sql='SELECT * FROM users WHERE id = $p0_user_id', params={'p0_user_id': 123}) + + This is very friendly with chaining relations: + >>> all_people = duckdb.sql("SELECT * FROM people") + >>> age = 18 + >>> adults = template(t"SELECT * FROM ({all_people}) WHERE age >= {age}") + >>> names = template(t"SELECT name FROM ({adults})") + >>> names.compile() + CompiledSql(sql='SELECT name FROM (SELECT * FROM (SELECT * FROM people) WHERE age >= $p0_age)', params={'p0_age': 18}) + + We also support iterables of strings and Interpolations/Params/etc, + which will be joined together into a single template. + This is very useful for versions of python before 3.14 that don't have tstrings, + since it allows you to build up a template from smaller pieces: + + >>> t = template(["SELECT * FROM (", all_people, ") WHERE age >= ", age]) + >>> t.compile() + CompiledSql(sql='SELECT * FROM (SELECT * FROM people) WHERE age >= $p0_age', params={'p0_age': 18}) + + You can define evaluation logic for your custom types by defining a `.__duckdb_template__()` method. + If this method is defined, the result of that call will be used to create the template. + + >>> class Record: + ... def __init__(self, table_name: str, id: int): + ... self.table_name = table_name + ... self.id = id + ... + ... def __duckdb_template__(self, **kwargs): + ... id = self.id + ... # note the use of !s to indicate that the table name should be treated as raw SQL + ... return t"SELECT * FROM {self.table_name!s} WHERE id = {id}" + >>> t = template(Record("users", 123)) + >>> t.compile() + CompiledSql(sql='SELECT * FROM users WHERE id = $p0_id', params={'p0_id': 123}) + """ # noqa: E501 if isinstance(thing, SupportsDuckdbTemplate): raw = thing.__duckdb_template__() - parts = [raw] if isinstance(raw, str) or is_into_interpolation(raw) else list(raw) - return SqlTemplate(*parts) + return template(raw) if isinstance(thing, str): return SqlTemplate(thing) - if isinstance(thing, IntoTemplate): - return SqlTemplate(*thing) + if isinstance(thing, Param): + return SqlTemplate(ParamInterpolation(thing)) if isinstance(thing, IntoInterpolation): return SqlTemplate(thing) - return SqlTemplate(param(value=thing)) + if is_iterable(thing): + return SqlTemplate(*thing) + return SqlTemplate(ParamInterpolation(param(value=thing))) + + +def is_iterable(thing: object) -> TypeIs[Iterable]: + return isinstance(thing, Iterable) and not isinstance(thing, (str, bytes)) -def resolve(parts: Iterable[str | IntoInterpolation]) -> SqlTemplate: +class ParamInterpolation: + """A simple wrapper that implements the IntoInterpolation protocol for a given IntoParam.""" + + def __init__(self, param: Param): + self.value = param + self.expression = param.name + self.conversion = None + self.format_spec = "" + + +def resolve(parts: Iterable[str | IntoInterpolation]) -> ResolvedSqlTemplate: """Resolve a stream of strings and Interpolations, recursively resolving inner interpolations.""" - resolved: list[str | IntoInterpolation] = [] + resolved: list[str | Param] = [] for part in parts: if isinstance(part, str): resolved.append(part) else: resolved.extend(resolve_interpolation(part)) - return SqlTemplate(*resolved) + return ResolvedSqlTemplate(resolved) def compile(thing: object) -> CompiledSql: @@ -131,8 +206,12 @@ def compile(thing: object) -> CompiledSql: return compile_parts(resolved) -def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInterpolation]: +def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: value = interp.value + if isinstance(value, Param): + # If it's already a Param, we can skip the template resolution and just return it as a param. + return (value,) + # if conversion specified (!s, !r, !a), treat as raw sql, eg # name = "Alice" # t"SELECT * FROM users where name = '{name!s}'" should be @@ -144,16 +223,18 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInter # or if the user just wants to write raw SQL and doesn't care about safety # Note that this is potentially unsafe if the value comes from an untrusted source, # since it could lead to SQL injection vulnerabilities, so it should be used with caution. + formatted = format(value, interp.format_spec) + if interp.conversion == "s": - return SqlTemplate(str(value)) + return (str(formatted),) elif interp.conversion == "r": - return SqlTemplate(repr(value)) + return (repr(formatted),) elif interp.conversion == "a": - return SqlTemplate(ascii(value)) + return (ascii(formatted),) if isinstance(value, str): # do NOT pass to template, since that would treat it as a raw SQL. - return SqlTemplate(param(value, name=interp.expression)) + return (param(value, name=interp.expression),) templ = template(value) # resolved = resolve(templ) # If the resolved inner is just a single Interpolation, then just return @@ -169,7 +250,7 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInter # eg with a friendly param name, rather than # "SELECT * FROM people WHERE age = $p0", with a param $p0=37 if len(templ.strings) == 2 and templ.strings[0] == "" and templ.strings[1] == "" and len(templ.interpolations) == 1: - return SqlTemplate(param(value=value, name=interp.expression)) + return (param(value, name=interp.expression),) else: # We got something nested, eg # age = 37 @@ -177,23 +258,7 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | IntoInter # names = t"SELECT name FROM ({foo})" # names should resolve to: # "SELECT name FROM (SELECT * FROM people WHERE age = $age)", with a param $age=37 - return templ - - -def compile_parts(parts: Iterable[str | IntoInterpolation], /) -> CompiledSql: - """Compile parts into a final SQL string with named parameter placeholders, and a list of Params.""" - sql_parts: list[str] = [] - params: list[Param] = [] - for part in parts: - if isinstance(part, str): - sql_parts.append(part) - else: - param_name = f"p{len(params)}" - if passed_name := part.expression: - param_name += f"_{passed_name}" - sql_parts.append(f"${param_name}") - params.append(Param(name=param_name, value=part.value)) - return CompiledSql(sql="".join(sql_parts), params=tuple(params)) + return resolve(templ) @runtime_checkable @@ -224,24 +289,24 @@ def assert_param_name_legal(name: str) -> None: # or perhaps we shouldn't even check, just pass it to duckdb and let it error if it's illegal -class ParamInterpolation: - """A simple implementation of IntoInterpolation, for testing purposes.""" - - def __init__(self, value: object, expression: str | None = None) -> None: - self.value = value - self.expression = expression - self.conversion = None - self.format_spec = "" - - class SqlTemplate: """A sequence of strings and Interpolations.""" def __init__( self, - *parts: str | IntoInterpolation, + *parts: str | IntoInterpolation | Param, ) -> None: - self.strings, self.interpolations = parse_parts(parts) + self.strings, interps_and_params = parse_parts(parts) + interps = [] + for part in interps_and_params: + if isinstance(part, IntoInterpolation): + interps.append(part) + elif isinstance(part, Param): + # it's a Param, so we need to wrap it in a ParamInterpolation + interps.append(ParamInterpolation(part)) + else: + raise TypeError(f"Unexpected part type: {type(part)}. Expected str, IntoInterpolation, or Param.") + self.interpolations = interps def __iter__(self) -> Iterator[str | IntoInterpolation]: """Iterate over the strings and interpolations in order.""" @@ -250,8 +315,8 @@ def __iter__(self) -> Iterator[str | IntoInterpolation]: yield i yield self.strings[-1] - def resolve(self) -> SqlTemplate: - """Resolve any inner templates and interpolations, returning a new OurTemplate with only strings and ParamInterpolations.""" + def resolve(self) -> ResolvedSqlTemplate: + """Recursively resolve Interpolations into Params, returning a ResolvedSqlTemplate.""" return resolve(self) def compile(self) -> CompiledSql: @@ -259,33 +324,100 @@ def compile(self) -> CompiledSql: resolved = self.resolve() return compile_parts(resolved) + def __str__(self) -> NoReturn: + msg = "SqlTemplate cannot be directly converted to a string, since it may contain unresolved interpolations. Please call .resolve() or .compile() first." # noqa: E501 + raise NotImplementedError(msg) + + def __repr__(self) -> str: + part_strings = [repr(part) for part in self] + return f"SqlTemplate({', '.join(part_strings)})" + + +class ResolvedSqlTemplate: + """A SqlTemplate that has been resolved to only strings and Params.""" + + def __init__(self, parts: Iterable[str | Param]) -> None: + self.parts = tuple(parts) + + def compile(self) -> CompiledSql: + """Compile this template into a final SQL string with named parameter placeholders, and a list of Params.""" + return compile_parts(self.parts) + + def __str__(self) -> NoReturn: + msg = "ResolvedSqlTemplate cannot be directly converted to a string, since it may contain unresolved interpolations. Please call .compile() first." # noqa: E501 + raise NotImplementedError(msg) + + def __repr__(self) -> str: + part_strings = [] + for part in self.parts: + if isinstance(part, str): + part_strings.append(repr(part)) + else: + part_strings.append(f"{{{part.name}={part.value}}}") + return f"ResolvedSqlTemplate({', '.join(part_strings)})" + + def __iter__(self) -> Iterator[str | Param]: + yield from self.parts + T = TypeVar("T") -def parse_parts( - parts: Iterable[str | T], -) -> tuple[tuple[str, ...], tuple[T, ...]]: - """Parse an iterable of strings and params into separate tuples of strings and params, merging adjacent strings and ensuring that the number of strings is one more than the number of params.""" - strings, params = [], [] - last_thing: Literal["string", "param"] | None = None +def parse_parts(parts: Iterable[str | T]) -> tuple[tuple[str, ...], tuple[T, ...]]: + """Parse an iterable of strings and others into separate tuples of strings and others. + + This merges adjacent strings and ensuring that the number of strings is one more than the number of others. + """ + strings, others = [], [] + last_thing: Literal["string", "other"] | None = None for part in parts: if isinstance(part, str): if last_thing == "string": - # Merge adjacent string parts for efficiency, since the template engine allows that + # Merge adjacent string parts strings[-1] += part else: strings.append(part) last_thing = "string" else: - if last_thing is None or last_thing == "param": - # this is the first part or there were two adjacent params, + if last_thing is None or last_thing == "other": + # this is the first part or there were two adjacent others, # so we need an empty string spacer strings.append("") - params.append(part) - last_thing = "param" - if last_thing == "param": - # If the last part was a param, we need to end with an empty string + others.append(part) + last_thing = "other" + if last_thing == "other": + # If the last part was an other, we need to end with an empty string strings.append("") - assert len(strings) == len(params) + 1 - return tuple(strings), tuple(params) + assert len(strings) == len(others) + 1 + return tuple(strings), tuple(others) + + +def compile_parts(parts: Iterable[str | Param], /) -> CompiledSql: + """Compile parts into a final SQL string with named parameter placeholders, and a list of Params.""" + sql_parts: list[str] = [] + params_items = [] + + def next_name(suffix: str | None = None) -> str: + base = f"p{len(params_items)}" + if suffix is not None: + return f"{base}_{suffix}" + else: + return base + + for part in parts: + if isinstance(part, str): + sql_parts.append(part) + else: + if passed_name := part.name: + param_name = passed_name if part.exact else next_name(passed_name) + else: + param_name = next_name() + assert_param_name_legal(param_name) + sql_parts.append(f"${param_name}") + params_items.append((param_name, part.value)) + param_name_counts = Counter(name for name, _ in params_items) + dupes = [name for name, count in param_name_counts.items() if count > 1] + if dupes: + msg = f"Duplicate parameter names found: {dupes}. Please ensure all parameter names are unique." + raise ValueError(msg) + return CompiledSql(sql="".join(sql_parts), params=dict(params_items)) From 971d28de98d23b9d88bb03a6b621391842545c33 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 15:33:55 -0800 Subject: [PATCH 11/29] fixup --- template.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/template.py b/template.py index ece5c57a..2a720245 100644 --- a/template.py +++ b/template.py @@ -2,10 +2,11 @@ import dataclasses from collections import Counter +from collections.abc import Iterable from typing import TYPE_CHECKING, Literal, NoReturn, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Iterator, Sequence from typing_extensions import TypeIs @@ -15,8 +16,7 @@ "IntoTemplate", "Param", "SupportsDuckdbTemplate", - "compile", - "resolve", + "param", "template", ] @@ -29,10 +29,6 @@ class CompiledSql: params: dict[str, object] -def is_into_interpolation(thing: object) -> TypeIs[IntoInterpolation]: - return isinstance(thing, IntoInterpolation) - - @runtime_checkable class SupportsDuckdbTemplate(Protocol): """Something that can be converted into a Template by implementing the __duckdb_template__ method.""" @@ -169,12 +165,12 @@ def template( return SqlTemplate(ParamInterpolation(thing)) if isinstance(thing, IntoInterpolation): return SqlTemplate(thing) - if is_iterable(thing): + if _is_iterable(thing): return SqlTemplate(*thing) return SqlTemplate(ParamInterpolation(param(value=thing))) -def is_iterable(thing: object) -> TypeIs[Iterable]: +def _is_iterable(thing: object) -> TypeIs[Iterable]: return isinstance(thing, Iterable) and not isinstance(thing, (str, bytes)) @@ -188,25 +184,18 @@ def __init__(self, param: Param): self.format_spec = "" -def resolve(parts: Iterable[str | IntoInterpolation]) -> ResolvedSqlTemplate: +def _resolve(parts: Iterable[str | IntoInterpolation]) -> ResolvedSqlTemplate: """Resolve a stream of strings and Interpolations, recursively resolving inner interpolations.""" resolved: list[str | Param] = [] for part in parts: if isinstance(part, str): resolved.append(part) else: - resolved.extend(resolve_interpolation(part)) + resolved.extend(_resolve_interpolation(part)) return ResolvedSqlTemplate(resolved) -def compile(thing: object) -> CompiledSql: - """Compile a thing into a final SQL string with named parameter placeholders, and a list of Params.""" - t = template(thing) - resolved = resolve(t) - return compile_parts(resolved) - - -def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: +def _resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: value = interp.value if isinstance(value, Param): # If it's already a Param, we can skip the template resolution and just return it as a param. @@ -235,8 +224,7 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: if isinstance(value, str): # do NOT pass to template, since that would treat it as a raw SQL. return (param(value, name=interp.expression),) - templ = template(value) - # resolved = resolve(templ) + templ = template(value) # ty:ignore[invalid-argument-type] # If the resolved inner is just a single Interpolation, then just return # the original value so that we preserve the expression name. # For example, if we have @@ -258,7 +246,7 @@ def resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: # names = t"SELECT name FROM ({foo})" # names should resolve to: # "SELECT name FROM (SELECT * FROM people WHERE age = $age)", with a param $age=37 - return resolve(templ) + return _resolve(templ) @runtime_checkable @@ -305,7 +293,8 @@ def __init__( # it's a Param, so we need to wrap it in a ParamInterpolation interps.append(ParamInterpolation(part)) else: - raise TypeError(f"Unexpected part type: {type(part)}. Expected str, IntoInterpolation, or Param.") + msg = f"Unexpected part type: {type(part)}. Expected str, IntoInterpolation, or Param." + raise TypeError(msg) self.interpolations = interps def __iter__(self) -> Iterator[str | IntoInterpolation]: @@ -317,7 +306,7 @@ def __iter__(self) -> Iterator[str | IntoInterpolation]: def resolve(self) -> ResolvedSqlTemplate: """Recursively resolve Interpolations into Params, returning a ResolvedSqlTemplate.""" - return resolve(self) + return _resolve(self) def compile(self) -> CompiledSql: """Compile this template into a final SQL string with named parameter placeholders, and a list of Params.""" From 69e509b6e0ca691ad21f27de9c7ff9d2230b6b13 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 15:50:20 -0800 Subject: [PATCH 12/29] add comment explaining param incrementing, remove IntoTemplate --- template.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/template.py b/template.py index 2a720245..3b12d42f 100644 --- a/template.py +++ b/template.py @@ -6,14 +6,13 @@ from typing import TYPE_CHECKING, Literal, NoReturn, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterator from typing_extensions import TypeIs __all__ = [ "CompiledSql", "IntoInterpolation", - "IntoTemplate", "Param", "SupportsDuckdbTemplate", "param", @@ -259,17 +258,6 @@ class IntoInterpolation(Protocol): format_spec: str -@runtime_checkable -class IntoTemplate(Protocol): - """Something that can be converted into string.templatelib.Template.""" - - strings: Sequence[str] - interpolations: Sequence[IntoInterpolation] - - def __iter__(self) -> Iterator[str | IntoInterpolation]: - """Iterate over the strings and interpolations in order.""" - - def assert_param_name_legal(name: str) -> None: """Eg `$param_1` is legal, but `$1param`, `$param-1`, `$param 1`, and `$p ; DROP TABLE users` are not.""" # not implemented yet @@ -387,6 +375,7 @@ def compile_parts(parts: Iterable[str | Param], /) -> CompiledSql: params_items = [] def next_name(suffix: str | None = None) -> str: + # still count exact params in the count, so we get p0, my_param, p2, p3, my_param_2, p5, etc base = f"p{len(params_items)}" if suffix is not None: return f"{base}_{suffix}" From 0c0fd4ebe905a93c27d24e744f707609b8673bf9 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 16:57:45 -0800 Subject: [PATCH 13/29] adjust implementation and tests --- template.py | 153 ++-- tests/fast/test_template.py | 999 ++++++++++++++++++++++++ tests/fast/test_template_tstrings.py314 | 178 +++++ 3 files changed, 1249 insertions(+), 81 deletions(-) create mode 100644 tests/fast/test_template.py create mode 100644 tests/fast/test_template_tstrings.py314 diff --git a/template.py b/template.py index 3b12d42f..09b27c7f 100644 --- a/template.py +++ b/template.py @@ -1,3 +1,5 @@ +"""Template system for duckdb SQL statements, based on Python's string.templatelib.""" + from __future__ import annotations import dataclasses @@ -8,8 +10,6 @@ if TYPE_CHECKING: from collections.abc import Iterator - from typing_extensions import TypeIs - __all__ = [ "CompiledSql", "IntoInterpolation", @@ -32,40 +32,17 @@ class CompiledSql: class SupportsDuckdbTemplate(Protocol): """Something that can be converted into a Template by implementing the __duckdb_template__ method.""" - def __duckdb_template__(self, /, **future_kwargs) -> str | IntoInterpolation | Iterable[str | IntoInterpolation]: - """Convert self into an IntoTemplate, by returning either a string, an IntoInterpolation, or an iterable of these. - - The future_kwargs are for future extensibility, in case duckdb wants - to pass additional information in the future. - To be future-proof, implementations should accept any additional kwargs, - and ignore them at this point. - - Examples: - A simple implementation might just return a string, eg - - >>> class MyRelation: - ... def __duckdb_template__(self, **kwargs): - ... return "SELECT * FROM my_table" - >>> t = template(MyRelation()) - >>> t.compile() - CompiledSql(sql='SELECT * FROM my_table', params={}) - - A more complex implementation might return an iterable of strings and Params. - A Param is a dict with a "value" key, and optionally a "name" key. - - For example: - >>> class Record: - ... def __init__(self, table: str, id: int): - ... self.table = table - ... self.id = id - ... - ... def __duckdb_template__(self, **kwargs): - ... id = self.id - ... return t"SELECT * FROM {self.table!s} WHERE id = {id}" - >>> t = template(Record("users", 123)) - >>> t.compile() - CompiledSql(sql='SELECT * FROM users WHERE id = $p0_id', params={'p0_id': 123}) - """ + def __duckdb_template__( + self, /, **future_kwargs + ) -> ( + str + | IntoInterpolation + | Param + | SupportsDuckdbTemplate + | object + | Iterable[str | IntoInterpolation | Param | SupportsDuckdbTemplate | object] + ): + """Convert self into something that template() understands.""" @dataclasses.dataclass(frozen=True, slots=True) @@ -77,6 +54,7 @@ class Param: exact: bool = False def __post_init__(self) -> None: + """Ensure passed args were valid.""" if self.exact: if self.name is None: msg = "Param with exact=True must have a name." @@ -90,16 +68,17 @@ def param(value: object, name: str | None = None, *, exact: bool = False) -> Par return Param(value=value, name=name, exact=exact) -def template( - thing: str | Param | IntoInterpolation | SupportsDuckdbTemplate | Iterable[str | IntoInterpolation | Param], / -) -> SqlTemplate: - """Convert something to a SqlTemplate. +def template(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | object) -> SqlTemplate: + """Convert a sequence of things into a SqlTemplate. - The rules are: - - If the thing has a `.__duckdb_template__()` method, call it and convert the - resuling strings and IntoParams into a SqlTemplate. + We go through the parts and convert it into a sequence of str and Interpolations, + which we then hand off to SqlTemplate. + - If the thing has a `.__duckdb_template__()` method, call it, + and call template() recursively on the result. - If the thing is a `str`, treat it as raw SQL and return a SqlTemplate with that string. - - Otherwise, treat the thing as a param. + - If it's an Interpolation, leave it as is, treating it as an interpolation. + - It it's a Param, wrap it in an Interpolation. + - Otherwise, treat the thing as a param, and then wrap the Param in an Interpolation. Examples: A very simple example is just passing a string, which will be treated as raw SQL: @@ -155,33 +134,47 @@ def template( >>> t.compile() CompiledSql(sql='SELECT * FROM users WHERE id = $p0_id', params={'p0_id': 123}) """ # noqa: E501 - if isinstance(thing, SupportsDuckdbTemplate): - raw = thing.__duckdb_template__() - return template(raw) - if isinstance(thing, str): - return SqlTemplate(thing) - if isinstance(thing, Param): - return SqlTemplate(ParamInterpolation(thing)) - if isinstance(thing, IntoInterpolation): - return SqlTemplate(thing) - if _is_iterable(thing): - return SqlTemplate(*thing) - return SqlTemplate(ParamInterpolation(param(value=thing))) - - -def _is_iterable(thing: object) -> TypeIs[Iterable]: - return isinstance(thing, Iterable) and not isinstance(thing, (str, bytes)) + expanded = _expand_part(part) + return SqlTemplate(*expanded) + + +def _expand_part(part: object) -> Iterable[str | IntoInterpolation]: + if isinstance(part, SupportsDuckdbTemplate): + raw = part.__duckdb_template__() + if isinstance(raw, str): # noqa: SIM114 + yield raw + elif isinstance(raw, IntoInterpolation): + yield raw + elif isinstance(raw, Param): + yield ParamInterpolation(raw) + elif isinstance(raw, Iterable): + yield from _expand_part(raw) + else: + p = param(value=raw) + yield ParamInterpolation(p) + elif isinstance(part, str): # noqa: SIM114 + yield part + elif isinstance(part, IntoInterpolation): + yield part + elif isinstance(part, Param): + yield ParamInterpolation(part) + else: + p = param(value=part) + yield ParamInterpolation(p) class ParamInterpolation: """A simple wrapper that implements the IntoInterpolation protocol for a given IntoParam.""" - def __init__(self, param: Param): + def __init__(self, param: Param): # noqa: ANN204 self.value = param self.expression = param.name self.conversion = None self.format_spec = "" + def __repr__(self) -> str: + return repr(self.value) + def _resolve(parts: Iterable[str | IntoInterpolation]) -> ResolvedSqlTemplate: """Resolve a stream of strings and Interpolations, recursively resolving inner interpolations.""" @@ -211,19 +204,25 @@ def _resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: # or if the user just wants to write raw SQL and doesn't care about safety # Note that this is potentially unsafe if the value comes from an untrusted source, # since it could lead to SQL injection vulnerabilities, so it should be used with caution. - formatted = format(value, interp.format_spec) - + # + # Follow Python's f-string semantics: apply conversion first, then format_spec. + # e.g. {x!r:.10} means format(repr(x), ".10") if interp.conversion == "s": - return (str(formatted),) + converted = str(value) elif interp.conversion == "r": - return (repr(formatted),) + converted = repr(value) elif interp.conversion == "a": - return (ascii(formatted),) + converted = ascii(value) + else: + converted = None + + if converted is not None: + return (format(converted, interp.format_spec),) if isinstance(value, str): # do NOT pass to template, since that would treat it as a raw SQL. return (param(value, name=interp.expression),) - templ = template(value) # ty:ignore[invalid-argument-type] + templ = template(value) # If the resolved inner is just a single Interpolation, then just return # the original value so that we preserve the expression name. # For example, if we have @@ -270,20 +269,9 @@ class SqlTemplate: def __init__( self, - *parts: str | IntoInterpolation | Param, + *parts: str | IntoInterpolation, ) -> None: - self.strings, interps_and_params = parse_parts(parts) - interps = [] - for part in interps_and_params: - if isinstance(part, IntoInterpolation): - interps.append(part) - elif isinstance(part, Param): - # it's a Param, so we need to wrap it in a ParamInterpolation - interps.append(ParamInterpolation(part)) - else: - msg = f"Unexpected part type: {type(part)}. Expected str, IntoInterpolation, or Param." - raise TypeError(msg) - self.interpolations = interps + self.strings, self.interpolations = parse_parts(parts) def __iter__(self) -> Iterator[str | IntoInterpolation]: """Iterate over the strings and interpolations in order.""" @@ -362,7 +350,10 @@ def parse_parts(parts: Iterable[str | T]) -> tuple[tuple[str, ...], tuple[T, ... strings.append("") others.append(part) last_thing = "other" - if last_thing == "other": + if last_thing is None: + # Empty input — return a single empty string to maintain the invariant + strings.append("") + elif last_thing == "other": # If the last part was an other, we need to end with an empty string strings.append("") assert len(strings) == len(others) + 1 diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py new file mode 100644 index 00000000..94ebd80c --- /dev/null +++ b/tests/fast/test_template.py @@ -0,0 +1,999 @@ +"""Exhaustive tests for template.py — the SQL template / t-string system.""" + +from __future__ import annotations + +import pytest + +from template import ( + CompiledSql, + IntoInterpolation, + Param, + ParamInterpolation, + ResolvedSqlTemplate, + SqlTemplate, + SupportsDuckdbTemplate, + compile_parts, + param, + parse_parts, + template, +) + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +class FakeInterpolation: + """Minimal object satisfying IntoInterpolation protocol.""" + + def __init__(self, value, expression=None, conversion=None, format_spec="") -> None: + self.value = value + self.expression = expression + self.conversion = conversion + self.format_spec = format_spec + + +class SimpleRelation: + """A minimal SupportsDuckdbTemplate implementation returning a string.""" + + def __init__(self, sql: str) -> None: + self._sql = sql + + def __duckdb_template__(self, **kwargs) -> str: + return self._sql + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Param +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestParam: + def test_basic_creation(self): + p = Param(value=42) + assert p.value == 42 + assert p.name is None + assert p.exact is False + + def test_named_param(self): + p = Param(value="hello", name="greeting") + assert p.name == "greeting" + assert p.exact is False + + def test_exact_param_requires_name(self): + with pytest.raises(ValueError, match="exact=True must have a name"): + Param(value=1, exact=True) + + def test_exact_param_with_name(self): + p = Param(value=1, name="x", exact=True) + assert p.name == "x" + assert p.exact is True + + def test_frozen(self): + p = Param(value=1) + with pytest.raises(AttributeError): + p.value = 2 # ty:ignore[invalid-assignment] + + def test_param_helper_function(self): + p = param(42, "answer", exact=True) + assert isinstance(p, Param) + assert p.value == 42 + assert p.name == "answer" + assert p.exact is True + + def test_param_repr(self): + p = Param(value=42, name="x") + r = repr(p) + assert "42" in r + assert "x" in r + + def test_param_equality(self): + """Frozen dataclasses support equality by default.""" + assert Param(value=1, name="x") == Param(value=1, name="x") + assert Param(value=1) != Param(value=2) + + def test_param_various_value_types(self): + """Params should accept any Python object as a value.""" + for val in [None, 3.14, True, [1, 2], {"a": 1}, b"bytes", object()]: + p = Param(value=val) + assert p.value is val + + +# ═══════════════════════════════════════════════════════════════════════════════ +# ParamInterpolation +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestParamInterpolation: + def test_wraps_param(self): + p = Param(value=42, name="x") + pi = ParamInterpolation(p) + assert pi.value is p + assert pi.expression == "x" + assert pi.conversion is None + assert pi.format_spec == "" + + def test_unnamed_param_expression_is_none(self): + p = Param(value=42) + pi = ParamInterpolation(p) + assert pi.expression is None + + def test_satisfies_into_interpolation_protocol(self): + pi = ParamInterpolation(Param(value=1, name="x")) + assert isinstance(pi, IntoInterpolation) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse_parts +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestParseParts: + def test_all_strings(self): + strings, others = parse_parts(["hello", " ", "world"]) + assert strings == ("hello world",) + assert others == () + + def test_all_others(self): + a, b = object(), object() + strings, others = parse_parts([a, b]) + # two others → need three string spacers: "", "", "" + assert strings == ("", "", "") + assert others == (a, b) + + def test_alternating(self): + a = object() + strings, others = parse_parts(["before", a, "after"]) + assert strings == ("before", "after") + assert others == (a,) + + def test_string_then_other(self): + a = object() + strings, others = parse_parts(["sql", a]) + assert strings == ("sql", "") + assert others == (a,) + + def test_other_then_string(self): + a = object() + strings, others = parse_parts([a, "sql"]) + assert strings == ("", "sql") + assert others == (a,) + + def test_adjacent_strings_merged(self): + strings, others = parse_parts(["a", "b", "c"]) + assert strings == ("abc",) + assert others == () + + def test_adjacent_others_get_empty_string_spacers(self): + a, b, c = object(), object(), object() + strings, others = parse_parts([a, b, c]) + assert strings == ("", "", "", "") + assert others == (a, b, c) + + def test_invariant_strings_one_more_than_others(self): + """The fundamental invariant: len(strings) == len(others) + 1.""" + cases = [ + ["a"], + ["a", object()], + [object(), "a"], + [object(), object()], + ["a", object(), "b", object(), "c"], + ] + for parts in cases: + strings, others = parse_parts(parts) + assert len(strings) == len(others) + 1, f"Failed for parts={parts}" + + def test_empty_input(self): + """Empty input should return a single empty string and no others.""" + strings, others = parse_parts([]) + assert strings == ("",) + assert others == () + + def test_single_string(self): + strings, others = parse_parts(["SELECT 1"]) + assert strings == ("SELECT 1",) + assert others == () + + def test_single_other(self): + a = object() + strings, others = parse_parts([a]) + assert strings == ("", "") + assert others == (a,) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SqlTemplate construction +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestSqlTemplateConstruction: + def test_plain_string(self): + t = SqlTemplate("SELECT 1") + assert t.strings == ("SELECT 1",) + assert t.interpolations == [] + + def test_multiple_strings_merged(self): + t = SqlTemplate("SELECT ", "1") + assert t.strings == ("SELECT 1",) + assert t.interpolations == [] + + def test_with_param(self): + p = Param(value=42, name="x") + t = SqlTemplate("SELECT ", p, " FROM t") + assert t.strings == ("SELECT ", " FROM t") + assert len(t.interpolations) == 1 + assert t.interpolations[0].value is p + + def test_with_interpolation(self): + interp = FakeInterpolation(value=42, expression="x") + t = SqlTemplate("a ", interp, " b") + assert len(t.interpolations) == 1 + assert t.interpolations[0] is interp + + def test_bare_param_errors(self): + p = Param(value=42) + with pytest.raises(TypeError, match="Unexpected part type"): + SqlTemplate("SELECT ", p) # ty:ignore[invalid-argument-type] + + def test_rejects_invalid_types(self): + """Items that are not str, IntoInterpolation, or Param should raise TypeError.""" + with pytest.raises(TypeError, match="Unexpected part type"): + SqlTemplate(42) # ty:ignore[invalid-argument-type] + + def test_no_args(self): + """Empty SqlTemplate should produce a single empty string.""" + t = SqlTemplate() + assert t.strings == ("",) + assert t.interpolations == [] + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SqlTemplate iteration and repr +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestSqlTemplateIterRepr: + def test_iter_plain_string(self): + t = SqlTemplate("hello") + parts = list(t) + assert parts == ["hello"] + + def test_iter_with_interpolations(self): + p = Param(value=1, name="x") + t = SqlTemplate("a ", p, " b") + parts = list(t) + assert len(parts) == 3 + assert parts[0] == "a " + assert isinstance(parts[1], IntoInterpolation) + assert parts[2] == " b" + + def test_str_raises(self): + t = SqlTemplate("hello") + with pytest.raises(NotImplementedError): + str(t) + + def test_repr_plain_string(self): + t = SqlTemplate("SELECT 1") + r = repr(t) + assert "SqlTemplate" in r + assert "'SELECT 1'" in r + + +# ═══════════════════════════════════════════════════════════════════════════════ +# template() factory — basic cases +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestTemplateFactory: + def test_plain_string(self): + t = template("SELECT 1") + compiled = t.compile() + assert compiled.sql == "SELECT 1" + assert compiled.params == {} + + def test_param(self): + p = Param(value=42, name="answer") + t = template(p) + compiled = t.compile() + assert "$" in compiled.sql + assert 42 in compiled.params.values() + + def test_into_interpolation(self): + interp = FakeInterpolation(value=42, expression="x") + t = template(interp) + compiled = t.compile() + assert 42 in compiled.params.values() + + def test_supports_duckdb_template_string(self): + rel = SimpleRelation("SELECT 1") + t = template(rel) + compiled = t.compile() + assert compiled.sql == "SELECT 1" + assert compiled.params == {} + + def test_supports_duckdb_template_returns_iterable(self): + class MultiPart: + def __duckdb_template__(self, **kwargs) -> list[str | Param]: + return ["SELECT * FROM ", param(42, "x")] + + t = template(MultiPart()) + compiled = t.compile() + assert "$" in compiled.sql + assert 42 in compiled.params.values() + + def test_supports_duckdb_template_returns_interpolation(self): + class InterpReturner: + def __duckdb_template__(self, **kwargs) -> FakeInterpolation: + return FakeInterpolation(value="hello", expression="val") + + t = template(InterpReturner()) + compiled = t.compile() + assert "hello" in compiled.params.values() + + def test_iterable_of_strings(self): + t = template("SELECT ", "1") + compiled = t.compile() + assert compiled.sql == "SELECT 1" + assert compiled.params == {} + + def test_iterable_with_params(self): + t = template("SELECT * FROM t WHERE id = ", Param(value=5, name="id")) + compiled = t.compile() + assert compiled.params + assert 5 in compiled.params.values() + + def test_iterable_with_bare_values(self): + """Bare values in an iterable should be treated as params.""" + t = template("SELECT * FROM t WHERE id = ", 42) + compiled = t.compile() + assert 42 in compiled.params.values() + assert "42" not in compiled.sql + + def test_bare_value_at_top_level_becomes_param(self): + """A bare value passed directly to template() becomes a param.""" + t = template(42) # type: ignore[arg-type] + compiled = t.compile() + assert compiled.params + assert 42 in compiled.params.values() + + def test_bytes_not_treated_as_iterable(self): + """Bytes should not be iterated — _is_iterable excludes bytes.""" + t = template(b"hello") # type: ignore[arg-type] + compiled = t.compile() + expected = CompiledSql(sql="$p0", params={"p0": b"hello"}) + assert compiled == expected + + +# ═══════════════════════════════════════════════════════════════════════════════ +# template() factory — t-string integration (Python 3.14+) +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestTemplateWithInterpolations: + """Tests for template() with interpolations, using FakeInterpolation to simulate t-string behavior.""" + + def test_simple_param(self): + interp = FakeInterpolation(value=123, expression="user_id") + t = template("SELECT * FROM users WHERE id = ", interp, "") + compiled = t.compile() + assert "user_id" in compiled.sql + assert "$" in compiled.sql + assert 123 in compiled.params.values() + + def test_multiple_params(self): + name_interp = FakeInterpolation(value="Alice", expression="name") + age_interp = FakeInterpolation(value=30, expression="age") + t = template("SELECT * FROM users WHERE name = ", name_interp, " AND age = ", age_interp, "") + compiled = t.compile() + assert len(compiled.params) == 2 + assert "Alice" in compiled.params.values() + assert 30 in compiled.params.values() + + def test_no_params(self): + t = template("SELECT 1") + compiled = t.compile() + assert compiled.sql == "SELECT 1" + assert compiled.params == {} + + def test_string_conversion_s(self): + """!s should inline the value as raw SQL (no param).""" + interp = FakeInterpolation(value="users", expression="table", conversion="s") + t = template("SELECT * FROM ", interp, "") + compiled = t.compile() + assert compiled.sql == "SELECT * FROM users" + assert compiled.params == {} + + def test_repr_conversion_r(self): + """!r should inline repr(value) as raw SQL.""" + interp = FakeInterpolation(value="hello", expression="val", conversion="r") + t = template("SELECT ", interp, "") + compiled = t.compile() + # repr of "hello" is "'hello'" + assert "'hello'" in compiled.sql + assert compiled.params == {} + + def test_ascii_conversion_a(self): + """!a should inline ascii(value) as raw SQL.""" + interp = FakeInterpolation(value="café", expression="val", conversion="a") + t = template("SELECT ", interp, "") + compiled = t.compile() + assert compiled.params == {} + # ascii() should escape non-ASCII + assert "\\xe9" in compiled.sql or "caf" in compiled.sql + + def test_string_value_becomes_param_not_raw_sql(self): + """A string interpolation WITHOUT conversion should be a param, not raw SQL.""" + interp = FakeInterpolation(value="Alice", expression="name") + t = template("SELECT * FROM users WHERE name = ", interp, "") + compiled = t.compile() + assert "Alice" not in compiled.sql + assert "Alice" in compiled.params.values() + + def test_nested_template_via_supports_duckdb_template(self): + """An interpolated SupportsDuckdbTemplate should be expanded inline.""" + inner = SimpleRelation("SELECT * FROM people") + interp = FakeInterpolation(value=inner, expression="inner") + t = template("SELECT name FROM (", interp, ")") + compiled = t.compile() + assert "SELECT * FROM people" in compiled.sql + assert compiled.params == {} + + def test_nested_chaining(self): + """Chaining templates: inner template params should propagate.""" + inner = template("SELECT * FROM people WHERE age >= ", 18) + interp = FakeInterpolation(value=inner, expression="inner") + t = template("SELECT name FROM (", interp, ")") + compiled = t.compile() + assert "SELECT * FROM people WHERE age >= $" in compiled.sql + assert 18 in compiled.params.values() + + def test_param_name_derived_from_expression(self): + """Param names should be based on the expression in the interpolation.""" + interp = FakeInterpolation(value=99, expression="my_value") + t = template("SELECT ", interp, "") + compiled = t.compile() + # The name should contain "my_value" (with a prefix) + assert any("my_value" in k for k in compiled.params) + + def test_explicit_param_in_interpolation(self): + """An explicit Param() used in an interpolation should be treated as a param.""" + p = param(42, "answer") + interp = FakeInterpolation(value=p, expression="p") + t = template("SELECT ", interp, "") + compiled = t.compile() + assert 42 in compiled.params.values() + + def test_format_spec_on_conversion(self): + """Format spec combined with conversion: conversion first, then format.""" + interp = FakeInterpolation(value=3.14159, expression="val", conversion="s", format_spec=".5") + t = template("SELECT ", interp, "") + compiled = t.compile() + # Python semantics: str(3.14159) = "3.14159", then format("3.14159", ".5") = "3.141" + assert compiled.params == {} + assert "3.141" in compiled.sql + + def test_format_spec_without_conversion_is_ignored(self): + """Format spec without conversion should ideally apply, but currently it's silently dropped.""" + interp = FakeInterpolation(value=3.14159, expression="val", format_spec=".2f") + t = template("SELECT ", interp, "") + compiled = t.compile() + # The format spec is currently ignored — val becomes a param with its original value + assert 3.14159 in compiled.params.values() + # The formatted "3.14" is NOT in the SQL — it's parameterized + assert "3.14" not in compiled.sql + + +# ═══════════════════════════════════════════════════════════════════════════════ +# resolve / _resolve_interpolation +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestResolve: + def test_plain_string_resolves_to_itself(self): + t = SqlTemplate("SELECT 1") + resolved = t.resolve() + assert list(resolved) == ["SELECT 1"] + + def test_conversion_s_resolves_to_string(self): + interp = FakeInterpolation(value=42, expression="x", conversion="s") + t = SqlTemplate("SELECT ", interp) + resolved = t.compile() + expected = CompiledSql(sql="SELECT 42", params={}) + assert resolved == expected + + def test_conversion_r_resolves_to_repr(self): + interp = FakeInterpolation(value="hello", expression="x", conversion="r") + t = SqlTemplate("SELECT ", interp) + resolved = t.resolve() + parts = list(resolved) + joined = "".join(p if isinstance(p, str) else "" for p in parts) + assert "'hello'" in joined + + def test_conversion_a_resolves_to_ascii(self): + interp = FakeInterpolation(value="café", expression="x", conversion="a") + t = SqlTemplate("SELECT ", interp) + resolved = t.resolve() + parts = list(resolved) + joined = "".join(p if isinstance(p, str) else "" for p in parts) + # ascii("café") → "'caf\\xe9'" + assert "caf" in joined + + def test_string_value_without_conversion_becomes_param(self): + """A string value in an interpolation (no conversion) must become a param, not raw SQL.""" + interp = FakeInterpolation(value="DROP TABLE users", expression="val") + t = SqlTemplate("SELECT ", interp) + resolved = t.resolve() + parts = list(resolved) + # Must NOT inline "DROP TABLE users" as raw SQL + param_parts = [p for p in parts if isinstance(p, Param)] + assert len(param_parts) == 1 + assert param_parts[0].value == "DROP TABLE users" + + def test_nested_supports_duckdb_template(self): + rel = SimpleRelation("SELECT 1") + interp = FakeInterpolation(value=rel, expression="rel") + t = SqlTemplate("SELECT * FROM (", interp, ")") + resolved = t.resolve() + parts = list(resolved) + sql = "".join(p if isinstance(p, str) else f"${p.name}" for p in parts) + assert "SELECT 1" in sql + + def test_expression_name_preserved_for_simple_param(self): + """When interpolation resolves to a single param, the expression name should be kept.""" + interp = FakeInterpolation(value=42, expression="my_age") + t = SqlTemplate("age = ", interp) + resolved = t.resolve() + parts = list(resolved) + param_parts = [p for p in parts if isinstance(p, Param)] + assert len(param_parts) == 1 + assert param_parts[0].name == "my_age" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# ResolvedSqlTemplate +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestResolvedSqlTemplate: + def test_basic(self): + r = ResolvedSqlTemplate(["SELECT ", Param(value=42, name="x")]) + parts = list(r) + assert len(parts) == 2 + + def test_compile(self): + r = ResolvedSqlTemplate(["SELECT ", Param(value=42, name="x")]) + compiled = r.compile() + assert isinstance(compiled, CompiledSql) + assert 42 in compiled.params.values() + + def test_str_raises(self): + r = ResolvedSqlTemplate(["SELECT 1"]) + with pytest.raises(NotImplementedError): + str(r) + + def test_repr(self): + r = ResolvedSqlTemplate(["SELECT ", Param(value=42, name="x")]) + rep = repr(r) + assert "ResolvedSqlTemplate" in rep + assert "x=42" in rep + + def test_iter(self): + parts_in = ["a", Param(value=1, name="x"), "b"] + r = ResolvedSqlTemplate(parts_in) + assert list(r) == parts_in + + +# ═══════════════════════════════════════════════════════════════════════════════ +# compile_parts +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestCompileParts: + def test_all_strings(self): + result = compile_parts(["SELECT 1"]) + assert result == CompiledSql(sql="SELECT 1", params={}) + + def test_single_unnamed_param(self): + result = compile_parts(["SELECT ", Param(value=42)]) + assert result.sql == "SELECT $p0" + assert result.params == {"p0": 42} + + def test_single_named_param(self): + result = compile_parts(["SELECT ", Param(value=42, name="x")]) + assert result.sql == "SELECT $p0_x" + assert result.params == {"p0_x": 42} + + def test_exact_param_uses_literal_name(self): + result = compile_parts(["SELECT ", Param(value=42, name="my_param", exact=True)]) + assert result.sql == "SELECT $my_param" + assert result.params == {"my_param": 42} + + def test_multiple_params_numbered_sequentially(self): + result = compile_parts( + [ + "SELECT * WHERE a = ", + Param(value=1, name="a"), + " AND b = ", + Param(value=2, name="b"), + ] + ) + assert result.sql == "SELECT * WHERE a = $p0_a AND b = $p1_b" + assert result.params == {"p0_a": 1, "p1_b": 2} + + def test_duplicate_param_names_raises(self): + with pytest.raises(ValueError, match="Duplicate parameter names"): + compile_parts( + [ + Param(value=1, name="x", exact=True), + Param(value=2, name="x", exact=True), + ] + ) + + def test_unnamed_params_get_sequential_names(self): + result = compile_parts( + [ + "a = ", + Param(value=1), + " AND b = ", + Param(value=2), + ] + ) + assert result.sql == "a = $p0 AND b = $p1" + assert result.params == {"p0": 1, "p1": 2} + + def test_exact_param_causes_counter_gap(self): + """Known issue: exact params still increment the counter, causing non-sequential auto names.""" + result = compile_parts( + [ + "a = ", + Param(value=1, name="x"), # → p0_x + " AND b = ", + Param(value=2, name="b", exact=True), # → b (exact), but counter increments + " AND c = ", + Param(value=3, name="y"), # → p2_y (not p1_y!) + ] + ) + assert result.params == {"p0_x": 1, "b": 2, "p2_y": 3} + + def test_empty_parts(self): + result = compile_parts([]) + assert result == CompiledSql(sql="", params={}) + + def test_adjacent_strings(self): + result = compile_parts(["SELECT ", "1"]) + assert result.sql == "SELECT 1" + + def test_param_with_none_value(self): + result = compile_parts(["SELECT ", Param(value=None, name="x")]) + assert result.params == {"p0_x": None} + + +# ═══════════════════════════════════════════════════════════════════════════════ +# SupportsDuckdbTemplate protocol +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestSupportsDuckdbTemplate: + def test_protocol_check(self): + rel = SimpleRelation("SELECT 1") + assert isinstance(rel, SupportsDuckdbTemplate) + + def test_non_implementor(self): + assert not isinstance("hello", SupportsDuckdbTemplate) + assert not isinstance(42, SupportsDuckdbTemplate) + + def test_template_calls_dunder(self): + class Tracking: + def __init__(self) -> None: + self.called = False + + def __duckdb_template__(self, **kwargs) -> str: + self.called = True + return "SELECT 1" + + obj = Tracking() + template(obj) + assert obj.called + + def test_future_kwargs_accepted(self): + """Implementations should accept **kwargs for future extensibility.""" + + class Strict: + def __duckdb_template__(self, **kwargs) -> str: + assert isinstance(kwargs, dict) + return "SELECT 1" + + template(Strict()) + + def test_returns_interpolation(self): + """__duckdb_template__ can return an interpolation.""" + + class InterpolationReturner: + def __duckdb_template__(self, **kwargs) -> FakeInterpolation: + return FakeInterpolation(value=42, expression="val") + + t = template(InterpolationReturner()) + compiled = t.compile() + assert 42 in compiled.params.values() + + +# ═══════════════════════════════════════════════════════════════════════════════ +# IntoInterpolation protocol +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestIntoInterpolation: + def test_protocol_check_positive(self): + interp = FakeInterpolation(value=1, expression="x") + assert isinstance(interp, IntoInterpolation) + + def test_protocol_check_negative(self): + assert not isinstance("hello", IntoInterpolation) + assert not isinstance(42, IntoInterpolation) + + def test_param_interpolation_satisfies(self): + pi = ParamInterpolation(Param(value=1)) + assert isinstance(pi, IntoInterpolation) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# CompiledSql +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestCompiledSql: + def test_basic(self): + c = CompiledSql(sql="SELECT $p0", params={"p0": 42}) + assert c.sql == "SELECT $p0" + assert c.params == {"p0": 42} + + def test_frozen(self): + c = CompiledSql(sql="SELECT 1", params={}) + with pytest.raises(AttributeError): + c.sql = "SELECT 2" # ty:ignore[invalid-assignment] + + def test_equality(self): + a = CompiledSql(sql="SELECT 1", params={}) + b = CompiledSql(sql="SELECT 1", params={}) + assert a == b + + def test_repr(self): + c = CompiledSql(sql="SELECT 1", params={}) + assert "SELECT 1" in repr(c) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# End-to-end: compile +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestEndToEndCompile: + def test_plain_sql(self): + result = template("SELECT * FROM users").compile() + assert result == CompiledSql(sql="SELECT * FROM users", params={}) + + def test_param_in_list(self): + result = template(["SELECT * FROM users WHERE id = ", Param(value=5, name="id")]).compile() + assert result.sql == "SELECT * FROM users WHERE id = $p0_id" + assert result.params == {"p0_id": 5} + + def test_multiple_params_in_list(self): + result = template( + [ + "SELECT * FROM users WHERE name = ", + Param(value="Alice", name="name"), + " AND age > ", + Param(value=18, name="age"), + ] + ).compile() + assert result.sql == "SELECT * FROM users WHERE name = $p0_name AND age > $p1_age" + assert result.params == {"p0_name": "Alice", "p1_age": 18} + + def test_supports_duckdb_template_end_to_end(self): + class MyTable: + def __duckdb_template__(self, **kwargs) -> str: + return "SELECT * FROM my_table" + + result = template(MyTable()).compile() + assert result.sql == "SELECT * FROM my_table" + assert result.params == {} + + def test_interpolations_end_to_end(self): + result = template("SELECT * FROM users WHERE id = ", 42, " AND name = ", "Alice").compile() + assert 42 in result.params.values() + assert "Alice" in result.params.values() + # The raw values should NOT be in the SQL + assert "42" not in result.sql + assert "Alice" not in result.sql + + def test_nested_template_relations(self): + """The key use case: chaining template queries.""" + inner = template("SELECT * FROM people WHERE age >= ", 18) + outer = template("SELECT name FROM (", inner, ")") + result = outer.compile() + assert "SELECT * FROM people WHERE age >=" in result.sql + assert "SELECT name FROM (" in result.sql + assert 18 in result.params.values() + # Only one param + assert len(result.params) == 1 + + def test_deeply_nested_tstrings(self): + """Three levels of nesting.""" + val = 100 + level1 = template("SELECT * FROM t WHERE x = ", val) + level2 = template("SELECT * FROM (", level1, ") WHERE y = 1") + level3 = template("SELECT count(*) FROM (", level2, ")") + result = level3.compile() + assert "SELECT * FROM t WHERE x = $" in result.sql + assert 100 in result.params.values() + assert len(result.params) == 1 + + def test_multiple_nested_with_separate_params(self): + """Two nested templates each with their own params.""" + a = 1 + b = 2 + t1 = template("SELECT * FROM t1 WHERE a = ", a) + t2 = template("SELECT * FROM t2 WHERE b = ", b) + outer = template("SELECT * FROM (", t1, ") JOIN (", t2, ")") + result = outer.compile() + assert 1 in result.params.values() + assert 2 in result.params.values() + assert len(result.params) == 2 + + def test_sql_injection_prevented_by_default(self): + """Without !s, even SQL-like strings become params, not raw SQL.""" + evil = "1; DROP TABLE users; --" + interp = FakeInterpolation(value=evil, expression="evil") + result = template("SELECT * FROM users WHERE id = ", interp, "").compile() + # The evil string should be a param, not inlined + assert "DROP TABLE" not in result.sql + assert evil in result.params.values() + + def test_sql_injection_possible_with_s_conversion(self): + """With !s, values ARE inlined — this is intentional but dangerous.""" + evil = "1; DROP TABLE users; --" + interp = FakeInterpolation(value=evil, expression="evil", conversion="s") + result = template("SELECT * FROM users WHERE id = ", interp, "").compile() + assert "DROP TABLE" in result.sql + + def test_exact_param_name(self): + result = template( + [ + "SELECT * FROM t WHERE id = ", + Param(value=42, name="my_id", exact=True), + ] + ).compile() + assert result.sql == "SELECT * FROM t WHERE id = $my_id" + assert result.params == {"my_id": 42} + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Edge cases and known issues +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestEdgeCases: + def test_empty_string_template(self): + result = template("").compile() + assert result == CompiledSql(sql="", params={}) + + def test_param_with_none_value(self): + result = template([Param(value=None, name="x")]).compile() + assert result.params["p0_x"] is None + + def test_param_with_list_value(self): + result = template([Param(value=[1, 2, 3], name="ids")]).compile() + assert result.params["p0_ids"] == [1, 2, 3] + + def test_param_with_dict_value(self): + d = {"key": "value"} + result = template([Param(value=d, name="data")]).compile() + assert result.params["p0_data"] == d + + def test_bool_param(self): + result = template("SELECT * FROM t WHERE active = ", True).compile() + assert result == CompiledSql(sql="SELECT * FROM t WHERE active = $p0", params={"p0": True}) + + def test_float_param(self): + interp = FakeInterpolation(value=3.14, expression="threshold") + result = template("SELECT * FROM t WHERE score > ", interp, "").compile() + assert 3.14 in result.params.values() + + def test_none_param(self): + interp = FakeInterpolation(value=None, expression="val") + result = template("SELECT * FROM t WHERE x IS ", interp, "").compile() + assert None in result.params.values() + + def test_param_object_in_interpolation_preserves_name(self): + """An explicit Param used in an interpolation should keep its name.""" + p = Param(value=42, name="custom_name") + interp = FakeInterpolation(value=p, expression="p") + result = template("SELECT ", interp, "").compile() + assert 42 in result.params.values() + + def test_supports_duckdb_template_priority_over_iterable(self): + class IterableWithTemplate: + def __duckdb_template__(self, **kwargs) -> str: + return "SELECT 1" + + def __iter__(self) -> list[str]: + return ["SELECT 2"] + + result = template(IterableWithTemplate()).compile() + assert result.sql == "SELECT 1" + + def test_same_expression_used_twice(self): + """Using the same expression twice should create two separate params.""" + interp1 = FakeInterpolation(value=42, expression="x") + interp2 = FakeInterpolation(value=42, expression="x") + result = template("SELECT * FROM t WHERE a = ", interp1, " AND b = ", interp2, "").compile() + assert len(result.params) == 2 + assert all(v == 42 for v in result.params.values()) + + def test_mixed_conversion_and_param(self): + """Mix of !s (inlined) and regular (parameterized) in same template.""" + table_interp = FakeInterpolation(value="users", expression="table", conversion="s") + id_interp = FakeInterpolation(value=5, expression="user_id") + result = template("SELECT * FROM ", table_interp, " WHERE id = ", id_interp, "").compile() + assert "users" in result.sql # inlined + assert "5" not in result.sql # parameterized + assert 5 in result.params.values() + + def test_resolved_sql_template_compile_matches_sql_template_compile(self): + """resolve().compile() should match compile() directly.""" + parts = ["SELECT * FROM t WHERE id = ", Param(value=42, name="x")] + t = template(*parts) + assert t.compile() == t.resolve().compile() + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Conversion semantics — documenting current (potentially incorrect) behavior +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TestConversionSemantics: + """Verify conversion + format_spec follows Python f-string semantics.""" + + def test_s_conversion_on_int(self): + interp = FakeInterpolation(value=42, expression="x", conversion="s") + t = SqlTemplate(interp) + resolved = list(t.resolve()) + assert len(resolved) == 1 + assert resolved[0] == "42" + + def test_r_conversion_on_string(self): + """repr('hello') = "'hello'".""" + interp = FakeInterpolation(value="hello", expression="x", conversion="r") + t = SqlTemplate(interp) + resolved = list(t.resolve()) + assert len(resolved) == 1 + assert resolved[0] == "'hello'" + + def test_r_conversion_on_int(self): + """repr(42) = '42', no quotes.""" + interp = FakeInterpolation(value=42, expression="x", conversion="r") + t = SqlTemplate(interp) + resolved = list(t.resolve()) + assert resolved[0] == "42" + + def test_s_conversion_with_format_spec(self): + """Conversion first, then format_spec: str(3.14159) then format with '.5' truncates.""" + interp = FakeInterpolation(value=3.14159, expression="x", conversion="s", format_spec=".5") + t = SqlTemplate(interp) + resolved = list(t.resolve()) + # str(3.14159) = "3.14159", then format("3.14159", ".5") = "3.141" (truncates string to 5 chars) + assert resolved[0] == "3.141" + + def test_r_conversion_with_format_spec(self): + """Python semantics: repr first, then format.""" + interp = FakeInterpolation(value="hi", expression="x", conversion="r", format_spec=".4") + t = SqlTemplate(interp) + resolved = list(t.resolve()) + # repr("hi") = "'hi'", then format("'hi'", ".4") = "'hi'" (already 4 chars) + assert resolved[0] == "'hi'" + + def test_format_spec_ignored_for_parameterized_values(self): + """When no conversion is specified, format_spec is silently ignored — the value is parameterized as-is.""" + interp = FakeInterpolation(value=3.14159, expression="x", format_spec=".2f") + t = SqlTemplate(interp) + compiled = t.compile() + # The value is parameterized with its original value, format_spec is dropped + assert 3.14159 in compiled.params.values() diff --git a/tests/fast/test_template_tstrings.py314 b/tests/fast/test_template_tstrings.py314 new file mode 100644 index 00000000..2cae9e06 --- /dev/null +++ b/tests/fast/test_template_tstrings.py314 @@ -0,0 +1,178 @@ +"""Tests for template.py that use t-string literals (Python 3.14+ only).""" + +from __future__ import annotations + +import sys + +import pytest + +from template import ( + Param, + param, + template, +) + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 14), + reason="t-strings require Python 3.14+", +) + + +class SimpleRelation: + """A minimal SupportsDuckdbTemplate implementation returning a string.""" + + def __init__(self, sql: str): + self._sql = sql + + def __duckdb_template__(self, **kwargs): + return self._sql + + +class TestTStringBasics: + def test_simple_param(self): + user_id = 123 + t = template(t"SELECT * FROM users WHERE id = {user_id}") + compiled = t.compile() + assert "user_id" in compiled.sql + assert "$" in compiled.sql + assert 123 in compiled.params.values() + + def test_multiple_params(self): + name = "Alice" + age = 30 + t = template(t"SELECT * FROM users WHERE name = {name} AND age = {age}") + compiled = t.compile() + assert len(compiled.params) == 2 + assert "Alice" in compiled.params.values() + assert 30 in compiled.params.values() + + def test_no_params(self): + t = template(t"SELECT 1") + compiled = t.compile() + assert compiled.sql == "SELECT 1" + assert compiled.params == {} + + def test_param_name_derived_from_expression(self): + my_value = 99 + t = template(t"SELECT {my_value}") + compiled = t.compile() + assert any("my_value" in k for k in compiled.params) + + def test_explicit_param_in_tstring(self): + p = param(42, "answer") + t = template(t"SELECT {p}") + compiled = t.compile() + assert 42 in compiled.params.values() + + +class TestTStringConversions: + def test_string_conversion_s(self): + table = "users" + t = template(t"SELECT * FROM {table!s}") + compiled = t.compile() + assert compiled.sql == "SELECT * FROM users" + assert compiled.params == {} + + def test_repr_conversion_r(self): + val = "hello" + t = template(t"SELECT {val!r}") + compiled = t.compile() + assert "'hello'" in compiled.sql + assert compiled.params == {} + + def test_ascii_conversion_a(self): + val = "café" + t = template(t"SELECT {val!a}") + compiled = t.compile() + assert compiled.params == {} + assert "\\xe9" in compiled.sql or "caf" in compiled.sql + + def test_string_value_becomes_param_not_raw_sql(self): + name = "Alice" + t = template(t"SELECT * FROM users WHERE name = {name}") + compiled = t.compile() + assert "Alice" not in compiled.sql + assert "Alice" in compiled.params.values() + + def test_format_spec_on_conversion(self): + val = 3.14159 + t = template(t"SELECT {val!s:.5}") + compiled = t.compile() + assert compiled.params == {} + assert "3.141" in compiled.sql + + def test_format_spec_without_conversion_is_ignored(self): + val = 3.14159 + t = template(t"SELECT {val:.2f}") + compiled = t.compile() + assert 3.14159 in compiled.params.values() + assert "3.14" not in compiled.sql + + +class TestTStringNesting: + def test_nested_template_via_supports_duckdb_template(self): + inner = SimpleRelation("SELECT * FROM people") + t = template(t"SELECT name FROM ({inner})") + compiled = t.compile() + assert "SELECT * FROM people" in compiled.sql + assert compiled.params == {} + + def test_nested_tstring_chaining(self): + age = 18 + inner = template(t"SELECT * FROM people WHERE age >= {age}") + t = template(t"SELECT name FROM ({inner})") + compiled = t.compile() + assert "SELECT * FROM people WHERE age >= $" in compiled.sql + assert 18 in compiled.params.values() + + def test_duckdb_template_returns_tstring(self): + class TStringReturner: + def __duckdb_template__(self, **kwargs): + val = 42 + return t"SELECT {val}" + + t = template(TStringReturner()) + compiled = t.compile() + assert 42 in compiled.params.values() + + +class TestTStringEdgeCases: + def test_float_param(self): + threshold = 3.14 + result = template(t"SELECT * FROM t WHERE score > {threshold}").compile() + assert 3.14 in result.params.values() + + def test_none_param(self): + val = None + result = template(t"SELECT * FROM t WHERE x IS {val}").compile() + assert None in result.params.values() + + def test_param_object_preserves_name(self): + p = Param(value=42, name="custom_name") + result = template(t"SELECT {p}").compile() + assert 42 in result.params.values() + + def test_same_variable_used_twice(self): + x = 42 + result = template(t"SELECT * FROM t WHERE a = {x} AND b = {x}").compile() + assert len(result.params) == 2 + assert all(v == 42 for v in result.params.values()) + + def test_mixed_conversion_and_param(self): + table = "users" + user_id = 5 + result = template(t"SELECT * FROM {table!s} WHERE id = {user_id}").compile() + assert "users" in result.sql + assert "5" not in result.sql + assert 5 in result.params.values() + + def test_sql_injection_prevented_by_default(self): + evil = "1; DROP TABLE users; --" + result = template(t"SELECT * FROM users WHERE id = {evil}").compile() + assert "DROP TABLE" not in result.sql + assert evil in result.params.values() + + def test_sql_injection_possible_with_s_conversion(self): + evil = "1; DROP TABLE users; --" + result = template(t"SELECT * FROM users WHERE id = {evil!s}").compile() + assert "DROP TABLE" in result.sql From 373f0aa0211dabba778e63f3d83349eb587ecd42 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 16:59:07 -0800 Subject: [PATCH 14/29] move template.py into actual package --- template.py => duckdb/template.py | 0 tests/fast/test_template.py | 2 +- tests/fast/test_template_tstrings.py314 | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename template.py => duckdb/template.py (100%) diff --git a/template.py b/duckdb/template.py similarity index 100% rename from template.py rename to duckdb/template.py diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 94ebd80c..54d3ea4b 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -4,7 +4,7 @@ import pytest -from template import ( +from duckdb.template import ( CompiledSql, IntoInterpolation, Param, diff --git a/tests/fast/test_template_tstrings.py314 b/tests/fast/test_template_tstrings.py314 index 2cae9e06..27d9c0d9 100644 --- a/tests/fast/test_template_tstrings.py314 +++ b/tests/fast/test_template_tstrings.py314 @@ -6,7 +6,7 @@ import sys import pytest -from template import ( +from duckdb.template import ( Param, param, template, From 843987e685d0ca59853f6c11342111dd0025c618 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 17:08:17 -0800 Subject: [PATCH 15/29] Improve CompiledSql ergonomics and docstrings --- duckdb/template.py | 31 +++++++++++++++++++++++++++---- tests/fast/test_template.py | 13 ++++++++++++- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/duckdb/template.py b/duckdb/template.py index 09b27c7f..9b302991 100644 --- a/duckdb/template.py +++ b/duckdb/template.py @@ -15,6 +15,7 @@ "IntoInterpolation", "Param", "SupportsDuckdbTemplate", + "compile", "param", "template", ] @@ -22,10 +23,21 @@ @dataclasses.dataclass(frozen=True, slots=True) class CompiledSql: - """Represents a compiled SQL statement, with the final SQL string and a list of Params to be passed to duckdb.""" + """Represents a compiled SQL statement, with the final SQL string and a dict of params to be passed to duckdb. + + You will typically not create this directly, but will get it as the result + of calling .compile() on a SqlTemplate or ResolvedSqlTemplate. + + Example: + >>> age = 37 + >>> c = compile(t"SELECT * FROM users WHERE age >= {age}") + >>> c + >>> CompiledSql(sql="SELECT * FROM users WHERE age >= $p0_age", params={"p0_age": 37}) + duckdb.query(c.sql, c.params) + """ sql: str - params: dict[str, object] + params: dict[str, object] = dataclasses.field(default_factory=dict) @runtime_checkable @@ -114,7 +126,7 @@ def template(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | o This is very useful for versions of python before 3.14 that don't have tstrings, since it allows you to build up a template from smaller pieces: - >>> t = template(["SELECT * FROM (", all_people, ") WHERE age >= ", age]) + >>> t = template("SELECT * FROM (", all_people, ") WHERE age >= ", age) >>> t.compile() CompiledSql(sql='SELECT * FROM (SELECT * FROM people) WHERE age >= $p0_age', params={'p0_age': 18}) @@ -138,6 +150,17 @@ def template(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | o return SqlTemplate(*expanded) +def compile(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | object) -> CompiledSql: + """Compile a sequence of things into a final SQL string with named parameter placeholders, and a list of Params. + + This is a convenience function that combines template() and .compile() into one step. + + For more details and examples, see template(). + """ + t = template(*part) + return t.compile() + + def _expand_part(part: object) -> Iterable[str | IntoInterpolation]: if isinstance(part, SupportsDuckdbTemplate): raw = part.__duckdb_template__() @@ -164,7 +187,7 @@ def _expand_part(part: object) -> Iterable[str | IntoInterpolation]: class ParamInterpolation: - """A simple wrapper that implements the IntoInterpolation protocol for a given IntoParam.""" + """A simple wrapper that implements the IntoInterpolation protocol for a given Param.""" def __init__(self, param: Param): # noqa: ANN204 self.value = param diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 54d3ea4b..9c9a9aaa 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -745,6 +745,16 @@ def test_basic(self): assert c.sql == "SELECT $p0" assert c.params == {"p0": 42} + def test_empty_params(self): + c = CompiledSql(sql="SELECT 1", params={}) + assert c.sql == "SELECT 1" + assert c.params == {} + + def test_optional_params(self): + c = CompiledSql(sql="SELECT 1") + assert c.sql == "SELECT 1" + assert c.params == {} + def test_frozen(self): c = CompiledSql(sql="SELECT 1", params={}) with pytest.raises(AttributeError): @@ -757,7 +767,8 @@ def test_equality(self): def test_repr(self): c = CompiledSql(sql="SELECT 1", params={}) - assert "SELECT 1" in repr(c) + expected = "CompiledSql(sql='SELECT 1', params={})" + assert repr(c) == expected # ═══════════════════════════════════════════════════════════════════════════════ From ce621a6fe061fd633c65bad5306184da8775eedb Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 17:33:01 -0800 Subject: [PATCH 16/29] fixup how parts are expanded --- duckdb/template.py | 11 +++++------ tests/fast/test_template.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/duckdb/template.py b/duckdb/template.py index 9b302991..916e7c79 100644 --- a/duckdb/template.py +++ b/duckdb/template.py @@ -80,7 +80,7 @@ def param(value: object, name: str | None = None, *, exact: bool = False) -> Par return Param(value=value, name=name, exact=exact) -def template(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | object) -> SqlTemplate: +def template(*parts: str | IntoInterpolation | Param | SupportsDuckdbTemplate | object) -> SqlTemplate: """Convert a sequence of things into a SqlTemplate. We go through the parts and convert it into a sequence of str and Interpolations, @@ -146,7 +146,9 @@ def template(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | o >>> t.compile() CompiledSql(sql='SELECT * FROM users WHERE id = $p0_id', params={'p0_id': 123}) """ # noqa: E501 - expanded = _expand_part(part) + expanded = [] + for part in parts: + expanded.extend(_expand_part(part)) return SqlTemplate(*expanded) @@ -290,10 +292,7 @@ def assert_param_name_legal(name: str) -> None: class SqlTemplate: """A sequence of strings and Interpolations.""" - def __init__( - self, - *parts: str | IntoInterpolation, - ) -> None: + def __init__(self, *parts: str | IntoInterpolation) -> None: self.strings, self.interpolations = parse_parts(parts) def __iter__(self) -> Iterator[str | IntoInterpolation]: diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 9c9a9aaa..941d3389 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -779,7 +779,7 @@ def test_repr(self): class TestEndToEndCompile: def test_plain_sql(self): result = template("SELECT * FROM users").compile() - assert result == CompiledSql(sql="SELECT * FROM users", params={}) + assert result == CompiledSql("SELECT * FROM users", {}) def test_param_in_list(self): result = template(["SELECT * FROM users WHERE id = ", Param(value=5, name="id")]).compile() From 9269020a5fbdd26763957fa896f8323a9e36cd07 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 17:38:33 -0800 Subject: [PATCH 17/29] fixup: template(*mylist), not template(mylist) --- tests/fast/test_template.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 941d3389..b7c6095c 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -788,12 +788,10 @@ def test_param_in_list(self): def test_multiple_params_in_list(self): result = template( - [ - "SELECT * FROM users WHERE name = ", - Param(value="Alice", name="name"), - " AND age > ", - Param(value=18, name="age"), - ] + "SELECT * FROM users WHERE name = ", + Param(value="Alice", name="name"), + " AND age > ", + Param(value=18, name="age"), ).compile() assert result.sql == "SELECT * FROM users WHERE name = $p0_name AND age > $p1_age" assert result.params == {"p0_name": "Alice", "p1_age": 18} @@ -867,10 +865,8 @@ def test_sql_injection_possible_with_s_conversion(self): def test_exact_param_name(self): result = template( - [ - "SELECT * FROM t WHERE id = ", - Param(value=42, name="my_id", exact=True), - ] + "SELECT * FROM t WHERE id = ", + Param(value=42, name="my_id", exact=True), ).compile() assert result.sql == "SELECT * FROM t WHERE id = $my_id" assert result.params == {"my_id": 42} @@ -887,16 +883,16 @@ def test_empty_string_template(self): assert result == CompiledSql(sql="", params={}) def test_param_with_none_value(self): - result = template([Param(value=None, name="x")]).compile() + result = template(Param(value=None, name="x")).compile() assert result.params["p0_x"] is None def test_param_with_list_value(self): - result = template([Param(value=[1, 2, 3], name="ids")]).compile() + result = template(Param(value=[1, 2, 3], name="ids")).compile() assert result.params["p0_ids"] == [1, 2, 3] def test_param_with_dict_value(self): d = {"key": "value"} - result = template([Param(value=d, name="data")]).compile() + result = template(Param(value=d, name="data")).compile() assert result.params["p0_data"] == d def test_bool_param(self): From d9a253740bf46bb639ef54d77a77f145700b7da0 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 17:40:54 -0800 Subject: [PATCH 18/29] fix a few tests --- tests/fast/test_template.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index b7c6095c..2b7dae84 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -208,19 +208,20 @@ class TestSqlTemplateConstruction: def test_plain_string(self): t = SqlTemplate("SELECT 1") assert t.strings == ("SELECT 1",) - assert t.interpolations == [] + assert t.interpolations == () def test_multiple_strings_merged(self): t = SqlTemplate("SELECT ", "1") assert t.strings == ("SELECT 1",) - assert t.interpolations == [] + assert t.interpolations == () - def test_with_param(self): + def test_bare_param_raises(self): p = Param(value=42, name="x") - t = SqlTemplate("SELECT ", p, " FROM t") - assert t.strings == ("SELECT ", " FROM t") + with pytest.raises(TypeError, match="Unexpected part type"): + SqlTemplate("SELECT ", p, " FROM t") # ty:ignore[invalid-argument-type] + wrapped = ParamInterpolation(p) + t = SqlTemplate("SELECT ", wrapped, " FROM t") assert len(t.interpolations) == 1 - assert t.interpolations[0].value is p def test_with_interpolation(self): interp = FakeInterpolation(value=42, expression="x") @@ -242,7 +243,7 @@ def test_no_args(self): """Empty SqlTemplate should produce a single empty string.""" t = SqlTemplate() assert t.strings == ("",) - assert t.interpolations == [] + assert t.interpolations == () # ═══════════════════════════════════════════════════════════════════════════════ From 4a10dfa73ebdc99552edd678c899d284e33486d6 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 17:42:50 -0800 Subject: [PATCH 19/29] fixup --- tests/fast/test_template.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 2b7dae84..19c951d7 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -215,14 +215,6 @@ def test_multiple_strings_merged(self): assert t.strings == ("SELECT 1",) assert t.interpolations == () - def test_bare_param_raises(self): - p = Param(value=42, name="x") - with pytest.raises(TypeError, match="Unexpected part type"): - SqlTemplate("SELECT ", p, " FROM t") # ty:ignore[invalid-argument-type] - wrapped = ParamInterpolation(p) - t = SqlTemplate("SELECT ", wrapped, " FROM t") - assert len(t.interpolations) == 1 - def test_with_interpolation(self): interp = FakeInterpolation(value=42, expression="x") t = SqlTemplate("a ", interp, " b") @@ -234,6 +226,11 @@ def test_bare_param_errors(self): with pytest.raises(TypeError, match="Unexpected part type"): SqlTemplate("SELECT ", p) # ty:ignore[invalid-argument-type] + def test_wrapped_param(self): + wrapped = ParamInterpolation(Param(value=42, name="x")) + t = SqlTemplate("SELECT ", wrapped, " FROM t") + assert len(t.interpolations) == 1 + def test_rejects_invalid_types(self): """Items that are not str, IntoInterpolation, or Param should raise TypeError.""" with pytest.raises(TypeError, match="Unexpected part type"): From 6b62b938839edd4f77dac111fd858a7197e4b08d Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 17:43:44 -0800 Subject: [PATCH 20/29] fixup typecheck error --- tests/fast/test_template.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 19c951d7..e2a753ab 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -255,8 +255,7 @@ def test_iter_plain_string(self): assert parts == ["hello"] def test_iter_with_interpolations(self): - p = Param(value=1, name="x") - t = SqlTemplate("a ", p, " b") + t = template("a ", param(1, "x"), " b") parts = list(t) assert len(parts) == 3 assert parts[0] == "a " From 11e5e3d34aa4ccc60fe02818ee16ee7d5630706c Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 18:09:03 -0800 Subject: [PATCH 21/29] improve test style --- tests/fast/test_template.py | 329 ++++++++++++++++++------------------ 1 file changed, 166 insertions(+), 163 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index e2a753ab..b5943e2c 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -3,7 +3,9 @@ from __future__ import annotations import pytest +from sympy import comp +from duckdb.experimental.spark.sql.functions import exp from duckdb.template import ( CompiledSql, IntoInterpolation, @@ -12,6 +14,7 @@ ResolvedSqlTemplate, SqlTemplate, SupportsDuckdbTemplate, + compile, compile_parts, param, parse_parts, @@ -41,6 +44,11 @@ def __duckdb_template__(self, **kwargs) -> str: return self._sql +class Cafe: + def __ascii__(self) -> str: + return "cafe" + + # ═══════════════════════════════════════════════════════════════════════════════ # Param # ═══════════════════════════════════════════════════════════════════════════════ @@ -283,28 +291,29 @@ class TestTemplateFactory: def test_plain_string(self): t = template("SELECT 1") compiled = t.compile() - assert compiled.sql == "SELECT 1" - assert compiled.params == {} + expected = CompiledSql("SELECT 1", {}) + assert expected == compiled def test_param(self): p = Param(value=42, name="answer") t = template(p) compiled = t.compile() - assert "$" in compiled.sql - assert 42 in compiled.params.values() + expected = CompiledSql("$p0_answer", {"p0_answer": 42}) + assert expected == compiled def test_into_interpolation(self): interp = FakeInterpolation(value=42, expression="x") t = template(interp) compiled = t.compile() - assert 42 in compiled.params.values() + expected = CompiledSql("$p0_x", {"p0_x": 42}) + assert expected == compiled def test_supports_duckdb_template_string(self): rel = SimpleRelation("SELECT 1") t = template(rel) compiled = t.compile() - assert compiled.sql == "SELECT 1" - assert compiled.params == {} + expected = CompiledSql("SELECT 1", {}) + assert expected == compiled def test_supports_duckdb_template_returns_iterable(self): class MultiPart: @@ -313,8 +322,8 @@ def __duckdb_template__(self, **kwargs) -> list[str | Param]: t = template(MultiPart()) compiled = t.compile() - assert "$" in compiled.sql - assert 42 in compiled.params.values() + expected = CompiledSql("SELECT * FROM $p0_x", {"p0_x": 42}) + assert expected == compiled def test_supports_duckdb_template_returns_interpolation(self): class InterpReturner: @@ -323,33 +332,34 @@ def __duckdb_template__(self, **kwargs) -> FakeInterpolation: t = template(InterpReturner()) compiled = t.compile() - assert "hello" in compiled.params.values() + expected = CompiledSql("$p0_val", {"p0_val": "hello"}) + assert expected == compiled def test_iterable_of_strings(self): t = template("SELECT ", "1") compiled = t.compile() - assert compiled.sql == "SELECT 1" - assert compiled.params == {} + expected = CompiledSql("SELECT 1", {}) + assert expected == compiled def test_iterable_with_params(self): t = template("SELECT * FROM t WHERE id = ", Param(value=5, name="id")) compiled = t.compile() - assert compiled.params - assert 5 in compiled.params.values() + expected = CompiledSql("SELECT * FROM t WHERE id = $p0_id", {"p0_id": 5}) + assert expected == compiled def test_iterable_with_bare_values(self): """Bare values in an iterable should be treated as params.""" t = template("SELECT * FROM t WHERE id = ", 42) compiled = t.compile() - assert 42 in compiled.params.values() - assert "42" not in compiled.sql + expected = CompiledSql("SELECT * FROM t WHERE id = $p0", {"p0": 42}) + assert expected == compiled def test_bare_value_at_top_level_becomes_param(self): """A bare value passed directly to template() becomes a param.""" t = template(42) # type: ignore[arg-type] compiled = t.compile() - assert compiled.params - assert 42 in compiled.params.values() + expected = CompiledSql("$p0", {"p0": 42}) + assert expected == compiled def test_bytes_not_treated_as_iterable(self): """Bytes should not be iterated — _is_iterable excludes bytes.""" @@ -369,69 +379,61 @@ class TestTemplateWithInterpolations: def test_simple_param(self): interp = FakeInterpolation(value=123, expression="user_id") - t = template("SELECT * FROM users WHERE id = ", interp, "") - compiled = t.compile() - assert "user_id" in compiled.sql - assert "$" in compiled.sql - assert 123 in compiled.params.values() + compiled = compile("SELECT * FROM users WHERE id = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_user_id", {"p0_user_id": 123}) + assert expected == compiled def test_multiple_params(self): name_interp = FakeInterpolation(value="Alice", expression="name") age_interp = FakeInterpolation(value=30, expression="age") t = template("SELECT * FROM users WHERE name = ", name_interp, " AND age = ", age_interp, "") compiled = t.compile() - assert len(compiled.params) == 2 - assert "Alice" in compiled.params.values() - assert 30 in compiled.params.values() + expected = CompiledSql( + "SELECT * FROM users WHERE name = $p0_name AND age = $p1_age", + {"p0_name": "Alice", "p1_age": 30}, + ) + assert expected == compiled def test_no_params(self): - t = template("SELECT 1") - compiled = t.compile() - assert compiled.sql == "SELECT 1" - assert compiled.params == {} + compiled = compile("SELECT 1") + expected = CompiledSql("SELECT 1", {}) + assert expected == compiled def test_string_conversion_s(self): """!s should inline the value as raw SQL (no param).""" interp = FakeInterpolation(value="users", expression="table", conversion="s") - t = template("SELECT * FROM ", interp, "") - compiled = t.compile() - assert compiled.sql == "SELECT * FROM users" - assert compiled.params == {} + compiled = compile("SELECT * FROM ", interp, "") + expected = CompiledSql("SELECT * FROM users", {}) + assert expected == compiled def test_repr_conversion_r(self): """!r should inline repr(value) as raw SQL.""" interp = FakeInterpolation(value="hello", expression="val", conversion="r") - t = template("SELECT ", interp, "") - compiled = t.compile() + compiled = compile("SELECT ", interp, "") # repr of "hello" is "'hello'" - assert "'hello'" in compiled.sql - assert compiled.params == {} + expected = CompiledSql("SELECT 'hello'", {}) + assert expected == compiled def test_ascii_conversion_a(self): - """!a should inline ascii(value) as raw SQL.""" - interp = FakeInterpolation(value="café", expression="val", conversion="a") - t = template("SELECT ", interp, "") - compiled = t.compile() - assert compiled.params == {} - # ascii() should escape non-ASCII - assert "\\xe9" in compiled.sql or "caf" in compiled.sql + interp = FakeInterpolation(value=Cafe(), expression="val", conversion="a") + compiled = compile("SELECT ", interp) + expected = CompiledSql("SELECT cafe", {}) + assert expected == compiled def test_string_value_becomes_param_not_raw_sql(self): """A string interpolation WITHOUT conversion should be a param, not raw SQL.""" interp = FakeInterpolation(value="Alice", expression="name") - t = template("SELECT * FROM users WHERE name = ", interp, "") - compiled = t.compile() - assert "Alice" not in compiled.sql - assert "Alice" in compiled.params.values() + compiled = compile("SELECT * FROM users WHERE name = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE name = $p0_name", {"p0_name": "Alice"}) + assert expected == compiled def test_nested_template_via_supports_duckdb_template(self): """An interpolated SupportsDuckdbTemplate should be expanded inline.""" inner = SimpleRelation("SELECT * FROM people") interp = FakeInterpolation(value=inner, expression="inner") - t = template("SELECT name FROM (", interp, ")") - compiled = t.compile() - assert "SELECT * FROM people" in compiled.sql - assert compiled.params == {} + compiled = compile("SELECT name FROM (", interp, ")") + expected = CompiledSql("SELECT name FROM (SELECT * FROM people)", {}) + assert expected == compiled def test_nested_chaining(self): """Chaining templates: inner template params should propagate.""" @@ -439,6 +441,7 @@ def test_nested_chaining(self): interp = FakeInterpolation(value=inner, expression="inner") t = template("SELECT name FROM (", interp, ")") compiled = t.compile() + # Skipping CompiledSql equality: exact param name depends on counter offsets across nesting levels assert "SELECT * FROM people WHERE age >= $" in compiled.sql assert 18 in compiled.params.values() @@ -447,8 +450,8 @@ def test_param_name_derived_from_expression(self): interp = FakeInterpolation(value=99, expression="my_value") t = template("SELECT ", interp, "") compiled = t.compile() - # The name should contain "my_value" (with a prefix) - assert any("my_value" in k for k in compiled.params) + expected = CompiledSql("SELECT $p0_my_value", {"p0_my_value": 99}) + assert expected == compiled def test_explicit_param_in_interpolation(self): """An explicit Param() used in an interpolation should be treated as a param.""" @@ -456,6 +459,7 @@ def test_explicit_param_in_interpolation(self): interp = FakeInterpolation(value=p, expression="p") t = template("SELECT ", interp, "") compiled = t.compile() + # Skipping exact equality: param name may come from Param.name or interpolation expression assert 42 in compiled.params.values() def test_format_spec_on_conversion(self): @@ -464,18 +468,16 @@ def test_format_spec_on_conversion(self): t = template("SELECT ", interp, "") compiled = t.compile() # Python semantics: str(3.14159) = "3.14159", then format("3.14159", ".5") = "3.141" - assert compiled.params == {} - assert "3.141" in compiled.sql + expected = CompiledSql("SELECT 3.141", {}) + assert expected == compiled def test_format_spec_without_conversion_is_ignored(self): """Format spec without conversion should ideally apply, but currently it's silently dropped.""" interp = FakeInterpolation(value=3.14159, expression="val", format_spec=".2f") t = template("SELECT ", interp, "") compiled = t.compile() - # The format spec is currently ignored — val becomes a param with its original value - assert 3.14159 in compiled.params.values() - # The formatted "3.14" is NOT in the SQL — it's parameterized - assert "3.14" not in compiled.sql + expected = CompiledSql("SELECT $p0_val", {"p0_val": 3.14159}) + assert expected == compiled # ═══════════════════════════════════════════════════════════════════════════════ @@ -500,48 +502,37 @@ def test_conversion_r_resolves_to_repr(self): interp = FakeInterpolation(value="hello", expression="x", conversion="r") t = SqlTemplate("SELECT ", interp) resolved = t.resolve() - parts = list(resolved) - joined = "".join(p if isinstance(p, str) else "" for p in parts) - assert "'hello'" in joined + expected = CompiledSql(sql="SELECT 'hello'", params={}) def test_conversion_a_resolves_to_ascii(self): - interp = FakeInterpolation(value="café", expression="x", conversion="a") - t = SqlTemplate("SELECT ", interp) - resolved = t.resolve() - parts = list(resolved) - joined = "".join(p if isinstance(p, str) else "" for p in parts) - # ascii("café") → "'caf\\xe9'" - assert "caf" in joined + interp = FakeInterpolation(value=Cafe(), expression="x", conversion="a") + actual = compile("SELECT ", interp) + expected = CompiledSql(sql="SELECT cafe") + assert actual == expected def test_string_value_without_conversion_becomes_param(self): """A string value in an interpolation (no conversion) must become a param, not raw SQL.""" interp = FakeInterpolation(value="DROP TABLE users", expression="val") t = SqlTemplate("SELECT ", interp) resolved = t.resolve() - parts = list(resolved) - # Must NOT inline "DROP TABLE users" as raw SQL - param_parts = [p for p in parts if isinstance(p, Param)] - assert len(param_parts) == 1 - assert param_parts[0].value == "DROP TABLE users" + expected = CompiledSql(sql="SELECT $p0_val", params={"p0_val": "DROP TABLE users"}) + assert resolved == expected def test_nested_supports_duckdb_template(self): rel = SimpleRelation("SELECT 1") interp = FakeInterpolation(value=rel, expression="rel") t = SqlTemplate("SELECT * FROM (", interp, ")") resolved = t.resolve() - parts = list(resolved) - sql = "".join(p if isinstance(p, str) else f"${p.name}" for p in parts) - assert "SELECT 1" in sql + expected = CompiledSql(sql="SELECT * FROM (SELECT 1)", params={}) + assert resolved == expected def test_expression_name_preserved_for_simple_param(self): """When interpolation resolves to a single param, the expression name should be kept.""" interp = FakeInterpolation(value=42, expression="my_age") t = SqlTemplate("age = ", interp) resolved = t.resolve() - parts = list(resolved) - param_parts = [p for p in parts if isinstance(p, Param)] - assert len(param_parts) == 1 - assert param_parts[0].name == "my_age" + expected = CompiledSql(sql="age = $p0_my_age", params={"p0_my_age": 42}) + assert resolved == expected # ═══════════════════════════════════════════════════════════════════════════════ @@ -590,30 +581,23 @@ def test_all_strings(self): def test_single_unnamed_param(self): result = compile_parts(["SELECT ", Param(value=42)]) - assert result.sql == "SELECT $p0" - assert result.params == {"p0": 42} + expected = CompiledSql(sql="SELECT $p0", params={"p0": 42}) + assert result == expected def test_single_named_param(self): result = compile_parts(["SELECT ", Param(value=42, name="x")]) - assert result.sql == "SELECT $p0_x" - assert result.params == {"p0_x": 42} + expected = CompiledSql(sql="SELECT $p0_x", params={"p0_x": 42}) + assert result == expected def test_exact_param_uses_literal_name(self): result = compile_parts(["SELECT ", Param(value=42, name="my_param", exact=True)]) - assert result.sql == "SELECT $my_param" - assert result.params == {"my_param": 42} + expected = CompiledSql(sql="SELECT $my_param", params={"my_param": 42}) + assert result == expected def test_multiple_params_numbered_sequentially(self): - result = compile_parts( - [ - "SELECT * WHERE a = ", - Param(value=1, name="a"), - " AND b = ", - Param(value=2, name="b"), - ] - ) - assert result.sql == "SELECT * WHERE a = $p0_a AND b = $p1_b" - assert result.params == {"p0_a": 1, "p1_b": 2} + result = compile_parts(["SELECT * WHERE a = ", Param(value=1, name="a"), " AND b = ", Param(value=2, name="b")]) + expected = CompiledSql(sql="SELECT * WHERE a = $p0_a AND b = $p1_b", params={"p0_a": 1, "p1_b": 2}) + assert result == expected def test_duplicate_param_names_raises(self): with pytest.raises(ValueError, match="Duplicate parameter names"): @@ -633,8 +617,8 @@ def test_unnamed_params_get_sequential_names(self): Param(value=2), ] ) - assert result.sql == "a = $p0 AND b = $p1" - assert result.params == {"p0": 1, "p1": 2} + expected = CompiledSql(sql="a = $p0 AND b = $p1", params={"p0": 1, "p1": 2}) + assert result == expected def test_exact_param_causes_counter_gap(self): """Known issue: exact params still increment the counter, causing non-sequential auto names.""" @@ -648,19 +632,23 @@ def test_exact_param_causes_counter_gap(self): Param(value=3, name="y"), # → p2_y (not p1_y!) ] ) - assert result.params == {"p0_x": 1, "b": 2, "p2_y": 3} + expected = CompiledSql(sql="a = $p0_x AND b = $b AND c = $p2_y", params={"p0_x": 1, "b": 2, "p2_y": 3}) + assert result == expected def test_empty_parts(self): result = compile_parts([]) - assert result == CompiledSql(sql="", params={}) + expected = CompiledSql(sql="", params={}) + assert result == expected def test_adjacent_strings(self): result = compile_parts(["SELECT ", "1"]) - assert result.sql == "SELECT 1" + expected = CompiledSql(sql="SELECT 1", params={}) + assert result == expected def test_param_with_none_value(self): result = compile_parts(["SELECT ", Param(value=None, name="x")]) - assert result.params == {"p0_x": None} + expected = CompiledSql(sql="SELECT $p0_x", params={"p0_x": None}) + assert result == expected # ═══════════════════════════════════════════════════════════════════════════════ @@ -742,6 +730,11 @@ def test_basic(self): assert c.sql == "SELECT $p0" assert c.params == {"p0": 42} + def test_positional_params(self): + c = CompiledSql("SELECT $p0", {"p0": 42}) + assert c.sql == "SELECT $p0" + assert c.params == {"p0": 42} + def test_empty_params(self): c = CompiledSql(sql="SELECT 1", params={}) assert c.sql == "SELECT 1" @@ -775,50 +768,57 @@ def test_repr(self): class TestEndToEndCompile: def test_plain_sql(self): - result = template("SELECT * FROM users").compile() + result = compile("SELECT * FROM users") + assert result == CompiledSql("SELECT * FROM users", {}) + + def test_strings_joined(self): + result = compile("SELECT * FROM ", "users") assert result == CompiledSql("SELECT * FROM users", {}) def test_param_in_list(self): - result = template(["SELECT * FROM users WHERE id = ", Param(value=5, name="id")]).compile() - assert result.sql == "SELECT * FROM users WHERE id = $p0_id" - assert result.params == {"p0_id": 5} + result = compile("SELECT * FROM users WHERE id = ", Param(value=5, name="id")) + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_id", {"p0_id": 5}) + assert expected == result def test_multiple_params_in_list(self): - result = template( + result = compile( "SELECT * FROM users WHERE name = ", Param(value="Alice", name="name"), " AND age > ", Param(value=18, name="age"), - ).compile() - assert result.sql == "SELECT * FROM users WHERE name = $p0_name AND age > $p1_age" - assert result.params == {"p0_name": "Alice", "p1_age": 18} + ) + expected = CompiledSql( + "SELECT * FROM users WHERE name = $p0_name AND age > $p1_age", + {"p0_name": "Alice", "p1_age": 18}, + ) + assert expected == result def test_supports_duckdb_template_end_to_end(self): class MyTable: def __duckdb_template__(self, **kwargs) -> str: return "SELECT * FROM my_table" - result = template(MyTable()).compile() - assert result.sql == "SELECT * FROM my_table" - assert result.params == {} + result = compile(MyTable()) + expected = CompiledSql("SELECT * FROM my_table", {}) + assert expected == result def test_interpolations_end_to_end(self): - result = template("SELECT * FROM users WHERE id = ", 42, " AND name = ", "Alice").compile() - assert 42 in result.params.values() - assert "Alice" in result.params.values() - # The raw values should NOT be in the SQL - assert "42" not in result.sql - assert "Alice" not in result.sql + result = compile("SELECT * FROM users WHERE id = ", 42, " AND name = ", "Alice") + expected = CompiledSql( + "SELECT * FROM users WHERE id = $p0 AND name = $p1", + {"p0": 42, "p1": "Alice"}, + ) + assert expected == result def test_nested_template_relations(self): """The key use case: chaining template queries.""" inner = template("SELECT * FROM people WHERE age >= ", 18) outer = template("SELECT name FROM (", inner, ")") result = outer.compile() + # Skipping exact equality: param name depends on counter offsets across nesting levels assert "SELECT * FROM people WHERE age >=" in result.sql assert "SELECT name FROM (" in result.sql assert 18 in result.params.values() - # Only one param assert len(result.params) == 1 def test_deeply_nested_tstrings(self): @@ -828,6 +828,7 @@ def test_deeply_nested_tstrings(self): level2 = template("SELECT * FROM (", level1, ") WHERE y = 1") level3 = template("SELECT count(*) FROM (", level2, ")") result = level3.compile() + # Skipping exact equality: param name depends on counter offsets across nesting levels assert "SELECT * FROM t WHERE x = $" in result.sql assert 100 in result.params.values() assert len(result.params) == 1 @@ -840,6 +841,7 @@ def test_multiple_nested_with_separate_params(self): t2 = template("SELECT * FROM t2 WHERE b = ", b) outer = template("SELECT * FROM (", t1, ") JOIN (", t2, ")") result = outer.compile() + # Skipping exact equality: param names depend on counter offsets across nesting levels assert 1 in result.params.values() assert 2 in result.params.values() assert len(result.params) == 2 @@ -848,25 +850,22 @@ def test_sql_injection_prevented_by_default(self): """Without !s, even SQL-like strings become params, not raw SQL.""" evil = "1; DROP TABLE users; --" interp = FakeInterpolation(value=evil, expression="evil") - result = template("SELECT * FROM users WHERE id = ", interp, "").compile() - # The evil string should be a param, not inlined - assert "DROP TABLE" not in result.sql - assert evil in result.params.values() + result = compile("SELECT * FROM users WHERE id = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_evil", {"p0_evil": evil}) + assert expected == result def test_sql_injection_possible_with_s_conversion(self): """With !s, values ARE inlined — this is intentional but dangerous.""" evil = "1; DROP TABLE users; --" interp = FakeInterpolation(value=evil, expression="evil", conversion="s") - result = template("SELECT * FROM users WHERE id = ", interp, "").compile() - assert "DROP TABLE" in result.sql + result = compile("SELECT * FROM users WHERE id = ", interp, "") + expected = CompiledSql("SELECT * FROM users WHERE id = 1; DROP TABLE users; --", {}) + assert expected == result def test_exact_param_name(self): - result = template( - "SELECT * FROM t WHERE id = ", - Param(value=42, name="my_id", exact=True), - ).compile() - assert result.sql == "SELECT * FROM t WHERE id = $my_id" - assert result.params == {"my_id": 42} + result = compile("SELECT * FROM t WHERE id = ", param(value=42, name="my_id", exact=True)) + expected = CompiledSql("SELECT * FROM t WHERE id = $my_id", {"my_id": 42}) + assert expected == result # ═══════════════════════════════════════════════════════════════════════════════ @@ -876,42 +875,50 @@ def test_exact_param_name(self): class TestEdgeCases: def test_empty_string_template(self): - result = template("").compile() - assert result == CompiledSql(sql="", params={}) + result = compile("") + expected = CompiledSql(sql="", params={}) + assert result == expected def test_param_with_none_value(self): - result = template(Param(value=None, name="x")).compile() - assert result.params["p0_x"] is None + result = compile(param(value=None, name="x")) + expected = CompiledSql(sql="$p0_x", params={"p0_x": None}) + assert result == expected def test_param_with_list_value(self): - result = template(Param(value=[1, 2, 3], name="ids")).compile() - assert result.params["p0_ids"] == [1, 2, 3] + result = compile(param(value=[1, 2, 3], name="ids")) + expected = CompiledSql(sql="$p0_ids", params={"p0_ids": [1, 2, 3]}) + assert result == expected def test_param_with_dict_value(self): d = {"key": "value"} - result = template(Param(value=d, name="data")).compile() - assert result.params["p0_data"] == d + result = compile(param(value=d, name="data")) + expected = CompiledSql(sql="$p0_data", params={"p0_data": d}) + assert result == expected def test_bool_param(self): - result = template("SELECT * FROM t WHERE active = ", True).compile() - assert result == CompiledSql(sql="SELECT * FROM t WHERE active = $p0", params={"p0": True}) + result = compile("SELECT * FROM t WHERE active = ", True) + expected = CompiledSql(sql="SELECT * FROM t WHERE active = $p0", params={"p0": True}) + assert result == expected def test_float_param(self): interp = FakeInterpolation(value=3.14, expression="threshold") result = template("SELECT * FROM t WHERE score > ", interp, "").compile() - assert 3.14 in result.params.values() + expected = CompiledSql(sql="SELECT * FROM t WHERE score > $p0_threshold", params={"p0_threshold": 3.14}) + assert result == expected def test_none_param(self): interp = FakeInterpolation(value=None, expression="val") result = template("SELECT * FROM t WHERE x IS ", interp, "").compile() - assert None in result.params.values() + expected = CompiledSql(sql="SELECT * FROM t WHERE x IS $p0_val", params={"p0_val": None}) + assert result == expected def test_param_object_in_interpolation_preserves_name(self): """An explicit Param used in an interpolation should keep its name.""" p = Param(value=42, name="custom_name") interp = FakeInterpolation(value=p, expression="p") result = template("SELECT ", interp, "").compile() - assert 42 in result.params.values() + expected = CompiledSql(sql="SELECT $custom_name", params={"custom_name": 42}) + assert result == expected def test_supports_duckdb_template_priority_over_iterable(self): class IterableWithTemplate: @@ -922,30 +929,24 @@ def __iter__(self) -> list[str]: return ["SELECT 2"] result = template(IterableWithTemplate()).compile() - assert result.sql == "SELECT 1" + expected = CompiledSql(sql="SELECT 1") + assert result == expected def test_same_expression_used_twice(self): """Using the same expression twice should create two separate params.""" interp1 = FakeInterpolation(value=42, expression="x") interp2 = FakeInterpolation(value=42, expression="x") result = template("SELECT * FROM t WHERE a = ", interp1, " AND b = ", interp2, "").compile() - assert len(result.params) == 2 - assert all(v == 42 for v in result.params.values()) + expected = CompiledSql(sql="SELECT * FROM t WHERE a = $p0_x AND b = $p1_x", params={"p0_x": 42, "p1_x": 42}) + assert result == expected def test_mixed_conversion_and_param(self): """Mix of !s (inlined) and regular (parameterized) in same template.""" table_interp = FakeInterpolation(value="users", expression="table", conversion="s") id_interp = FakeInterpolation(value=5, expression="user_id") result = template("SELECT * FROM ", table_interp, " WHERE id = ", id_interp, "").compile() - assert "users" in result.sql # inlined - assert "5" not in result.sql # parameterized - assert 5 in result.params.values() - - def test_resolved_sql_template_compile_matches_sql_template_compile(self): - """resolve().compile() should match compile() directly.""" - parts = ["SELECT * FROM t WHERE id = ", Param(value=42, name="x")] - t = template(*parts) - assert t.compile() == t.resolve().compile() + expected = CompiledSql(sql="SELECT * FROM users WHERE id = $p0_user_id", params={"p0_user_id": 5}) + assert result == expected # ═══════════════════════════════════════════════════════════════════════════════ @@ -990,9 +991,10 @@ def test_r_conversion_with_format_spec(self): """Python semantics: repr first, then format.""" interp = FakeInterpolation(value="hi", expression="x", conversion="r", format_spec=".4") t = SqlTemplate(interp) - resolved = list(t.resolve()) + resolved = t.resolve() # repr("hi") = "'hi'", then format("'hi'", ".4") = "'hi'" (already 4 chars) - assert resolved[0] == "'hi'" + expected = CompiledSql(sql="'hi'", params={}) + assert resolved == expected def test_format_spec_ignored_for_parameterized_values(self): """When no conversion is specified, format_spec is silently ignored — the value is parameterized as-is.""" @@ -1000,4 +1002,5 @@ def test_format_spec_ignored_for_parameterized_values(self): t = SqlTemplate(interp) compiled = t.compile() # The value is parameterized with its original value, format_spec is dropped - assert 3.14159 in compiled.params.values() + expected = CompiledSql(sql="SELECT $p0_x", params={"p0_x": 3.14159}) + assert compiled == expected From 758768f2f1a5a7ba320693e2cbf5fd2dfb157fcc Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 18:10:20 -0800 Subject: [PATCH 22/29] move protocol tests to be together --- tests/fast/test_template.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index b5943e2c..173c4d39 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -699,6 +699,18 @@ def __duckdb_template__(self, **kwargs) -> FakeInterpolation: compiled = t.compile() assert 42 in compiled.params.values() + def test_supports_duckdb_template_priority_over_iterable(self): + class IterableWithTemplate: + def __duckdb_template__(self, **kwargs) -> str: + return "SELECT 1" + + def __iter__(self) -> list[str]: + return ["SELECT 2"] + + result = template(IterableWithTemplate()).compile() + expected = CompiledSql(sql="SELECT 1") + assert result == expected + # ═══════════════════════════════════════════════════════════════════════════════ # IntoInterpolation protocol @@ -920,18 +932,6 @@ def test_param_object_in_interpolation_preserves_name(self): expected = CompiledSql(sql="SELECT $custom_name", params={"custom_name": 42}) assert result == expected - def test_supports_duckdb_template_priority_over_iterable(self): - class IterableWithTemplate: - def __duckdb_template__(self, **kwargs) -> str: - return "SELECT 1" - - def __iter__(self) -> list[str]: - return ["SELECT 2"] - - result = template(IterableWithTemplate()).compile() - expected = CompiledSql(sql="SELECT 1") - assert result == expected - def test_same_expression_used_twice(self): """Using the same expression twice should create two separate params.""" interp1 = FakeInterpolation(value=42, expression="x") From 3164e66fef16fb4f7208b9dc35de2b06a3537099 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 18:11:44 -0800 Subject: [PATCH 23/29] improve test style --- tests/fast/test_template.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 173c4d39..a1f001fa 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -697,7 +697,8 @@ def __duckdb_template__(self, **kwargs) -> FakeInterpolation: t = template(InterpolationReturner()) compiled = t.compile() - assert 42 in compiled.params.values() + expected = CompiledSql(sql="$p0_val", params={"p0_val": 42}) + assert compiled == expected def test_supports_duckdb_template_priority_over_iterable(self): class IterableWithTemplate: From c731ea797639963a517819befd97465ac1dfa4a3 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 18:13:50 -0800 Subject: [PATCH 24/29] fixup test --- tests/fast/test_template.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index a1f001fa..10ce34a9 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -3,9 +3,7 @@ from __future__ import annotations import pytest -from sympy import comp -from duckdb.experimental.spark.sql.functions import exp from duckdb.template import ( CompiledSql, IntoInterpolation, @@ -503,6 +501,7 @@ def test_conversion_r_resolves_to_repr(self): t = SqlTemplate("SELECT ", interp) resolved = t.resolve() expected = CompiledSql(sql="SELECT 'hello'", params={}) + assert resolved == expected def test_conversion_a_resolves_to_ascii(self): interp = FakeInterpolation(value=Cafe(), expression="x", conversion="a") From 637ab4b67c6fa9f7b4212f1fdf2cf2f5f66457c6 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 18:37:30 -0800 Subject: [PATCH 25/29] improve style of tests and get passing --- duckdb/template.py | 46 +++++++++-- tests/fast/test_template.py | 154 +++++++++++++++++++----------------- 2 files changed, 121 insertions(+), 79 deletions(-) diff --git a/duckdb/template.py b/duckdb/template.py index 916e7c79..e386cecd 100644 --- a/duckdb/template.py +++ b/duckdb/template.py @@ -10,6 +10,8 @@ if TYPE_CHECKING: from collections.abc import Iterator + from typing_extensions import TypeIs + __all__ = [ "CompiledSql", "IntoInterpolation", @@ -39,6 +41,11 @@ class CompiledSql: sql: str params: dict[str, object] = dataclasses.field(default_factory=dict) + def __str__(self) -> NoReturn: + """Disallow accidentally converting to a string, since that would lose the parameters.""" + msg = "CompiledSql cannot be directly converted to a string, since it may contain parameters. Please use the .sql attribute for the SQL string, and the .params attribute for the parameters." # noqa: E501 + raise NotImplementedError(msg) + @runtime_checkable class SupportsDuckdbTemplate(Protocol): @@ -163,6 +170,10 @@ def compile(*part: str | IntoInterpolation | Param | SupportsDuckdbTemplate | ob return t.compile() +def _is_iterable_nonstring(value: object) -> TypeIs[Iterable]: + return isinstance(value, Iterable) and not isinstance(value, str | bytes) + + def _expand_part(part: object) -> Iterable[str | IntoInterpolation]: if isinstance(part, SupportsDuckdbTemplate): raw = part.__duckdb_template__() @@ -172,8 +183,9 @@ def _expand_part(part: object) -> Iterable[str | IntoInterpolation]: yield raw elif isinstance(raw, Param): yield ParamInterpolation(raw) - elif isinstance(raw, Iterable): - yield from _expand_part(raw) + elif _is_iterable_nonstring(raw): + for item in raw: + yield from _expand_part(item) else: p = param(value=raw) yield ParamInterpolation(p) @@ -183,6 +195,9 @@ def _expand_part(part: object) -> Iterable[str | IntoInterpolation]: yield part elif isinstance(part, Param): yield ParamInterpolation(part) + elif _is_iterable_nonstring(part): + for item in part: + yield from _expand_part(item) else: p = param(value=part) yield ParamInterpolation(p) @@ -215,8 +230,12 @@ def _resolve(parts: Iterable[str | IntoInterpolation]) -> ResolvedSqlTemplate: def _resolve_interpolation(interp: IntoInterpolation) -> Iterable[str | Param]: value = interp.value if isinstance(value, Param): - # If it's already a Param, we can skip the template resolution and just return it as a param. - return (value,) + # Preserve direct ParamInterpolation behavior while allowing nested named Params to keep exact names. + if isinstance(interp, ParamInterpolation): + return (value,) + if value.name is None: + return (param(value.value),) + return (param(value.value, name=value.name, exact=True),) # if conversion specified (!s, !r, !a), treat as raw sql, eg # name = "Alice" @@ -293,6 +312,10 @@ class SqlTemplate: """A sequence of strings and Interpolations.""" def __init__(self, *parts: str | IntoInterpolation) -> None: + for part in parts: + if not isinstance(part, str | IntoInterpolation): + msg = f"Unexpected part type in SqlTemplate: {type(part).__name__}. Expected str or IntoInterpolation." + raise TypeError(msg) self.strings, self.interpolations = parse_parts(parts) def __iter__(self) -> Iterator[str | IntoInterpolation]: @@ -344,7 +367,20 @@ def __repr__(self) -> str: return f"ResolvedSqlTemplate({', '.join(part_strings)})" def __iter__(self) -> Iterator[str | Param]: - yield from self.parts + start = 0 + end = len(self.parts) + while start < end and self.parts[start] == "": + start += 1 + while end > start and self.parts[end - 1] == "": + end -= 1 + yield from self.parts[start:end] + + def __eq__(self, other: object) -> bool: + if isinstance(other, CompiledSql): + return self.compile() == other + if isinstance(other, ResolvedSqlTemplate): + return self.parts == other.parts + return False T = TypeVar("T") diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index 10ce34a9..cb992f5a 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -25,7 +25,7 @@ class FakeInterpolation: """Minimal object satisfying IntoInterpolation protocol.""" - def __init__(self, value, expression=None, conversion=None, format_spec="") -> None: + def __init__(self, value, *, expression=None, conversion=None, format_spec="") -> None: self.value = value self.expression = expression self.conversion = conversion @@ -43,8 +43,14 @@ def __duckdb_template__(self, **kwargs) -> str: class Cafe: - def __ascii__(self) -> str: - return "cafe" + """Test for ascii(obj) conversion.""" + + def __repr__(self) -> str: + return "Café" + + @classmethod + def ascii(cls) -> str: + return r"Caf\xe9" # ═══════════════════════════════════════════════════════════════════════════════ @@ -363,7 +369,7 @@ def test_bytes_not_treated_as_iterable(self): """Bytes should not be iterated — _is_iterable excludes bytes.""" t = template(b"hello") # type: ignore[arg-type] compiled = t.compile() - expected = CompiledSql(sql="$p0", params={"p0": b"hello"}) + expected = CompiledSql("$p0", params={"p0": b"hello"}) assert compiled == expected @@ -394,14 +400,14 @@ def test_multiple_params(self): def test_no_params(self): compiled = compile("SELECT 1") - expected = CompiledSql("SELECT 1", {}) + expected = CompiledSql("SELECT 1") assert expected == compiled def test_string_conversion_s(self): """!s should inline the value as raw SQL (no param).""" interp = FakeInterpolation(value="users", expression="table", conversion="s") compiled = compile("SELECT * FROM ", interp, "") - expected = CompiledSql("SELECT * FROM users", {}) + expected = CompiledSql("SELECT * FROM users") assert expected == compiled def test_repr_conversion_r(self): @@ -409,13 +415,13 @@ def test_repr_conversion_r(self): interp = FakeInterpolation(value="hello", expression="val", conversion="r") compiled = compile("SELECT ", interp, "") # repr of "hello" is "'hello'" - expected = CompiledSql("SELECT 'hello'", {}) + expected = CompiledSql("SELECT 'hello'") assert expected == compiled def test_ascii_conversion_a(self): interp = FakeInterpolation(value=Cafe(), expression="val", conversion="a") compiled = compile("SELECT ", interp) - expected = CompiledSql("SELECT cafe", {}) + expected = CompiledSql("SELECT " + Cafe.ascii()) assert expected == compiled def test_string_value_becomes_param_not_raw_sql(self): @@ -466,7 +472,7 @@ def test_format_spec_on_conversion(self): t = template("SELECT ", interp, "") compiled = t.compile() # Python semantics: str(3.14159) = "3.14159", then format("3.14159", ".5") = "3.141" - expected = CompiledSql("SELECT 3.141", {}) + expected = CompiledSql("SELECT 3.141") assert expected == compiled def test_format_spec_without_conversion_is_ignored(self): @@ -493,20 +499,20 @@ def test_conversion_s_resolves_to_string(self): interp = FakeInterpolation(value=42, expression="x", conversion="s") t = SqlTemplate("SELECT ", interp) resolved = t.compile() - expected = CompiledSql(sql="SELECT 42", params={}) + expected = CompiledSql("SELECT 42") assert resolved == expected def test_conversion_r_resolves_to_repr(self): interp = FakeInterpolation(value="hello", expression="x", conversion="r") t = SqlTemplate("SELECT ", interp) resolved = t.resolve() - expected = CompiledSql(sql="SELECT 'hello'", params={}) + expected = CompiledSql("SELECT 'hello'") assert resolved == expected def test_conversion_a_resolves_to_ascii(self): interp = FakeInterpolation(value=Cafe(), expression="x", conversion="a") actual = compile("SELECT ", interp) - expected = CompiledSql(sql="SELECT cafe") + expected = CompiledSql("SELECT " + Cafe.ascii()) assert actual == expected def test_string_value_without_conversion_becomes_param(self): @@ -514,7 +520,7 @@ def test_string_value_without_conversion_becomes_param(self): interp = FakeInterpolation(value="DROP TABLE users", expression="val") t = SqlTemplate("SELECT ", interp) resolved = t.resolve() - expected = CompiledSql(sql="SELECT $p0_val", params={"p0_val": "DROP TABLE users"}) + expected = CompiledSql("SELECT $p0_val", params={"p0_val": "DROP TABLE users"}) assert resolved == expected def test_nested_supports_duckdb_template(self): @@ -522,7 +528,7 @@ def test_nested_supports_duckdb_template(self): interp = FakeInterpolation(value=rel, expression="rel") t = SqlTemplate("SELECT * FROM (", interp, ")") resolved = t.resolve() - expected = CompiledSql(sql="SELECT * FROM (SELECT 1)", params={}) + expected = CompiledSql("SELECT * FROM (SELECT 1)") assert resolved == expected def test_expression_name_preserved_for_simple_param(self): @@ -530,7 +536,7 @@ def test_expression_name_preserved_for_simple_param(self): interp = FakeInterpolation(value=42, expression="my_age") t = SqlTemplate("age = ", interp) resolved = t.resolve() - expected = CompiledSql(sql="age = $p0_my_age", params={"p0_my_age": 42}) + expected = CompiledSql("age = $p0_my_age", params={"p0_my_age": 42}) assert resolved == expected @@ -576,26 +582,26 @@ def test_iter(self): class TestCompileParts: def test_all_strings(self): result = compile_parts(["SELECT 1"]) - assert result == CompiledSql(sql="SELECT 1", params={}) + assert result == CompiledSql("SELECT 1") def test_single_unnamed_param(self): result = compile_parts(["SELECT ", Param(value=42)]) - expected = CompiledSql(sql="SELECT $p0", params={"p0": 42}) + expected = CompiledSql("SELECT $p0", params={"p0": 42}) assert result == expected def test_single_named_param(self): result = compile_parts(["SELECT ", Param(value=42, name="x")]) - expected = CompiledSql(sql="SELECT $p0_x", params={"p0_x": 42}) + expected = CompiledSql("SELECT $p0_x", params={"p0_x": 42}) assert result == expected def test_exact_param_uses_literal_name(self): result = compile_parts(["SELECT ", Param(value=42, name="my_param", exact=True)]) - expected = CompiledSql(sql="SELECT $my_param", params={"my_param": 42}) + expected = CompiledSql("SELECT $my_param", params={"my_param": 42}) assert result == expected def test_multiple_params_numbered_sequentially(self): result = compile_parts(["SELECT * WHERE a = ", Param(value=1, name="a"), " AND b = ", Param(value=2, name="b")]) - expected = CompiledSql(sql="SELECT * WHERE a = $p0_a AND b = $p1_b", params={"p0_a": 1, "p1_b": 2}) + expected = CompiledSql("SELECT * WHERE a = $p0_a AND b = $p1_b", params={"p0_a": 1, "p1_b": 2}) assert result == expected def test_duplicate_param_names_raises(self): @@ -616,7 +622,7 @@ def test_unnamed_params_get_sequential_names(self): Param(value=2), ] ) - expected = CompiledSql(sql="a = $p0 AND b = $p1", params={"p0": 1, "p1": 2}) + expected = CompiledSql("a = $p0 AND b = $p1", params={"p0": 1, "p1": 2}) assert result == expected def test_exact_param_causes_counter_gap(self): @@ -631,22 +637,22 @@ def test_exact_param_causes_counter_gap(self): Param(value=3, name="y"), # → p2_y (not p1_y!) ] ) - expected = CompiledSql(sql="a = $p0_x AND b = $b AND c = $p2_y", params={"p0_x": 1, "b": 2, "p2_y": 3}) + expected = CompiledSql("a = $p0_x AND b = $b AND c = $p2_y", params={"p0_x": 1, "b": 2, "p2_y": 3}) assert result == expected def test_empty_parts(self): result = compile_parts([]) - expected = CompiledSql(sql="", params={}) + expected = CompiledSql("") assert result == expected def test_adjacent_strings(self): result = compile_parts(["SELECT ", "1"]) - expected = CompiledSql(sql="SELECT 1", params={}) + expected = CompiledSql("SELECT 1") assert result == expected def test_param_with_none_value(self): result = compile_parts(["SELECT ", Param(value=None, name="x")]) - expected = CompiledSql(sql="SELECT $p0_x", params={"p0_x": None}) + expected = CompiledSql("SELECT $p0_x", params={"p0_x": None}) assert result == expected @@ -696,7 +702,7 @@ def __duckdb_template__(self, **kwargs) -> FakeInterpolation: t = template(InterpolationReturner()) compiled = t.compile() - expected = CompiledSql(sql="$p0_val", params={"p0_val": 42}) + expected = CompiledSql("$p0_val", params={"p0_val": 42}) assert compiled == expected def test_supports_duckdb_template_priority_over_iterable(self): @@ -708,7 +714,7 @@ def __iter__(self) -> list[str]: return ["SELECT 2"] result = template(IterableWithTemplate()).compile() - expected = CompiledSql(sql="SELECT 1") + expected = CompiledSql("SELECT 1") assert result == expected @@ -738,7 +744,7 @@ def test_param_interpolation_satisfies(self): class TestCompiledSql: def test_basic(self): - c = CompiledSql(sql="SELECT $p0", params={"p0": 42}) + c = CompiledSql("SELECT $p0", params={"p0": 42}) assert c.sql == "SELECT $p0" assert c.params == {"p0": 42} @@ -748,30 +754,35 @@ def test_positional_params(self): assert c.params == {"p0": 42} def test_empty_params(self): - c = CompiledSql(sql="SELECT 1", params={}) + c = CompiledSql("SELECT 1") assert c.sql == "SELECT 1" assert c.params == {} def test_optional_params(self): - c = CompiledSql(sql="SELECT 1") + c = CompiledSql("SELECT 1") assert c.sql == "SELECT 1" assert c.params == {} def test_frozen(self): - c = CompiledSql(sql="SELECT 1", params={}) + c = CompiledSql("SELECT 1") with pytest.raises(AttributeError): c.sql = "SELECT 2" # ty:ignore[invalid-assignment] def test_equality(self): - a = CompiledSql(sql="SELECT 1", params={}) - b = CompiledSql(sql="SELECT 1", params={}) + a = CompiledSql("SELECT 1") + b = CompiledSql("SELECT 1") assert a == b def test_repr(self): - c = CompiledSql(sql="SELECT 1", params={}) + c = CompiledSql("SELECT 1") expected = "CompiledSql(sql='SELECT 1', params={})" assert repr(c) == expected + def test_str(self): + c = CompiledSql("SELECT 1") + with pytest.raises(NotImplementedError): + str(c) + # ═══════════════════════════════════════════════════════════════════════════════ # End-to-end: compile @@ -814,12 +825,14 @@ def __duckdb_template__(self, **kwargs) -> str: expected = CompiledSql("SELECT * FROM my_table", {}) assert expected == result - def test_interpolations_end_to_end(self): + def test_unparameterized_strings(self): result = compile("SELECT * FROM users WHERE id = ", 42, " AND name = ", "Alice") - expected = CompiledSql( - "SELECT * FROM users WHERE id = $p0 AND name = $p1", - {"p0": 42, "p1": "Alice"}, - ) + expected = CompiledSql("SELECT * FROM users WHERE id = $p0 AND name = Alice", {"p0": 42}) + assert expected == result + + def test_parameterized_strings(self): + result = compile("SELECT * FROM users WHERE id = ", 42, " AND name = ", param("Alice")) + expected = CompiledSql("SELECT * FROM users WHERE id = $p0 AND name = $p1", {"p0": 42, "p1": "Alice"}) assert expected == result def test_nested_template_relations(self): @@ -888,40 +901,40 @@ def test_exact_param_name(self): class TestEdgeCases: def test_empty_string_template(self): result = compile("") - expected = CompiledSql(sql="", params={}) + expected = CompiledSql("") assert result == expected def test_param_with_none_value(self): result = compile(param(value=None, name="x")) - expected = CompiledSql(sql="$p0_x", params={"p0_x": None}) + expected = CompiledSql("$p0_x", params={"p0_x": None}) assert result == expected def test_param_with_list_value(self): result = compile(param(value=[1, 2, 3], name="ids")) - expected = CompiledSql(sql="$p0_ids", params={"p0_ids": [1, 2, 3]}) + expected = CompiledSql("$p0_ids", params={"p0_ids": [1, 2, 3]}) assert result == expected def test_param_with_dict_value(self): d = {"key": "value"} result = compile(param(value=d, name="data")) - expected = CompiledSql(sql="$p0_data", params={"p0_data": d}) + expected = CompiledSql("$p0_data", params={"p0_data": d}) assert result == expected def test_bool_param(self): result = compile("SELECT * FROM t WHERE active = ", True) - expected = CompiledSql(sql="SELECT * FROM t WHERE active = $p0", params={"p0": True}) + expected = CompiledSql("SELECT * FROM t WHERE active = $p0", params={"p0": True}) assert result == expected def test_float_param(self): interp = FakeInterpolation(value=3.14, expression="threshold") result = template("SELECT * FROM t WHERE score > ", interp, "").compile() - expected = CompiledSql(sql="SELECT * FROM t WHERE score > $p0_threshold", params={"p0_threshold": 3.14}) + expected = CompiledSql("SELECT * FROM t WHERE score > $p0_threshold", params={"p0_threshold": 3.14}) assert result == expected def test_none_param(self): interp = FakeInterpolation(value=None, expression="val") result = template("SELECT * FROM t WHERE x IS ", interp, "").compile() - expected = CompiledSql(sql="SELECT * FROM t WHERE x IS $p0_val", params={"p0_val": None}) + expected = CompiledSql("SELECT * FROM t WHERE x IS $p0_val", params={"p0_val": None}) assert result == expected def test_param_object_in_interpolation_preserves_name(self): @@ -929,7 +942,7 @@ def test_param_object_in_interpolation_preserves_name(self): p = Param(value=42, name="custom_name") interp = FakeInterpolation(value=p, expression="p") result = template("SELECT ", interp, "").compile() - expected = CompiledSql(sql="SELECT $custom_name", params={"custom_name": 42}) + expected = CompiledSql("SELECT $custom_name", params={"custom_name": 42}) assert result == expected def test_same_expression_used_twice(self): @@ -937,7 +950,7 @@ def test_same_expression_used_twice(self): interp1 = FakeInterpolation(value=42, expression="x") interp2 = FakeInterpolation(value=42, expression="x") result = template("SELECT * FROM t WHERE a = ", interp1, " AND b = ", interp2, "").compile() - expected = CompiledSql(sql="SELECT * FROM t WHERE a = $p0_x AND b = $p1_x", params={"p0_x": 42, "p1_x": 42}) + expected = CompiledSql("SELECT * FROM t WHERE a = $p0_x AND b = $p1_x", params={"p0_x": 42, "p1_x": 42}) assert result == expected def test_mixed_conversion_and_param(self): @@ -945,7 +958,7 @@ def test_mixed_conversion_and_param(self): table_interp = FakeInterpolation(value="users", expression="table", conversion="s") id_interp = FakeInterpolation(value=5, expression="user_id") result = template("SELECT * FROM ", table_interp, " WHERE id = ", id_interp, "").compile() - expected = CompiledSql(sql="SELECT * FROM users WHERE id = $p0_user_id", params={"p0_user_id": 5}) + expected = CompiledSql("SELECT * FROM users WHERE id = $p0_user_id", params={"p0_user_id": 5}) assert result == expected @@ -959,48 +972,41 @@ class TestConversionSemantics: def test_s_conversion_on_int(self): interp = FakeInterpolation(value=42, expression="x", conversion="s") - t = SqlTemplate(interp) - resolved = list(t.resolve()) - assert len(resolved) == 1 - assert resolved[0] == "42" + actual = compile(interp) + expected = CompiledSql("42") + assert actual == expected def test_r_conversion_on_string(self): """repr('hello') = "'hello'".""" interp = FakeInterpolation(value="hello", expression="x", conversion="r") - t = SqlTemplate(interp) - resolved = list(t.resolve()) - assert len(resolved) == 1 - assert resolved[0] == "'hello'" + actual = compile(interp) + expected = CompiledSql("'hello'") + assert actual == expected def test_r_conversion_on_int(self): """repr(42) = '42', no quotes.""" interp = FakeInterpolation(value=42, expression="x", conversion="r") - t = SqlTemplate(interp) - resolved = list(t.resolve()) - assert resolved[0] == "42" + actual = compile(interp) + expected = CompiledSql("42") + assert actual == expected def test_s_conversion_with_format_spec(self): """Conversion first, then format_spec: str(3.14159) then format with '.5' truncates.""" interp = FakeInterpolation(value=3.14159, expression="x", conversion="s", format_spec=".5") - t = SqlTemplate(interp) - resolved = list(t.resolve()) - # str(3.14159) = "3.14159", then format("3.14159", ".5") = "3.141" (truncates string to 5 chars) - assert resolved[0] == "3.141" + actual = compile(interp) + expected = CompiledSql("3.141") + assert actual == expected def test_r_conversion_with_format_spec(self): """Python semantics: repr first, then format.""" interp = FakeInterpolation(value="hi", expression="x", conversion="r", format_spec=".4") - t = SqlTemplate(interp) - resolved = t.resolve() - # repr("hi") = "'hi'", then format("'hi'", ".4") = "'hi'" (already 4 chars) - expected = CompiledSql(sql="'hi'", params={}) - assert resolved == expected + actual = compile(interp) + expected = CompiledSql("'hi'") + assert actual == expected def test_format_spec_ignored_for_parameterized_values(self): """When no conversion is specified, format_spec is silently ignored — the value is parameterized as-is.""" interp = FakeInterpolation(value=3.14159, expression="x", format_spec=".2f") - t = SqlTemplate(interp) - compiled = t.compile() - # The value is parameterized with its original value, format_spec is dropped - expected = CompiledSql(sql="SELECT $p0_x", params={"p0_x": 3.14159}) - assert compiled == expected + actual = compile(interp) + expected = CompiledSql("$p0_x", params={"p0_x": 3.14159}) + assert actual == expected From 8531fc49cb77b2109c4fc24575da8e52ff256c11 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 18:49:43 -0800 Subject: [PATCH 26/29] reduce verbosity of param construction in tests --- tests/fast/test_template.py | 106 ++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 47 deletions(-) diff --git a/tests/fast/test_template.py b/tests/fast/test_template.py index cb992f5a..1d645dfc 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template.py @@ -70,6 +70,12 @@ def test_named_param(self): assert p.name == "greeting" assert p.exact is False + def test_positional_creation(self): + p = Param(42, "greeting", True) + assert p.value == 42 + assert p.name == "greeting" + assert p.exact is True + def test_exact_param_requires_name(self): with pytest.raises(ValueError, match="exact=True must have a name"): Param(value=1, exact=True) @@ -94,8 +100,8 @@ def test_param_helper_function(self): def test_param_repr(self): p = Param(value=42, name="x") r = repr(p) - assert "42" in r - assert "x" in r + expected = "Param(value=42, name='x', exact=False)" + assert r == expected def test_param_equality(self): """Frozen dataclasses support equality by default.""" @@ -108,6 +114,32 @@ def test_param_various_value_types(self): p = Param(value=val) assert p.value is val + def test_param_factory_function_defaults(self): + """param() should allow defaults for name and exact.""" + expected = Param(value=42, name=None, exact=False) + assert param(42) == expected + assert param(value=42) == expected + + def test_param_factory_function_named(self): + p = param(42, name="x") + assert p.value == 42 + assert p.name == "x" + assert p.exact is False + + def test_param_factory_function_exact(self): + p = param(42, name="x", exact=True) + assert p.value == 42 + assert p.name == "x" + assert p.exact is True + + def test_param_factory_function_exact_requires_name(self): + with pytest.raises(ValueError, match="exact=True must have a name"): + param(42, exact=True) + + def test_param_factory_function_exact_with_name(self): + with pytest.raises(TypeError): + param(42, "x", True) # ty:ignore[too-many-positional-arguments] + # ═══════════════════════════════════════════════════════════════════════════════ # ParamInterpolation @@ -116,7 +148,7 @@ def test_param_various_value_types(self): class TestParamInterpolation: def test_wraps_param(self): - p = Param(value=42, name="x") + p = param(42, "x") pi = ParamInterpolation(p) assert pi.value is p assert pi.expression == "x" @@ -124,12 +156,12 @@ def test_wraps_param(self): assert pi.format_spec == "" def test_unnamed_param_expression_is_none(self): - p = Param(value=42) + p = param(42) pi = ParamInterpolation(p) assert pi.expression is None def test_satisfies_into_interpolation_protocol(self): - pi = ParamInterpolation(Param(value=1, name="x")) + pi = ParamInterpolation(param(1, "x")) assert isinstance(pi, IntoInterpolation) @@ -234,12 +266,11 @@ def test_with_interpolation(self): assert t.interpolations[0] is interp def test_bare_param_errors(self): - p = Param(value=42) with pytest.raises(TypeError, match="Unexpected part type"): - SqlTemplate("SELECT ", p) # ty:ignore[invalid-argument-type] + SqlTemplate("SELECT ", param(42)) # ty:ignore[invalid-argument-type] def test_wrapped_param(self): - wrapped = ParamInterpolation(Param(value=42, name="x")) + wrapped = ParamInterpolation(param(42, "x")) t = SqlTemplate("SELECT ", wrapped, " FROM t") assert len(t.interpolations) == 1 @@ -299,8 +330,7 @@ def test_plain_string(self): assert expected == compiled def test_param(self): - p = Param(value=42, name="answer") - t = template(p) + t = template(param(42, "answer")) compiled = t.compile() expected = CompiledSql("$p0_answer", {"p0_answer": 42}) assert expected == compiled @@ -346,7 +376,7 @@ def test_iterable_of_strings(self): assert expected == compiled def test_iterable_with_params(self): - t = template("SELECT * FROM t WHERE id = ", Param(value=5, name="id")) + t = template("SELECT * FROM t WHERE id = ", param(5, "id")) compiled = t.compile() expected = CompiledSql("SELECT * FROM t WHERE id = $p0_id", {"p0_id": 5}) assert expected == compiled @@ -547,12 +577,12 @@ def test_expression_name_preserved_for_simple_param(self): class TestResolvedSqlTemplate: def test_basic(self): - r = ResolvedSqlTemplate(["SELECT ", Param(value=42, name="x")]) + r = ResolvedSqlTemplate(["SELECT ", param(42, "x")]) parts = list(r) assert len(parts) == 2 def test_compile(self): - r = ResolvedSqlTemplate(["SELECT ", Param(value=42, name="x")]) + r = ResolvedSqlTemplate(["SELECT ", param(42, "x")]) compiled = r.compile() assert isinstance(compiled, CompiledSql) assert 42 in compiled.params.values() @@ -563,13 +593,13 @@ def test_str_raises(self): str(r) def test_repr(self): - r = ResolvedSqlTemplate(["SELECT ", Param(value=42, name="x")]) + r = ResolvedSqlTemplate(["SELECT ", param(42, "x")]) rep = repr(r) assert "ResolvedSqlTemplate" in rep assert "x=42" in rep def test_iter(self): - parts_in = ["a", Param(value=1, name="x"), "b"] + parts_in = ["a", param(1, "x"), "b"] r = ResolvedSqlTemplate(parts_in) assert list(r) == parts_in @@ -585,43 +615,31 @@ def test_all_strings(self): assert result == CompiledSql("SELECT 1") def test_single_unnamed_param(self): - result = compile_parts(["SELECT ", Param(value=42)]) + result = compile_parts(["SELECT ", param(42)]) expected = CompiledSql("SELECT $p0", params={"p0": 42}) assert result == expected def test_single_named_param(self): - result = compile_parts(["SELECT ", Param(value=42, name="x")]) + result = compile_parts(["SELECT ", param(42, "x")]) expected = CompiledSql("SELECT $p0_x", params={"p0_x": 42}) assert result == expected def test_exact_param_uses_literal_name(self): - result = compile_parts(["SELECT ", Param(value=42, name="my_param", exact=True)]) + result = compile_parts(["SELECT ", param(42, "my_param", exact=True)]) expected = CompiledSql("SELECT $my_param", params={"my_param": 42}) assert result == expected def test_multiple_params_numbered_sequentially(self): - result = compile_parts(["SELECT * WHERE a = ", Param(value=1, name="a"), " AND b = ", Param(value=2, name="b")]) + result = compile_parts(["SELECT * WHERE a = ", param(1, "a"), " AND b = ", param(2, "b")]) expected = CompiledSql("SELECT * WHERE a = $p0_a AND b = $p1_b", params={"p0_a": 1, "p1_b": 2}) assert result == expected def test_duplicate_param_names_raises(self): with pytest.raises(ValueError, match="Duplicate parameter names"): - compile_parts( - [ - Param(value=1, name="x", exact=True), - Param(value=2, name="x", exact=True), - ] - ) + compile_parts([param(1, "x", exact=True), param(2, "x", exact=True)]) def test_unnamed_params_get_sequential_names(self): - result = compile_parts( - [ - "a = ", - Param(value=1), - " AND b = ", - Param(value=2), - ] - ) + result = compile_parts(["a = ", param(1), " AND b = ", param(2)]) expected = CompiledSql("a = $p0 AND b = $p1", params={"p0": 1, "p1": 2}) assert result == expected @@ -630,11 +648,11 @@ def test_exact_param_causes_counter_gap(self): result = compile_parts( [ "a = ", - Param(value=1, name="x"), # → p0_x + param(1, "x"), # → p0_x " AND b = ", - Param(value=2, name="b", exact=True), # → b (exact), but counter increments + param(2, "b", exact=True), # → b (exact), but counter increments " AND c = ", - Param(value=3, name="y"), # → p2_y (not p1_y!) + param(3, "y"), # → p2_y (not p1_y!) ] ) expected = CompiledSql("a = $p0_x AND b = $b AND c = $p2_y", params={"p0_x": 1, "b": 2, "p2_y": 3}) @@ -651,7 +669,7 @@ def test_adjacent_strings(self): assert result == expected def test_param_with_none_value(self): - result = compile_parts(["SELECT ", Param(value=None, name="x")]) + result = compile_parts(["SELECT ", param(None, "x")]) expected = CompiledSql("SELECT $p0_x", params={"p0_x": None}) assert result == expected @@ -733,7 +751,7 @@ def test_protocol_check_negative(self): assert not isinstance(42, IntoInterpolation) def test_param_interpolation_satisfies(self): - pi = ParamInterpolation(Param(value=1)) + pi = ParamInterpolation(param(1)) assert isinstance(pi, IntoInterpolation) @@ -799,17 +817,12 @@ def test_strings_joined(self): assert result == CompiledSql("SELECT * FROM users", {}) def test_param_in_list(self): - result = compile("SELECT * FROM users WHERE id = ", Param(value=5, name="id")) + result = compile("SELECT * FROM users WHERE id = ", param(5, "id")) expected = CompiledSql("SELECT * FROM users WHERE id = $p0_id", {"p0_id": 5}) assert expected == result def test_multiple_params_in_list(self): - result = compile( - "SELECT * FROM users WHERE name = ", - Param(value="Alice", name="name"), - " AND age > ", - Param(value=18, name="age"), - ) + result = compile("SELECT * FROM users WHERE name = ", param("Alice", "name"), " AND age > ", param(18, "age")) expected = CompiledSql( "SELECT * FROM users WHERE name = $p0_name AND age > $p1_age", {"p0_name": "Alice", "p1_age": 18}, @@ -939,8 +952,7 @@ def test_none_param(self): def test_param_object_in_interpolation_preserves_name(self): """An explicit Param used in an interpolation should keep its name.""" - p = Param(value=42, name="custom_name") - interp = FakeInterpolation(value=p, expression="p") + interp = FakeInterpolation(value=param(42, "custom_name"), expression="p") result = template("SELECT ", interp, "").compile() expected = CompiledSql("SELECT $custom_name", params={"custom_name": 42}) assert result == expected From 41c639fd741a5cb1d003b14472516bd2d97082b1 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 18:52:14 -0800 Subject: [PATCH 27/29] rename test to make room for test file that does actually execute code --- tests/fast/{test_template.py => test_template_python.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/fast/{test_template.py => test_template_python.py} (99%) diff --git a/tests/fast/test_template.py b/tests/fast/test_template_python.py similarity index 99% rename from tests/fast/test_template.py rename to tests/fast/test_template_python.py index 1d645dfc..2896455b 100644 --- a/tests/fast/test_template.py +++ b/tests/fast/test_template_python.py @@ -1,4 +1,4 @@ -"""Exhaustive tests for template.py — the SQL template / t-string system.""" +"""Pure-python tests that don't require a compiled extension module.""" from __future__ import annotations From 0ff65159111b0b19f799e6b98014fb81defd687f Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 19:16:10 -0800 Subject: [PATCH 28/29] add e2e tests --- tests/fast/test_template_e2e.py | 114 ++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 tests/fast/test_template_e2e.py diff --git a/tests/fast/test_template_e2e.py b/tests/fast/test_template_e2e.py new file mode 100644 index 00000000..a294c7e1 --- /dev/null +++ b/tests/fast/test_template_e2e.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import pytest + +import duckdb +from duckdb.template import param, template + + +def test_connection_sql_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(5) t(i) WHERE i >= ", 2, " ORDER BY i") + assert conn.sql(query).fetchall() == [(2,), (3,), (4,)] + + +def test_connection_query_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(3) t(i) WHERE i < ", 2, " ORDER BY i") + assert conn.query(query).fetchall() == [(0,), (1,)] + + +def test_connection_from_query_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(4) t(i) WHERE i % ", 2, " = 0 ORDER BY i") + assert conn.from_query(query).fetchall() == [(0,), (2,)] + + +def test_connection_execute_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT ", 42) + assert conn.execute(query).fetchone() == (42,) + + +def test_module_level_sql_apis_accept_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT i FROM range(5) t(i) WHERE i BETWEEN ", 1, " AND ", 3, " ORDER BY i") + + assert duckdb.sql(query, connection=conn).fetchall() == [(1,), (2,), (3,)] + assert duckdb.query(query, connection=conn).fetchall() == [(1,), (2,), (3,)] + assert duckdb.from_query(query, connection=conn).fetchall() == [(1,), (2,), (3,)] + + +def test_module_level_execute_accepts_sql_template() -> None: + conn = duckdb.connect() + query = template("SELECT ", "hello") + assert duckdb.execute(query, connection=conn).fetchone() == ("hello",) + + +def test_connection_sql_accepts_alias_kwarg_with_template() -> None: + conn = duckdb.connect() + inner = conn.sql(template("SELECT 42 AS x"), alias="my_alias") + assert inner.alias == "my_alias" + outer = conn.sql(template("SELECT x FROM (", inner, ")")) + assert outer.fetchall() == [(42,)] + + +def test_connection_sql_template_can_merge_additional_params() -> None: + conn = duckdb.connect() + query = template("SELECT ", 10, " + $another") + assert conn.sql(query, params={"another": 5}).fetchall() == [(15,)] + + +def test_connection_sql_template_param_name_conflict_with_additional_params_raises() -> None: + conn = duckdb.connect() + query = template("SELECT ", param(10, "num"), " + $num") + with pytest.raises((duckdb.InvalidInputException, ValueError)): + conn.sql(query, params={"num": 5}).fetchall() + + +def test_sql_apis_accept_compiled_sql() -> None: + conn = duckdb.connect() + compiled = template("SELECT i FROM range(5) t(i) WHERE i >= ", 3, " ORDER BY i").compile() + + assert conn.sql(compiled).fetchall() == [(3,), (4,)] + assert conn.query(compiled).fetchall() == [(3,), (4,)] + assert conn.from_query(compiled).fetchall() == [(3,), (4,)] + assert conn.execute(compiled).fetchall() == [(3,), (4,)] + + +def test_relation_interpolation_works_end_to_end() -> None: + conn = duckdb.connect() + rel = conn.sql("SELECT i FROM range(6) t(i)") + query = template("SELECT i FROM (", rel, ") WHERE i % ", 2, " = 0 ORDER BY i") + assert conn.sql(query).fetchall() == [(0,), (2,), (4,)] + + +def test_interpolated_strings_are_parameterized_by_default() -> None: + conn = duckdb.connect() + conn.execute("CREATE TABLE names(name VARCHAR)") + conn.execute("INSERT INTO names VALUES ('alice'), ('bob')") + + untrusted = "alice' OR 1=1 --" + query = template("SELECT count(*) FROM names WHERE name = ", untrusted) + assert conn.sql(query).fetchone() == (0,) + + +def test_builtin_duckdbpytype_object_interpolates_in_template() -> None: + conn = duckdb.connect() + integer_type = duckdb.sqltype("INTEGER") + query = template("SELECT 42::", integer_type) + assert conn.sql(query).fetchall() == [(42,)] + + +def test_builtin_expression_object_interpolates_in_template() -> None: + conn = duckdb.connect() + expr = duckdb.ColumnExpression("i") + query = template("SELECT ", expr, " FROM range(3) t(i) ORDER BY i") + assert conn.sql(query).fetchall() == [(0,), (1,), (2,)] + + +def test_builtin_sqlexpression_object_interpolates_in_template() -> None: + conn = duckdb.connect() + expr = duckdb.SQLExpression("i + 1") + query = template("SELECT ", expr, " FROM range(3) t(i) ORDER BY i") + assert conn.sql(query).fetchall() == [(1,), (2,), (3,)] From 3b8746f2bd6c3edafedb808571aa58f35e0128aa Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sat, 28 Mar 2026 19:52:51 -0800 Subject: [PATCH 29/29] implement at the cpp level --- _duckdb-stubs/__init__.pyi | 34 +++++--- .../pyconnection/pyconnection.hpp | 1 + src/duckdb_py/pyconnection.cpp | 79 ++++++++++++++++++- src/duckdb_py/pyexpression/initialize.cpp | 1 + src/duckdb_py/pyrelation/initialize.cpp | 1 + src/duckdb_py/typing/pytype.cpp | 1 + tests/fast/test_template_e2e.py | 24 +++--- 7 files changed, 115 insertions(+), 26 deletions(-) diff --git a/_duckdb-stubs/__init__.pyi b/_duckdb-stubs/__init__.pyi index 81d69be7..ea7d3c88 100644 --- a/_duckdb-stubs/__init__.pyi +++ b/_duckdb-stubs/__init__.pyi @@ -13,7 +13,7 @@ if typing.TYPE_CHECKING: import pandas import pyarrow.lib from collections.abc import Callable, Iterable, Sequence, Mapping - from duckdb import sqltypes, func + from duckdb import sqltypes, func, template from builtins import list as lst # needed to avoid mypy error on DuckDBPyRelation.list method shadowing # the field_ids argument to to_parquet and write_parquet has a recursive structure @@ -241,8 +241,12 @@ class DuckDBPyConnection: def dtype(self, type_str: str) -> sqltypes.DuckDBPyType: ... def duplicate(self) -> DuckDBPyConnection: ... def enum_type(self, name: str, type: sqltypes.DuckDBPyType, values: lst[typing.Any]) -> sqltypes.DuckDBPyType: ... - def execute(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... - def executemany(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... + def execute( + self, query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None + ) -> DuckDBPyConnection: ... + def executemany( + self, query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None + ) -> DuckDBPyConnection: ... def extract_statements(self, query: str) -> lst[Statement]: ... def fetch_arrow_table(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.Table: """Deprecated: use to_arrow_table() instead.""" @@ -331,7 +335,9 @@ class DuckDBPyConnection: union_by_name: bool = False, compression: str | None = None, ) -> DuckDBPyRelation: ... - def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def from_query( + self, query: str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None + ) -> DuckDBPyRelation: ... def get_table_names(self, query: str, *, qualified: bool = False) -> set[str]: ... def install_extension( self, @@ -360,7 +366,9 @@ class DuckDBPyConnection: def pl( self, rows_per_batch: typing.SupportsInt = 1000000, *, lazy: bool = False ) -> polars.DataFrame | polars.LazyFrame: ... - def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def query( + self, query: str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None + ) -> DuckDBPyRelation: ... def query_progress(self) -> float: ... def read_csv( self, @@ -462,7 +470,13 @@ class DuckDBPyConnection: def row_type( self, fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType] ) -> sqltypes.DuckDBPyType: ... - def sql(self, query: Statement | str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def sql( + self, + query: Statement | str | template.SqlTemplate | template.CompiledSql, + *, + alias: str = "", + params: object = None, + ) -> DuckDBPyRelation: ... def sqltype(self, type_str: str) -> sqltypes.DuckDBPyType: ... def string_type(self, collation: str = "") -> sqltypes.DuckDBPyType: ... def struct_type( @@ -1160,7 +1174,7 @@ def enum_type( connection: DuckDBPyConnection | None = None, ) -> sqltypes.DuckDBPyType: ... def execute( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None, *, connection: DuckDBPyConnection | None = None, @@ -1282,7 +1296,7 @@ def from_parquet( connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... def from_query( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None, @@ -1350,7 +1364,7 @@ def project( df: pandas.DataFrame, *args: _ExpressionLike, groups: str = "", connection: DuckDBPyConnection | None = None ) -> DuckDBPyRelation: ... def query( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None, @@ -1474,7 +1488,7 @@ def row_type( def rowcount(*, connection: DuckDBPyConnection | None = None) -> int: ... def set_default_connection(connection: DuckDBPyConnection) -> None: ... def sql( - query: Statement | str, + query: Statement | str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None, diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index dd7c9d2e..3d0e0509 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -363,6 +363,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, bool side_effects); void RegisterArrowObject(const py::object &arrow_object, const string &name); + pair ExtractCompiledSqlAndParams(const py::object &query, py::object params); vector> GetStatements(const py::object &query); static PythonEnvironmentType environment; diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 6883ba45..400dcd09 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -708,11 +708,74 @@ shared_ptr DuckDBPyConnection::ExecuteFromString(const strin return Execute(py::str(query)); } +pair DuckDBPyConnection::ExtractCompiledSqlAndParams(const py::object &query, py::object params) { + py::object compiled = query; + if (!py::hasattr(compiled, "sql") || !py::hasattr(compiled, "params")) { + if (!py::hasattr(query, "compile")) { + return {string(), params}; + } + compiled = query.attr("compile")(); + } + + if (!py::hasattr(compiled, "sql") || !py::hasattr(compiled, "params")) { + return {string(), params}; + } + + auto compiled_sql = py::cast(compiled.attr("sql")); + auto compiled_params_obj = compiled.attr("params"); + if (!py::is_dict_like(compiled_params_obj)) { + throw InvalidInputException("Compiled SQL parameters must be a dictionary"); + } + + auto compiled_params = py::cast(compiled_params_obj); + if (compiled_params.empty()) { + return {compiled_sql, params}; + } + + if (params.is_none()) { + return {compiled_sql, compiled_params}; + } + + if (py::is_dict_like(params)) { + auto merged_params = py::dict(); + for (auto &item : compiled_params) { + merged_params[item.first] = item.second; + } + auto provided_params = py::cast(params); + for (auto &item : provided_params) { + if (merged_params.contains(item.first)) { + throw py::value_error("Cannot merge compiled SQL parameters with duplicate parameter names"); + } + merged_params[item.first] = item.second; + } + return {compiled_sql, merged_params}; + } + + if (py::is_list_like(params)) { + if (py::len(params) == 0) { + return {compiled_sql, compiled_params}; + } + throw py::value_error("Cannot merge compiled SQL named parameters with positional parameters"); + } + + throw InvalidInputException("Prepared parameters can only be passed as a list or a dictionary"); +} + shared_ptr DuckDBPyConnection::Execute(const py::object &query, py::object params) { py::gil_scoped_acquire gil; con.SetResult(nullptr); - auto statements = GetStatements(query); + auto normalized_query = ExtractCompiledSqlAndParams(query, params); + auto &compiled_sql = normalized_query.first; + auto &merged_params = normalized_query.second; + vector> statements; + if (!compiled_sql.empty()) { + statements = GetStatements(py::str(compiled_sql)); + params = merged_params; + } else { + statements = GetStatements(query); + } + if (statements.empty()) { // TODO: should we throw? return nullptr; @@ -1603,7 +1666,17 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer alias = "unnamed_relation_" + StringUtil::GenerateRandomName(16); } - auto statements = GetStatements(query); + auto normalized_query = ExtractCompiledSqlAndParams(query, params); + auto &compiled_sql = normalized_query.first; + auto &merged_params = normalized_query.second; + vector> statements; + if (!compiled_sql.empty()) { + statements = GetStatements(py::str(compiled_sql)); + params = merged_params; + } else { + statements = GetStatements(query); + } + if (statements.empty()) { // TODO: should we throw? return nullptr; @@ -1616,7 +1689,7 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer // Attempt to create a Relation for lazy execution if possible shared_ptr relation; - bool has_params = !py::none().is(params) && py::len(params) > 0; + bool has_params = !params.is_none() && py::len(params) > 0; if (!has_params) { // No params (or empty params) — use lazy QueryRelation path { diff --git a/src/duckdb_py/pyexpression/initialize.cpp b/src/duckdb_py/pyexpression/initialize.cpp index 11cf5dc3..d709c582 100644 --- a/src/duckdb_py/pyexpression/initialize.cpp +++ b/src/duckdb_py/pyexpression/initialize.cpp @@ -314,6 +314,7 @@ void DuckDBPyExpression::Initialize(py::module_ &m) { Print the stringified version of the expression. )"; expression.def("show", &DuckDBPyExpression::Print, docs); + expression.def("__duckdb_template__", &DuckDBPyExpression::ToString); docs = R"( Set the order by modifier to ASCENDING. diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 4393889a..cd9b1e67 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -342,6 +342,7 @@ void DuckDBPyRelation::Initialize(py::handle &m) { .def("show", &DuckDBPyRelation::Print, "Display a summary of the data", py::kw_only(), py::arg("max_width") = py::none(), py::arg("max_rows") = py::none(), py::arg("max_col_width") = py::none(), py::arg("null_value") = py::none(), py::arg("render_mode") = py::none()) + .def("__duckdb_template__", &DuckDBPyRelation::ToSQL) .def("__str__", &DuckDBPyRelation::ToString) .def("__repr__", &DuckDBPyRelation::ToString); diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 5087de50..03e87873 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -329,6 +329,7 @@ void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); + type_module.def("__duckdb_template__", &DuckDBPyType::ToString); type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), diff --git a/tests/fast/test_template_e2e.py b/tests/fast/test_template_e2e.py index a294c7e1..96f14314 100644 --- a/tests/fast/test_template_e2e.py +++ b/tests/fast/test_template_e2e.py @@ -41,8 +41,8 @@ def test_module_level_sql_apis_accept_sql_template() -> None: def test_module_level_execute_accepts_sql_template() -> None: conn = duckdb.connect() - query = template("SELECT ", "hello") - assert duckdb.execute(query, connection=conn).fetchone() == ("hello",) + query = template("SELECT ", 5) + assert duckdb.execute(query, connection=conn).fetchone() == (5,) def test_connection_sql_accepts_alias_kwarg_with_template() -> None: @@ -61,11 +61,19 @@ def test_connection_sql_template_can_merge_additional_params() -> None: def test_connection_sql_template_param_name_conflict_with_additional_params_raises() -> None: conn = duckdb.connect() - query = template("SELECT ", param(10, "num"), " + $num") + query = template("SELECT ", param(10, "num", exact=True), " + $num") with pytest.raises((duckdb.InvalidInputException, ValueError)): conn.sql(query, params={"num": 5}).fetchall() +def test_cant_merge_with_positional_params() -> None: + conn = duckdb.connect() + # It doesn't even have a name, but still should error + query = template("SELECT ", 10, " + ?") + with pytest.raises(ValueError, match="Cannot merge compiled SQL named parameters with positional parameters"): + conn.sql(query, params=[5]).fetchall() + + def test_sql_apis_accept_compiled_sql() -> None: conn = duckdb.connect() compiled = template("SELECT i FROM range(5) t(i) WHERE i >= ", 3, " ORDER BY i").compile() @@ -83,16 +91,6 @@ def test_relation_interpolation_works_end_to_end() -> None: assert conn.sql(query).fetchall() == [(0,), (2,), (4,)] -def test_interpolated_strings_are_parameterized_by_default() -> None: - conn = duckdb.connect() - conn.execute("CREATE TABLE names(name VARCHAR)") - conn.execute("INSERT INTO names VALUES ('alice'), ('bob')") - - untrusted = "alice' OR 1=1 --" - query = template("SELECT count(*) FROM names WHERE name = ", untrusted) - assert conn.sql(query).fetchone() == (0,) - - def test_builtin_duckdbpytype_object_interpolates_in_template() -> None: conn = duckdb.connect() integer_type = duckdb.sqltype("INTEGER")