-
Notifications
You must be signed in to change notification settings - Fork 2
feat(Init): Initial upload of array + interoperability package #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1a482dc
49226bc
ce74739
7ff0352
6c617e7
4633489
19b5060
8852fdb
f12d604
609b939
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| * @Simpag @nicola-bastianello |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| **/__pycache__ | ||
| **/build | ||
| *.egg-info | ||
| *.so | ||
| .DS_Store | ||
| .mypy_cache | ||
| .tox | ||
| .vscode | ||
| dist | ||
| pyrightconfig.json | ||
| .claude | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| """ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be nice to add instructions on how to run the benchmarks and interpret the results; maybe we add this when preparing the docs |
||
| Microbenchmark: ``decent_array.Array`` operator overhead vs native frameworks. | ||
|
|
||
| Measures the wrapper cost added by routing operators through ``Array.__add__``, | ||
| ``Array.__neg__`` etc. against calling the framework's native operators | ||
| directly. Iterates over every framework whose package is importable; missing | ||
| optional dependencies are skipped silently. | ||
|
|
||
| The overhead column is ``wrapped / native`` runtime — values close to 1.0x mean | ||
| the wrapper is essentially free. Large values at small sizes are expected | ||
| (operator dispatch dominates) and should converge toward 1.0x as elementwise | ||
| work grows. | ||
|
|
||
| Run with:: | ||
|
|
||
| python benchmarks/bench_array.py | ||
| """ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be nice to have benchmarks that directly compare compiled and uncompiled versions; or at least have an option at the top to say which version should be used
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in any case all the benchmarks look good to me. I run them and the results of compilation are fantastic! |
||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from bench_common import ( | ||
| SIZES, | ||
| BackendCase, | ||
| activate_backend, | ||
| discover_backends, | ||
| fmt_row, | ||
| parse_backends_arg, | ||
| print_preamble, | ||
| print_size_header, | ||
| time_us_safe, | ||
| ) | ||
|
|
||
| from decent_array import Array | ||
|
|
||
|
|
||
| def _bench_case(case: BackendCase) -> None: | ||
| activate_backend(case.name) | ||
| print(f"## {case.name}\n") | ||
| for n in SIZES: | ||
| a = case.make(n) | ||
| b = case.make(n) | ||
| d_a, d_b = Array(a), Array(b) | ||
|
|
||
| print_size_header(n) | ||
| rows = ( | ||
| ("add", lambda a=a, b=b: a + b, lambda d_a=d_a, d_b=d_b: d_a + d_b), | ||
| ("sub", lambda a=a, b=b: a - b, lambda d_a=d_a, d_b=d_b: d_a - d_b), | ||
| ("mul", lambda a=a, b=b: a * b, lambda d_a=d_a, d_b=d_b: d_a * d_b), | ||
| ("div", lambda a=a, b=b: a / b, lambda d_a=d_a, d_b=d_b: d_a / d_b), | ||
| ("neg", lambda a=a: -a, lambda d_a=d_a: -d_a), | ||
| ("abs", lambda a=a: abs(a), lambda d_a=d_a: abs(d_a)), | ||
| ("pow", lambda a=a: a ** 2.0, lambda d_a=d_a: d_a ** 2.0), | ||
| ) | ||
| for op, native_fn, wrapped_fn in rows: | ||
| n_us = time_us_safe(case, native_fn) | ||
| w_us = time_us_safe(case, wrapped_fn) | ||
| print(fmt_row(op, n_us, w_us)) | ||
| print() | ||
| print() | ||
|
|
||
|
|
||
| def main() -> None: | ||
| print_preamble("Array operator overhead vs native frameworks") | ||
| cases = discover_backends(only=parse_backends_arg()) | ||
| print(f"available backends: {', '.join(c.name for c in cases)}\n") | ||
| for case in cases: | ||
| _bench_case(case) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| """ | ||
| Shared helpers for ``bench_array.py`` and ``bench_iop.py``. | ||
|
|
||
| Three concerns live here so the benchmarks stay focused on the comparison logic: | ||
|
|
||
| * :func:`discover_backends` returns the subset of frameworks whose package is | ||
| importable; backends with a missing optional dependency are skipped silently. | ||
| * :func:`is_compiled` / :func:`print_preamble` report whether the user is | ||
| running against a mypyc-compiled build of ``decent_array`` or the pure-Python | ||
| source — this materially changes overhead numbers, so the result is printed | ||
| at the top of every run. | ||
| * :func:`time_us` / :func:`time_us_safe` wrap :mod:`timeit` to take the | ||
| ``min`` of several auto-ranged repeats. ``min`` is the canonical choice: it | ||
| reports the lower bound of the machine's per-call cost and is the metric | ||
| least sensitive to background activity. A warmup call precedes timing so | ||
| JIT-style backends (JAX) don't skew the first iteration. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import importlib | ||
| import timeit | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
|
|
||
| SIZES: tuple[int, ...] = (10, 100, 1_000, 10_000) | ||
| REPEATS: int = 7 | ||
|
|
||
|
|
||
| def _no_sync(_value: Any) -> None: # noqa: ANN401 | ||
| """No-op sync used for synchronous backends (numpy, torch CPU, tf eager CPU).""" | ||
|
|
||
|
|
||
| def _sync_jax(value: Any) -> None: # noqa: ANN401 | ||
| """Block until a JAX DeviceArray is materialized, unwrapping ``Array`` if needed.""" | ||
| # Imported lazily so the module can load even when decent_array isn't yet importable. | ||
| from decent_array import Array # noqa: PLC0415 | ||
|
|
||
| raw = value.value if isinstance(value, Array) else value | ||
| raw.block_until_ready() | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class BackendCase: | ||
| """A discovered backend plus the helpers needed to drive it in a benchmark.""" | ||
|
|
||
| name: str | ||
| make: Callable[[int], Any] | ||
| sync: Callable[[Any], None] | ||
|
|
||
|
|
||
| def activate_backend(name: str) -> None: | ||
| """Activate ``name`` as the live backend, resetting any previously active one. | ||
|
|
||
| ``decent_array`` enforces a single-active-backend invariant per execution context; | ||
| swapping between frameworks within one process requires resetting first. | ||
| """ | ||
| from decent_array.interoperability._backend_manager import reset_backends, set_backend # noqa: PLC0415 | ||
|
|
||
| reset_backends() | ||
| set_backend(name) | ||
|
|
||
|
|
||
| def discover_backends(only: list[str] | None = None) -> list[BackendCase]: | ||
| """Return one :class:`BackendCase` per importable framework, in a stable order. | ||
|
|
||
| Args: | ||
| only: Optional allowlist of backend names. When provided, frameworks not in the | ||
| list are skipped entirely (their packages aren't even imported), and any | ||
| requested name that isn't a known backend raises :class:`ValueError`. | ||
|
|
||
| """ | ||
| import numpy as np # always available — hard dependency # noqa: PLC0415 | ||
|
|
||
| known = {"numpy", "pytorch", "jax", "tensorflow"} | ||
| if only is not None: | ||
| unknown = set(only) - known | ||
| if unknown: | ||
| raise ValueError(f"unknown backend(s): {sorted(unknown)}; known: {sorted(known)}") | ||
| wanted = set(only) | ||
| else: | ||
| wanted = known | ||
|
|
||
| cases: list[BackendCase] = [] | ||
|
|
||
| if "numpy" in wanted: | ||
| cases.append(BackendCase("numpy", lambda n: np.random.rand(n), _no_sync)) | ||
|
|
||
| if "pytorch" in wanted: | ||
| try: | ||
| import torch # noqa: PLC0415 | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| cases.append(BackendCase("pytorch", lambda n: torch.from_numpy(np.random.rand(n)), _no_sync)) | ||
|
|
||
| if "jax" in wanted: | ||
| try: | ||
| import jax.numpy as jnp # noqa: PLC0415 | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| cases.append(BackendCase("jax", lambda n: jnp.asarray(np.random.rand(n)), _sync_jax)) | ||
|
|
||
| if "tensorflow" in wanted: | ||
| try: | ||
| import tensorflow as tf # noqa: PLC0415 | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| cases.append(BackendCase("tensorflow", lambda n: tf.constant(np.random.rand(n)), _no_sync)) | ||
|
|
||
| return cases | ||
|
|
||
|
|
||
| def parse_backends_arg() -> list[str] | None: | ||
| """Parse the shared ``--backends`` CLI flag; returns ``None`` if not given.""" | ||
| import argparse # noqa: PLC0415 | ||
|
|
||
| parser = argparse.ArgumentParser(add_help=True) | ||
| parser.add_argument( | ||
| "--backends", | ||
| type=str, | ||
| default=None, | ||
| help="comma-separated allowlist of backends (numpy,pytorch,jax,tensorflow); default = all available", | ||
| ) | ||
| args = parser.parse_args() | ||
| if args.backends is None: | ||
| return None | ||
| return [b.strip() for b in args.backends.split(",") if b.strip()] | ||
|
|
||
|
|
||
| def is_compiled() -> tuple[bool, str]: | ||
| """Return ``(True, path)`` if the Array module loaded from a ``.so``/``.pyd``, else ``(False, .py path)``.""" | ||
| module = importlib.import_module("decent_array._array") | ||
| path = module.__file__ or "<unknown>" | ||
| return path.endswith((".so", ".pyd")), path | ||
|
|
||
|
|
||
| def print_preamble(title: str) -> None: | ||
| compiled, path = is_compiled() | ||
| print(f"# {title}\n") | ||
| print(f"decent_array compiled: {'yes' if compiled else 'no'}") | ||
| print(f" Array loaded from: {path}") | ||
| print(f" timing: min over {REPEATS} repeats, iterations per repeat auto-tuned to ~0.2s\n") | ||
|
|
||
|
|
||
| def time_us(case: BackendCase, fn: Callable[[], Any]) -> float: | ||
| """Per-call runtime in µs; min over :data:`REPEATS` measurements with autoranged N.""" | ||
|
|
||
| def runner() -> None: | ||
| case.sync(fn()) | ||
|
|
||
| runner() # warmup — material for JAX's first-call compilation | ||
| timer = timeit.Timer(runner) | ||
| n, _ = timer.autorange() | ||
| times = timer.repeat(repeat=REPEATS, number=n) | ||
| return (min(times) / n) * 1e6 | ||
|
|
||
|
|
||
| def time_us_safe(case: BackendCase, fn: Callable[[], Any]) -> float | None: | ||
| """Like :func:`time_us` but returns ``None`` if ``fn`` raises (e.g. TF 1D matmul).""" | ||
| try: | ||
| return time_us(case, fn) | ||
| except Exception: # noqa: BLE001 | ||
| return None | ||
|
|
||
|
|
||
| def fmt_row(op: str, native_us: float | None, wrapped_us: float | None) -> str: | ||
| if native_us is None or wrapped_us is None: | ||
| return f" {op:<8} {'n/a':>13} {'n/a':>13} {'n/a':>8}" | ||
| ratio = wrapped_us / native_us if native_us > 0 else float("inf") | ||
| return f" {op:<8} {native_us:>10.3f} µs {wrapped_us:>10.3f} µs {ratio:>6.2f}x" | ||
|
|
||
|
|
||
| def print_size_header(n: int) -> None: | ||
| print(f"size = {n:_}") | ||
| print(f" {'op':<8} {'native':>13} {'wrapped':>13} {'overhead':>8}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should add
.pyd?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on my system at least they are not ignored by vscode source control