Skip to content

Commit 3f66fa3

Browse files
committed
Speed up cached factory resolution
1 parent 563ffac commit 3f66fa3

2 files changed

Lines changed: 153 additions & 34 deletions

File tree

src/miniject/_container.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -303,33 +303,38 @@ def _invoke_factory(
303303
**overrides: Any,
304304
) -> Any:
305305
"""Call a factory, resolving its parameters from the container."""
306-
sig_and_hints = introspect_factory(factory, resolution_error=ResolutionError)
307-
if sig_and_hints is None:
308-
if inspect.iscoroutinefunction(factory):
309-
_raise_async_resolution_error(factory, stack=_stack)
310-
instance = factory()
311-
if inspect.isawaitable(instance):
312-
if inspect.iscoroutine(instance):
313-
instance.close()
314-
_raise_async_resolution_error(factory, stack=_stack, returned_awaitable=True)
315-
return instance
316-
sig, hints = sig_and_hints
306+
plan = introspect_factory(factory, resolution_error=ResolutionError)
307+
if plan is None:
308+
return _call_sync_factory_checked(
309+
factory,
310+
stack=_stack,
311+
is_async=inspect.iscoroutinefunction(factory),
312+
)
313+
fast_args = self._build_fast_positional_args(
314+
plan.fast_positional_deps,
315+
_stack=_stack,
316+
overrides=overrides,
317+
)
318+
if fast_args is not None:
319+
return _call_sync_factory_checked(
320+
factory,
321+
args=fast_args,
322+
stack=_stack,
323+
is_async=plan.is_async,
324+
)
317325
kwargs = self._build_factory_kwargs(
318326
factory,
319-
sig,
320-
hints,
327+
plan.signature,
328+
plan.hints,
321329
_stack=_stack,
322330
overrides=overrides,
323331
)
324-
if inspect.iscoroutinefunction(factory):
325-
_raise_async_resolution_error(factory, stack=_stack)
326-
327-
instance = factory(**kwargs)
328-
if inspect.isawaitable(instance):
329-
if inspect.iscoroutine(instance):
330-
instance.close()
331-
_raise_async_resolution_error(factory, stack=_stack, returned_awaitable=True)
332-
return instance
332+
return _call_sync_factory_checked(
333+
factory,
334+
kwargs=kwargs,
335+
stack=_stack,
336+
is_async=plan.is_async,
337+
)
333338

334339
async def _invoke_factory_async(
335340
self,
@@ -339,17 +344,26 @@ async def _invoke_factory_async(
339344
**overrides: Any,
340345
) -> Any:
341346
"""Call a factory, awaiting it if needed and resolving deps async."""
342-
sig_and_hints = introspect_factory(factory, resolution_error=ResolutionError)
343-
if sig_and_hints is None:
347+
plan = introspect_factory(factory, resolution_error=ResolutionError)
348+
if plan is None:
344349
instance = factory()
345350
if inspect.isawaitable(instance):
346351
return await instance
347352
return instance
348-
sig, hints = sig_and_hints
353+
fast_args = await self._build_fast_positional_args_async(
354+
plan.fast_positional_deps,
355+
_stack=_stack,
356+
overrides=overrides,
357+
)
358+
if fast_args is not None:
359+
instance = factory(*fast_args)
360+
if inspect.isawaitable(instance):
361+
return await instance
362+
return instance
349363
kwargs = await self._build_factory_kwargs_async(
350364
factory,
351-
sig,
352-
hints,
365+
plan.signature,
366+
plan.hints,
353367
_stack=_stack,
354368
overrides=overrides,
355369
)
@@ -397,6 +411,24 @@ def _build_factory_kwargs(
397411
)
398412
return kwargs
399413

414+
def _build_fast_positional_args(
415+
self,
416+
fast_positional_deps: tuple[type, ...] | None,
417+
*,
418+
_stack: tuple[type, ...],
419+
overrides: dict[str, Any],
420+
) -> tuple[object, ...] | None:
421+
if fast_positional_deps is None or overrides:
422+
return None
423+
424+
resolved: list[object] = []
425+
for dep in fast_positional_deps:
426+
if self._find_binding(dep) is None:
427+
return None
428+
resolved_instance = cast("object", self._resolve(dep, _stack=_stack))
429+
resolved.append(resolved_instance)
430+
return tuple(resolved)
431+
400432
async def _build_factory_kwargs_async(
401433
self,
402434
factory: Callable[..., Any],
@@ -438,6 +470,24 @@ async def _build_factory_kwargs_async(
438470
)
439471
return kwargs
440472

473+
async def _build_fast_positional_args_async(
474+
self,
475+
fast_positional_deps: tuple[type, ...] | None,
476+
*,
477+
_stack: tuple[type, ...],
478+
overrides: dict[str, Any],
479+
) -> tuple[object, ...] | None:
480+
if fast_positional_deps is None or overrides:
481+
return None
482+
483+
resolved: list[object] = []
484+
for dep in fast_positional_deps:
485+
if self._find_binding(dep) is None:
486+
return None
487+
resolved_instance = cast("object", await self._resolve_async(dep, _stack=_stack))
488+
resolved.append(resolved_instance)
489+
return tuple(resolved)
490+
441491

442492
def _require_factory(binding: _Binding) -> Callable[..., Any]:
443493
"""Return a binding factory or raise if the binding is malformed."""
@@ -463,6 +513,25 @@ def _raise_async_resolution_error(
463513
)
464514

465515

516+
def _call_sync_factory_checked(
517+
factory: Callable[..., Any],
518+
*,
519+
stack: tuple[type, ...],
520+
is_async: bool,
521+
args: tuple[object, ...] = (),
522+
kwargs: dict[str, Any] | None = None,
523+
) -> Any:
524+
if is_async:
525+
_raise_async_resolution_error(factory, stack=stack)
526+
527+
instance = factory(*args, **({} if kwargs is None else kwargs))
528+
if inspect.isawaitable(instance):
529+
if inspect.iscoroutine(instance):
530+
instance.close()
531+
_raise_async_resolution_error(factory, stack=stack, returned_awaitable=True)
532+
return instance
533+
534+
466535
def _validate_or_skip_missing_param(
467536
factory: Callable[..., Any],
468537
param_name: str,

src/miniject/_introspection.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_DISALLOWED_AUTO_INJECT_TYPES: frozenset[type] = frozenset({bool, bytes, float, int, str})
1414
_NONE_TYPE: type[None] = type(None)
1515
_UNION_TYPES: tuple[object, ...] = (typing.Union, types.UnionType)
16-
_INTROSPECTION_CACHE: dict[object, tuple[inspect.Signature, dict[str, Any]] | None] = {}
16+
_INTROSPECTION_CACHE: dict[object, FactoryPlan | None] = {}
1717

1818

1919
@dataclass(frozen=True, slots=True)
@@ -24,18 +24,28 @@ class ResolvedParamType:
2424
display_name: str
2525

2626

27+
@dataclass(frozen=True, slots=True)
28+
class FactoryPlan:
29+
"""Cached callable metadata used during dependency resolution."""
30+
31+
signature: inspect.Signature
32+
hints: dict[str, Any]
33+
is_async: bool
34+
fast_positional_deps: tuple[type, ...] | None
35+
36+
2737
def introspect_factory(
2838
factory: Callable[..., Any],
2939
*,
3040
resolution_error: type[Exception],
31-
) -> tuple[inspect.Signature, dict[str, Any]] | None:
41+
) -> FactoryPlan | None:
3242
"""Extract signature and type hints for a factory, or None if not introspectable."""
3343
try:
3444
cached = _INTROSPECTION_CACHE.get(factory, _EMPTY)
3545
except TypeError:
3646
return _compute_factory_introspection(factory, resolution_error=resolution_error)
3747
if cached is not _EMPTY:
38-
return cast("tuple[inspect.Signature, dict[str, Any]] | None", cached)
48+
return cast("FactoryPlan | None", cached)
3949

4050
result = _compute_factory_introspection(factory, resolution_error=resolution_error)
4151
_INTROSPECTION_CACHE[factory] = result
@@ -46,26 +56,66 @@ def _compute_factory_introspection(
4656
factory: Callable[..., Any],
4757
*,
4858
resolution_error: type[Exception],
49-
) -> tuple[inspect.Signature, dict[str, Any]] | None:
59+
) -> FactoryPlan | None:
5060
"""Compute signature and type hints for a factory without caching."""
5161
hint_target = factory.__init__ if isinstance(factory, type) else factory
62+
factory_name = callable_name(factory)
5263
try:
5364
sig = inspect.signature(factory)
5465
except (ValueError, TypeError):
5566
return None
5667
except NameError as exc:
5768
_raise_type_hint_resolution_error(
5869
hint_target,
59-
factory_name=callable_name(factory),
70+
factory_name=factory_name,
6071
resolution_error=resolution_error,
6172
exc=exc,
6273
)
6374
hints = _get_type_hints_or_raise(
6475
hint_target,
65-
factory_name=callable_name(factory),
76+
factory_name=factory_name,
6677
resolution_error=resolution_error,
6778
)
68-
return sig, hints
79+
return FactoryPlan(
80+
signature=sig,
81+
hints=hints,
82+
is_async=inspect.iscoroutinefunction(factory),
83+
fast_positional_deps=_build_fast_positional_deps(
84+
sig,
85+
hints,
86+
factory_name=factory_name,
87+
resolution_error=resolution_error,
88+
),
89+
)
90+
91+
92+
def _build_fast_positional_deps(
93+
sig: inspect.Signature,
94+
hints: dict[str, Any],
95+
*,
96+
factory_name: str,
97+
resolution_error: type[Exception],
98+
) -> tuple[type, ...] | None:
99+
deps: list[type] = []
100+
for param_name, param in sig.parameters.items():
101+
if param_name == "self":
102+
continue
103+
if param.kind not in (
104+
inspect.Parameter.POSITIONAL_ONLY,
105+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
106+
):
107+
return None
108+
109+
resolved_type = resolve_param_type(
110+
hints.get(param_name),
111+
factory_name=factory_name,
112+
param_name=param_name,
113+
resolution_error=resolution_error,
114+
)
115+
if resolved_type.binding_key is None or param.default is not inspect.Parameter.empty:
116+
return None
117+
deps.append(resolved_type.binding_key)
118+
return tuple(deps)
69119

70120

71121
def _get_type_hints_or_raise(

0 commit comments

Comments
 (0)