From 3b025bd02f53ffb1e8d7d8d31ca5f70981683b85 Mon Sep 17 00:00:00 2001 From: neurolabusc Date: Sun, 1 Mar 2026 09:45:33 -0500 Subject: [PATCH 1/2] Metal support --- .gitignore | 9 + README.md | 8 +- cuslines/__init__.py | 56 +- cuslines/boot_utils.py | 72 + cuslines/cuda_python/cu_direction_getters.py | 85 +- cuslines/metal/README.md | 127 ++ cuslines/metal/__init__.py | 13 + cuslines/metal/mt_direction_getters.py | 463 ++++ cuslines/metal/mt_propagate_seeds.py | 204 ++ cuslines/metal/mt_tractography.py | 246 +++ cuslines/metal/mutils.py | 142 ++ cuslines/metal_shaders/boot.metal | 869 ++++++++ cuslines/metal_shaders/disc.h | 1890 +++++++++++++++++ .../generate_streamlines_metal.metal | 400 ++++ cuslines/metal_shaders/globals.h | 61 + cuslines/metal_shaders/philox_rng.h | 152 ++ cuslines/metal_shaders/ptt.metal | 1061 +++++++++ cuslines/metal_shaders/tracking_helpers.metal | 221 ++ cuslines/metal_shaders/types.h | 50 + cuslines/metal_shaders/utils.metal | 107 + cuslines/metal_shaders/warp_sort.metal | 109 + pyproject.toml | 4 + run_gpu_streamlines.py | 14 +- setup.py | 2 +- 24 files changed, 6278 insertions(+), 87 deletions(-) create mode 100644 cuslines/boot_utils.py create mode 100644 cuslines/metal/README.md create mode 100644 cuslines/metal/__init__.py create mode 100644 cuslines/metal/mt_direction_getters.py create mode 100644 cuslines/metal/mt_propagate_seeds.py create mode 100644 cuslines/metal/mt_tractography.py create mode 100644 cuslines/metal/mutils.py create mode 100644 cuslines/metal_shaders/boot.metal create mode 100644 cuslines/metal_shaders/disc.h create mode 100644 cuslines/metal_shaders/generate_streamlines_metal.metal create mode 100644 cuslines/metal_shaders/globals.h create mode 100644 cuslines/metal_shaders/philox_rng.h create mode 100644 cuslines/metal_shaders/ptt.metal create mode 100644 cuslines/metal_shaders/tracking_helpers.metal create mode 100644 cuslines/metal_shaders/types.h create mode 100644 cuslines/metal_shaders/utils.metal create mode 100644 cuslines/metal_shaders/warp_sort.metal diff --git a/.gitignore b/.gitignore index 78bb5e2..4718dd1 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,12 @@ *.pyo *.pyd +# Build artifacts +*.egg-info/ +dist/ +build/ + +# Test outputs +*.trk +*.trx +*.nii.gz diff --git a/README.md b/README.md index 9ae3163..0d2ae45 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,16 @@ ## Installation To install from pypi, simply run `pip install "cuslines[cu13]"` or `pip install "cuslines[cu12]"` depending on your CUDA version. -To install from dev, simply run `pip install ".[cu13]"` or `pip install ".[cu12]"` in the top-level repository directory. +For Apple Silicon (M1/M2/M3/M4), install the Metal backend: `pip install "cuslines[metal]"` + +To install from dev, simply run `pip install ".[cu13]"` or `pip install ".[cu12]"` (or `pip install ".[metal]"` on macOS) in the top-level repository directory. + +The GPU backend is auto-detected at import time. On macOS with Apple Silicon, Metal is used; on Linux/Windows with an NVIDIA GPU, CUDA is used. ## Running the examples This repository contains several example usage scripts. -The script `run_gpu_streamlines.py` demonstrates how to run any diffusion MRI dataset on the GPU. It can also run on the CPU for reference, if the argument `--device=cpu` is used. If not data is passed, it will donaload and use the HARDI dataset. +The script `run_gpu_streamlines.py` demonstrates how to run any diffusion MRI dataset on the GPU. It can also run on the CPU for reference, if the argument `--device=cpu` is used. Use `--device=metal` to explicitly select the Metal backend on macOS. If no data is passed, it will download and use the HARDI dataset. To run the baseline CPU example on a random set of 1000 seeds, this is the command and example output: ``` diff --git a/cuslines/__init__.py b/cuslines/__init__.py index b96cca1..07b44ea 100644 --- a/cuslines/__init__.py +++ b/cuslines/__init__.py @@ -1,13 +1,55 @@ -from .cuda_python import ( - GPUTracker, - ProbDirectionGetter, - PttDirectionGetter, - BootDirectionGetter -) +import platform as _platform + + +def _detect_backend(): + """Auto-detect the best available GPU backend.""" + system = _platform.system() + if system == "Darwin": + try: + import Metal + + if Metal.MTLCreateSystemDefaultDevice() is not None: + return "metal" + except ImportError: + pass + try: + from cuda.bindings import runtime + + count = runtime.cudaGetDeviceCount() + if count[1] > 0: + return "cuda" + except (ImportError, Exception): + pass + return None + + +BACKEND = _detect_backend() + +if BACKEND == "metal": + from cuslines.metal import ( + MetalGPUTracker as GPUTracker, + MetalProbDirectionGetter as ProbDirectionGetter, + MetalPttDirectionGetter as PttDirectionGetter, + MetalBootDirectionGetter as BootDirectionGetter, + ) +elif BACKEND == "cuda": + from cuslines.cuda_python import ( + GPUTracker, + ProbDirectionGetter, + PttDirectionGetter, + BootDirectionGetter, + ) +else: + raise ImportError( + "No GPU backend available. Install either:\n" + " - CUDA: pip install 'cuslines[cu13]' (NVIDIA GPU)\n" + " - Metal: pip install 'cuslines[metal]' (Apple Silicon)" + ) __all__ = [ "GPUTracker", "ProbDirectionGetter", "PttDirectionGetter", - "BootDirectionGetter" + "BootDirectionGetter", + "BACKEND", ] diff --git a/cuslines/boot_utils.py b/cuslines/boot_utils.py new file mode 100644 index 0000000..50abd7b --- /dev/null +++ b/cuslines/boot_utils.py @@ -0,0 +1,72 @@ +"""Shared utilities for bootstrap direction getters (CUDA and Metal). + +Extracts DIPY model matrices (H, R, delta_b, delta_q, sampling_matrix) +for OPDT and CSA models. Both backends need the same matrices — only +the GPU dispatch differs. +""" + +from dipy.reconst import shm + + +def prepare_opdt(gtab, sphere, sh_order_max=6, full_basis=False, + sh_lambda=0.006, min_signal=1): + """Build bootstrap matrices for the OPDT model. + + Returns dict with keys: model_type, min_signal, H, R, delta_b, + delta_q, sampling_matrix, b0s_mask. + """ + sampling_matrix, _, _ = shm.real_sh_descoteaux( + sh_order_max, sphere.theta, sphere.phi, + full_basis=full_basis, legacy=True, + ) + model = shm.OpdtModel( + gtab, sh_order_max=sh_order_max, smooth=sh_lambda, + min_signal=min_signal, + ) + delta_b, delta_q = model._fit_matrix + + H, R = _hat_and_lcr(gtab, model, sh_order_max) + + return dict( + model_type="OPDT", min_signal=min_signal, + H=H, R=R, delta_b=delta_b, delta_q=delta_q, + sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask, + ) + + +def prepare_csa(gtab, sphere, sh_order_max=6, full_basis=False, + sh_lambda=0.006, min_signal=1): + """Build bootstrap matrices for the CSA model. + + Returns dict with keys: model_type, min_signal, H, R, delta_b, + delta_q, sampling_matrix, b0s_mask. + """ + sampling_matrix, _, _ = shm.real_sh_descoteaux( + sh_order_max, sphere.theta, sphere.phi, + full_basis=full_basis, legacy=True, + ) + model = shm.CsaOdfModel( + gtab, sh_order_max=sh_order_max, smooth=sh_lambda, + min_signal=min_signal, + ) + delta_b = model._fit_matrix + delta_q = model._fit_matrix + + H, R = _hat_and_lcr(gtab, model, sh_order_max) + + return dict( + model_type="CSA", min_signal=min_signal, + H=H, R=R, delta_b=delta_b, delta_q=delta_q, + sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask, + ) + + +def _hat_and_lcr(gtab, model, sh_order_max): + """Compute hat matrix H and leveraged centered residuals matrix R.""" + dwi_mask = ~gtab.b0s_mask + x, y, z = model.gtab.gradients[dwi_mask].T + _, theta, phi = shm.cart2sphere(x, y, z) + B, _, _ = shm.real_sh_descoteaux(sh_order_max, theta, phi, legacy=True) + H = shm.hat(B) + R = shm.lcr_matrix(H) + return H, R diff --git a/cuslines/cuda_python/cu_direction_getters.py b/cuslines/cuda_python/cu_direction_getters.py index 617f893..36d2c66 100644 --- a/cuslines/cuda_python/cu_direction_getters.py +++ b/cuslines/cuda_python/cu_direction_getters.py @@ -4,7 +4,7 @@ from importlib.resources import files from time import time -from dipy.reconst import shm +from cuslines.boot_utils import prepare_opdt, prepare_csa from cuda.core import Device, LaunchConfig, Program, launch, ProgramOptions from cuda.pathfinder import find_nvidia_header_directory @@ -135,83 +135,16 @@ def __init__( self.compile_program() @classmethod - def from_dipy_opdt( - cls, - gtab, - sphere, - sh_order_max=6, - full_basis=False, - sh_lambda=0.006, - min_signal=1, - ): - sampling_matrix, _, _ = shm.real_sh_descoteaux( - sh_order_max, sphere.theta, sphere.phi, full_basis=full_basis, legacy=False - ) - - model = shm.OpdtModel( - gtab, sh_order_max=sh_order_max, smooth=sh_lambda, min_signal=min_signal - ) - fit_matrix = model._fit_matrix - delta_b, delta_q = fit_matrix - - b0s_mask = gtab.b0s_mask - dwi_mask = ~b0s_mask - x, y, z = model.gtab.gradients[dwi_mask].T - _, theta, phi = shm.cart2sphere(x, y, z) - B, _, _ = shm.real_sym_sh_basis(sh_order_max, theta, phi) - H = shm.hat(B) - R = shm.lcr_matrix(H) - - return cls( - model_type="OPDT", - min_signal=min_signal, - H=H, - R=R, - delta_b=delta_b, - delta_q=delta_q, - sampling_matrix=sampling_matrix, - b0s_mask=gtab.b0s_mask, - ) + def from_dipy_opdt(cls, gtab, sphere, sh_order_max=6, full_basis=False, + sh_lambda=0.006, min_signal=1): + return cls(**prepare_opdt(gtab, sphere, sh_order_max, full_basis, + sh_lambda, min_signal)) @classmethod - def from_dipy_csa( - cls, - gtab, - sphere, - sh_order_max=6, - full_basis=False, - sh_lambda=0.006, - min_signal=1, - ): - sampling_matrix, _, _ = shm.real_sh_descoteaux( - sh_order_max, sphere.theta, sphere.phi, full_basis=full_basis, legacy=False - ) - - model = shm.CsaOdfModel( - gtab, sh_order_max=sh_order_max, smooth=sh_lambda, min_signal=min_signal - ) - fit_matrix = model._fit_matrix - delta_b = fit_matrix - delta_q = fit_matrix - - b0s_mask = gtab.b0s_mask - dwi_mask = ~b0s_mask - x, y, z = model.gtab.gradients[dwi_mask].T - _, theta, phi = shm.cart2sphere(x, y, z) - B, _, _ = shm.real_sym_sh_basis(sh_order_max, theta, phi) - H = shm.hat(B) - R = shm.lcr_matrix(H) - - return cls( - model_type="CSA", - min_signal=min_signal, - H=H, - R=R, - delta_b=delta_b, - delta_q=delta_q, - sampling_matrix=sampling_matrix, - b0s_mask=gtab.b0s_mask, - ) + def from_dipy_csa(cls, gtab, sphere, sh_order_max=6, full_basis=False, + sh_lambda=0.006, min_signal=1): + return cls(**prepare_csa(gtab, sphere, sh_order_max, full_basis, + sh_lambda, min_signal)) def allocate_on_gpu(self, n): self.H_d.append(checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.H.size))) diff --git a/cuslines/metal/README.md b/cuslines/metal/README.md new file mode 100644 index 0000000..966704b --- /dev/null +++ b/cuslines/metal/README.md @@ -0,0 +1,127 @@ +# Metal Backend for GPUStreamlines + +The Metal backend runs GPU-accelerated tractography on Apple Silicon (M1/M2/M3/M4) using Apple's Metal Shading Language. It mirrors the CUDA backend's functionality with the same API surface, and is auto-detected at import time on macOS. + +## Installation + +```bash +pip install "cuslines[metal]" # from PyPI +pip install ".[metal]" # from source +``` + +Requires macOS 13+ and Apple Silicon. Dependencies: `pyobjc-framework-Metal`, `pyobjc-framework-MetalPerformanceShaders`. + +## Usage + +```bash +# GPU (auto-detects Metal on macOS) +python run_gpu_streamlines.py --output-prefix out --nseeds 10000 --ngpus 1 + +# Explicit Metal device +python run_gpu_streamlines.py --device metal --output-prefix out --nseeds 10000 + +# CPU reference (DIPY) +python run_gpu_streamlines.py --device cpu --output-prefix out_cpu --nseeds 10000 +``` + +All CLI arguments (`--max-angle`, `--step-size`, `--fa-threshold`, `--model`, `--dg`, etc.) work identically to the CUDA backend. + +## Benchmarks + +Measured on Apple M4 Pro (20-core GPU), Stanford HARDI dataset (81x106x76, 160 directions), OPDT model with bootstrap direction getter, 10,000 seeds: + +| | Metal GPU | CPU (DIPY) | +|---|---|---| +| **Streamline generation time** | 0.89 s | 91.6 s | +| **Speedup** | **~100x** | 1x | +| **Streamlines generated** | 13,205 | 13,647 | +| **Mean fiber length** | 53.8 pts | 45.4 pts | +| **Median fiber length** | 42.0 pts | 33.0 pts | +| **Commissural fibers** | 1,656 | 1,522 | + +The GPU produces comparable streamline counts and commissural fiber density. Mean fiber length is ~18% longer on the GPU due to float32 vs float64 precision differences in ODF peak selection at fiber crossings. + +## Architecture + +### Unified memory advantage + +Apple Silicon shares CPU and GPU memory. Metal buffers use `storageModeShared`, so numpy arrays backing `MTLBuffer` objects are directly GPU-accessible. The CUDA backend requires ~6 `cudaMemcpy` calls per seed batch to transfer data between host and device; **the Metal backend requires zero**. For workloads with large read-only input data (the 4D ODF array is often hundreds of MB), this eliminates a significant source of latency. + +### Kernel compilation + +MSL source files in `cuslines/metal_shaders/` are concatenated and compiled at runtime via `MTLDevice.newLibraryWithSource`. This mirrors the CUDA path (NVRTC), with compile-time constants passed as preprocessor defines. + +### File layout + +``` +cuslines/metal/ + mt_tractography.py MetalGPUTracker context manager + mt_propagate_seeds.py Chunked seed processing (no memcpy) + mt_direction_getters.py Boot/Prob/PTT direction getters + mutils.py Types, aligned allocation, error checking + +cuslines/metal_shaders/ + globals.h Shared constants (float32 only) + types.h packed_float3 <-> float3 helpers + philox_rng.h Philox4x32-10 RNG (replaces curand) + boot.metal Bootstrap direction getter kernel + ptt.metal PTT direction getter kernel + generate_streamlines_metal.metal Main streamline generation kernel + tracking_helpers.metal Trilinear interpolation, peak finding + utils.metal SIMD reductions, prefix sum + warp_sort.metal Bitonic sort + disc.h Lookup tables for PTT +``` + +### Key implementation details + +- **float3 alignment**: CUDA `float3` is 12 bytes in arrays; Metal `float3` is 16 bytes. All device buffers use `packed_float3` (12 bytes) with `load_f3()`/`store_f3()` helpers for register conversion. +- **Page alignment**: Metal shared buffers require 16KB-aligned memory. `aligned_array()` in `mutils.py` handles this. +- **RNG**: Philox4x32-10 counter-based RNG in MSL, matching curand's algorithm for reproducible streams. +- **SIMD mapping**: CUDA warp primitives map directly to Metal SIMD group operations (`__shfl_sync` -> `simd_shuffle`, `__ballot_sync` -> `simd_ballot`). Apple GPU SIMD width is 32, matching CUDA's warp size. +- **No double precision**: Metal GPUs do not support float64. Only the float32 path is ported. +- **SH basis convention**: The sampling matrix, H/R matrices, and OPDT/CSA model matrices must all use the same spherical harmonics basis (`real_sh_descoteaux` with `legacy=True`). A basis mismatch causes sign flips in odd-m SH columns that corrupt ODF reconstruction. + +## Optional: Soft Angular Weighting + +The bootstrap direction getter in `boot.metal` includes an optional soft angular weighting feature that is **disabled by default** and compiled out at the preprocessor level (zero runtime cost when disabled). + +### Motivation + +At fiber crossings (e.g., the corona radiata, where commissural and projection fibers intersect), the ODF typically shows multiple peaks. The standard algorithm selects the peak closest to the current trajectory direction. However, when two peaks have similar magnitudes, float32 precision noise can cause the wrong peak to be selected, sending the fiber on an incorrect trajectory. + +In biological white matter, a fiber that has been traveling in a consistent direction is more likely to continue in that direction than to make a sharp turn. This prior is not captured by the standard closest-peak algorithm, which treats all peaks above threshold equally during the peak-finding step. + +### Implementation + +When enabled, the weighting multiplies each ODF sample by an angular similarity factor before the PMF threshold is applied: + +``` +PMF[j] *= (1 - w) + w * |cos(angle between current direction and sphere vertex j)| +``` + +This has two effects: +1. Peaks aligned with the current trajectory retain full weight +2. Perpendicular peaks are suppressed by a factor of `(1 - w)` + +Because the weighting is applied before the 5% absolute threshold and 25% relative peak threshold, it can prevent aligned peaks from being incorrectly zeroed out when a strong perpendicular peak dominates. + +### Configuration + +Set the `angular_weight` attribute on the direction getter before tracking: + +```python +from cuslines import BootDirectionGetter +dg = BootDirectionGetter.from_dipy_opdt(gtab, sphere) +dg.angular_weight = 0.5 # 0.0 = disabled (default), 0.5 = moderate +``` + +### Effect on tracking (10,000 seeds, HARDI dataset) + +| | weight = 0.0 (default) | weight = 0.5 | CPU (DIPY) | +|---|---|---|---| +| **Streamlines** | 13,205 | 13,307 | 13,647 | +| **Mean fiber length** | 53.8 pts | 64.8 pts | 45.4 pts | +| **Commissural fibers** | 1,656 | 1,915 | 1,522 | + +With the corrected SH basis, the default (no weighting) already produces good parity with CPU. The weighting increases mean fiber length and commissural fiber count beyond what the CPU produces. Whether this deviation is desirable depends on the application: for strict CPU/GPU reproducibility, leave it disabled; for applications where longer fibers through crossing regions are preferred, a value of 0.3-0.5 may be appropriate. diff --git a/cuslines/metal/__init__.py b/cuslines/metal/__init__.py new file mode 100644 index 0000000..00a75ed --- /dev/null +++ b/cuslines/metal/__init__.py @@ -0,0 +1,13 @@ +from cuslines.metal.mt_tractography import MetalGPUTracker +from cuslines.metal.mt_direction_getters import ( + MetalBootDirectionGetter, + MetalProbDirectionGetter, + MetalPttDirectionGetter, +) + +__all__ = [ + "MetalGPUTracker", + "MetalBootDirectionGetter", + "MetalProbDirectionGetter", + "MetalPttDirectionGetter", +] diff --git a/cuslines/metal/mt_direction_getters.py b/cuslines/metal/mt_direction_getters.py new file mode 100644 index 0000000..d6ed0ff --- /dev/null +++ b/cuslines/metal/mt_direction_getters.py @@ -0,0 +1,463 @@ +"""Metal direction getters — mirrors cuslines/cuda_python/cu_direction_getters.py. + +Compiles MSL shaders at runtime and dispatches kernel launches via +MTLComputeCommandEncoder. +""" + +import numpy as np +import struct +from abc import ABC, abstractmethod +import logging +from importlib.resources import files +from time import time + +from cuslines.boot_utils import prepare_opdt, prepare_csa + +from cuslines.metal.mutils import ( + REAL_SIZE, + REAL_DTYPE, + REAL3_SIZE, + BLOCK_Y, + THR_X_SL, + div_up, + checkMetalError, +) + +logger = logging.getLogger("GPUStreamlines") + + +class MetalGPUDirectionGetter(ABC): + """Abstract base for Metal direction getters.""" + + # Soft angular weighting factor for bootstrap direction getters. + # 0.0 = disabled (match CPU behavior), 0.5 = moderate bias toward + # current trajectory at fiber crossings. + angular_weight = 0.0 + + @abstractmethod + def getNumStreamlines(self, nseeds_gpu, block, grid, sp): + pass + + @abstractmethod + def generateStreamlines(self, nseeds_gpu, block, grid, sp): + pass + + def setup_device(self, device): + """Called once when GPUTracker allocates resources.""" + pass + + def compile_program(self, device): + import Metal + import re + + start_time = time() + logger.info("Compiling Metal shaders...") + + shader_dir = files("cuslines").joinpath("metal_shaders") + + # Read header files in dependency order and inline them. + # Metal's runtime compiler doesn't support include search paths, + # so we prepend all headers and strip #include "..." directives. + header_files = [ + "globals.h", + "types.h", + "philox_rng.h", + ] + # Add disc.h if boot.metal or ptt.metal is in the shader set + if "boot.metal" in self._shader_files() or "ptt.metal" in self._shader_files(): + header_files.append("disc.h") + + source_parts = [] + for fname in header_files: + path = shader_dir.joinpath(fname) + with open(path, "r") as f: + source_parts.append(f"// ── {fname} ──\n") + source_parts.append(f.read()) + + # Metal source files + metal_files = [ + "utils.metal", + "warp_sort.metal", + "tracking_helpers.metal", + ] + metal_files += self._shader_files() + metal_files.append("generate_streamlines_metal.metal") + + for fname in metal_files: + path = shader_dir.joinpath(fname) + with open(path, "r") as f: + src = f.read() + # Strip local #include directives (headers already inlined above) + src = re.sub(r'#include\s+"[^"]*"', '', src) + source_parts.append(f"// ── {fname} ──\n") + source_parts.append(src) + + full_source = "\n".join(source_parts) + + # Prepend compile-time constants + enable = 1 if self.angular_weight > 0 else 0 + defines = ( + f"#define ENABLE_ANGULAR_WEIGHT {enable}\n" + f"#define ANGULAR_WEIGHT {self.angular_weight:.2f}f\n" + ) + full_source = defines + full_source + + options = Metal.MTLCompileOptions.new() + options.setFastMathEnabled_(True) + + library, error = device.newLibraryWithSource_options_error_( + full_source, options, None + ) + if error is not None: + raise RuntimeError(f"Metal shader compilation failed: {error}") + + self.library = library + logger.info("Metal shaders compiled in %.2f seconds", time() - start_time) + + def _shader_files(self): + """Return list of additional .metal files needed by this direction getter.""" + return [] + + def _make_pipeline(self, device, kernel_name): + import Metal + + fn = self.library.newFunctionWithName_(kernel_name) + if fn is None: + raise RuntimeError(f"Metal kernel '{kernel_name}' not found in library") + pipeline, error = device.newComputePipelineStateWithFunction_error_(fn, None) + if error is not None: + raise RuntimeError(f"Failed to create pipeline for '{kernel_name}': {error}") + return pipeline + + @staticmethod + def _check_cmd_buf(cmd_buf, kernel_name=""): + """Check command buffer status after waitUntilCompleted.""" + import Metal + + status = cmd_buf.status() + if status == Metal.MTLCommandBufferStatusError: + error = cmd_buf.error() + raise RuntimeError( + f"Metal command buffer error in {kernel_name}: {error}" + ) + + +class MetalProbDirectionGetter(MetalGPUDirectionGetter): + """Probabilistic direction getter for Metal.""" + + def __init__(self): + self.library = None + self.getnum_pipeline = None + self.gen_pipeline = None + + def _shader_files(self): + return [] + + def setup_device(self, device): + self.compile_program(device) + self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesProb_k") + self.gen_pipeline = self._make_pipeline(device, "genStreamlinesMergeProb_k") + + def _make_params_bytes(self, sp, nseeds_gpu, for_gen=False): + gt = sp.gpu_tracker + rng_seed = gt.rng_seed + rng_seed_lo = rng_seed & 0xFFFFFFFF + rng_seed_hi = (rng_seed >> 32) & 0xFFFFFFFF + + # ProbTrackingParams struct layout (must match Metal struct) + # float max_angle, tc_threshold, step_size, relative_peak_thresh, min_separation_angle + # int rng_seed_lo, rng_seed_hi, rng_offset, nseed + # int dimx, dimy, dimz, dimt, samplm_nr, num_edges, model_type + values = [ + gt.max_angle, + gt.tc_threshold if for_gen else 0.0, + gt.step_size if for_gen else 0.0, + gt.relative_peak_thresh, + gt.min_separation_angle, + rng_seed_lo, + rng_seed_hi, + gt.rng_offset if for_gen else 0, + nseeds_gpu, + gt.dimx, gt.dimy, gt.dimz, gt.dimt, + gt.samplm_nr, gt.nedges, 2, # model_type = PROB + ] + # 5 floats + 11 ints + return struct.pack("5f11i", *values) + + def getNumStreamlines(self, nseeds_gpu, block, grid, sp): + import Metal + + gt = sp.gpu_tracker + params_bytes = self._make_params_bytes(sp, nseeds_gpu, for_gen=False) + + cmd_buf = gt.command_queue.commandBuffer() + encoder = cmd_buf.computeCommandEncoder() + encoder.setComputePipelineState_(self.getnum_pipeline) + + encoder.setBytes_length_atIndex_(params_bytes, len(params_bytes), 0) + encoder.setBuffer_offset_atIndex_(sp.seeds_buf, 0, 1) + encoder.setBuffer_offset_atIndex_(gt.dataf_buf, 0, 2) + encoder.setBuffer_offset_atIndex_(gt.sphere_vertices_buf, 0, 3) + encoder.setBuffer_offset_atIndex_(gt.sphere_edges_buf, 0, 4) + encoder.setBuffer_offset_atIndex_(sp.shDirTemp0_buf, 0, 5) + encoder.setBuffer_offset_atIndex_(sp.slinesOffs_buf, 0, 6) + + threads_per_group = Metal.MTLSize(block[0], block[1], block[2]) + groups = Metal.MTLSize(grid[0], grid[1], grid[2]) + encoder.dispatchThreadgroups_threadsPerThreadgroup_(groups, threads_per_group) + + encoder.endEncoding() + cmd_buf.commit() + cmd_buf.waitUntilCompleted() + self._check_cmd_buf(cmd_buf, "getNumStreamlinesProb_k") + + def generateStreamlines(self, nseeds_gpu, block, grid, sp): + import Metal + + gt = sp.gpu_tracker + params_bytes = self._make_params_bytes(sp, nseeds_gpu, for_gen=True) + + cmd_buf = gt.command_queue.commandBuffer() + encoder = cmd_buf.computeCommandEncoder() + encoder.setComputePipelineState_(self.gen_pipeline) + + encoder.setBytes_length_atIndex_(params_bytes, len(params_bytes), 0) + encoder.setBuffer_offset_atIndex_(sp.seeds_buf, 0, 1) + encoder.setBuffer_offset_atIndex_(gt.dataf_buf, 0, 2) + encoder.setBuffer_offset_atIndex_(gt.metric_map_buf, 0, 3) + encoder.setBuffer_offset_atIndex_(gt.sphere_vertices_buf, 0, 4) + encoder.setBuffer_offset_atIndex_(gt.sphere_edges_buf, 0, 5) + encoder.setBuffer_offset_atIndex_(sp.slinesOffs_buf, 0, 6) + encoder.setBuffer_offset_atIndex_(sp.shDirTemp0_buf, 0, 7) + encoder.setBuffer_offset_atIndex_(sp.slineSeed_buf, 0, 8) + encoder.setBuffer_offset_atIndex_(sp.slineLen_buf, 0, 9) + encoder.setBuffer_offset_atIndex_(sp.sline_buf, 0, 10) + + threads_per_group = Metal.MTLSize(block[0], block[1], block[2]) + groups = Metal.MTLSize(grid[0], grid[1], grid[2]) + encoder.dispatchThreadgroups_threadsPerThreadgroup_(groups, threads_per_group) + + encoder.endEncoding() + cmd_buf.commit() + cmd_buf.waitUntilCompleted() + self._check_cmd_buf(cmd_buf, "genStreamlinesMergeProb_k") + + +class MetalPttDirectionGetter(MetalProbDirectionGetter): + """PTT direction getter for Metal.""" + + def _shader_files(self): + return ["ptt.metal"] + + def setup_device(self, device): + self.compile_program(device) + # PTT reuses Prob's getNum kernel for initial direction finding + self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesProb_k") + # PTT has its own gen kernel with parallel transport frame tracking + self.gen_pipeline = self._make_pipeline(device, "genStreamlinesMergePtt_k") + + def _make_params_bytes(self, sp, nseeds_gpu, for_gen=False): + gt = sp.gpu_tracker + rng_seed = gt.rng_seed + rng_seed_lo = rng_seed & 0xFFFFFFFF + rng_seed_hi = (rng_seed >> 32) & 0xFFFFFFFF + values = [ + gt.max_angle, + gt.tc_threshold if for_gen else 0.0, + gt.step_size if for_gen else 0.0, + gt.relative_peak_thresh, + gt.min_separation_angle, + rng_seed_lo, + rng_seed_hi, + gt.rng_offset if for_gen else 0, + nseeds_gpu, + gt.dimx, gt.dimy, gt.dimz, gt.dimt, + gt.samplm_nr, gt.nedges, 3, # model_type = PTT + ] + return struct.pack("5f11i", *values) + + +class MetalBootDirectionGetter(MetalGPUDirectionGetter): + """Bootstrap direction getter for Metal.""" + + def __init__( + self, + model_type: str, + min_signal: float, + H: np.ndarray, + R: np.ndarray, + delta_b: np.ndarray, + delta_q: np.ndarray, + sampling_matrix: np.ndarray, + b0s_mask: np.ndarray, + ): + self.model_type_str = model_type.upper() + if self.model_type_str == "OPDT": + self.model_type = 0 + elif self.model_type_str == "CSA": + self.model_type = 1 + else: + raise ValueError(f"Invalid model_type {model_type}, must be 'OPDT' or 'CSA'") + + self.H = np.ascontiguousarray(H, dtype=REAL_DTYPE) + self.R = np.ascontiguousarray(R, dtype=REAL_DTYPE) + self.delta_b = np.ascontiguousarray(delta_b, dtype=REAL_DTYPE) + self.delta_q = np.ascontiguousarray(delta_q, dtype=REAL_DTYPE) + self.delta_nr = int(delta_b.shape[0]) + self.min_signal = np.float32(min_signal) + self.sampling_matrix = np.ascontiguousarray(sampling_matrix, dtype=REAL_DTYPE) + self.b0s_mask = np.ascontiguousarray(b0s_mask, dtype=np.int32) + + self.library = None + self.getnum_pipeline = None + self.gen_pipeline = None + + # Buffers created on setup_device + self.H_buf = None + self.R_buf = None + self.delta_b_buf = None + self.delta_q_buf = None + self.b0s_mask_buf = None + self.sampling_matrix_buf = None + + @classmethod + def from_dipy_opdt(cls, gtab, sphere, sh_order_max=6, full_basis=False, + sh_lambda=0.006, min_signal=1): + return cls(**prepare_opdt(gtab, sphere, sh_order_max, full_basis, + sh_lambda, min_signal)) + + @classmethod + def from_dipy_csa(cls, gtab, sphere, sh_order_max=6, full_basis=False, + sh_lambda=0.006, min_signal=1): + return cls(**prepare_csa(gtab, sphere, sh_order_max, full_basis, + sh_lambda, min_signal)) + + def _shader_files(self): + return ["boot.metal"] + + def setup_device(self, device): + from cuslines.metal.mt_tractography import _make_shared_buffer + + self.compile_program(device) + self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesBoot_k") + self.gen_pipeline = self._make_pipeline(device, "genStreamlinesMergeBoot_k") + + # Create shared buffers for boot-specific data + self.H_buf = _make_shared_buffer(device, self.H) + self.R_buf = _make_shared_buffer(device, self.R) + self.delta_b_buf = _make_shared_buffer(device, self.delta_b) + self.delta_q_buf = _make_shared_buffer(device, self.delta_q) + self.b0s_mask_buf = _make_shared_buffer(device, self.b0s_mask) + self.sampling_matrix_buf = _make_shared_buffer(device, self.sampling_matrix) + + def _make_params_bytes(self, sp, nseeds_gpu, for_gen=False): + gt = sp.gpu_tracker + rng_seed = gt.rng_seed + rng_seed_lo = rng_seed & 0xFFFFFFFF + rng_seed_hi = (rng_seed >> 32) & 0xFFFFFFFF + + # BootTrackingParams struct layout (must match Metal struct in boot.metal) + # float max_angle, tc_threshold, step_size, relative_peak_thresh, + # min_separation_angle, min_signal + # int rng_seed_lo, rng_seed_hi, rng_offset, nseed + # int dimx, dimy, dimz, dimt, samplm_nr, num_edges, delta_nr, model_type + values = [ + gt.max_angle, + gt.tc_threshold if for_gen else 0.0, + gt.step_size if for_gen else 0.0, + gt.relative_peak_thresh, + gt.min_separation_angle, + float(self.min_signal), + rng_seed_lo, + rng_seed_hi, + gt.rng_offset if for_gen else 0, + nseeds_gpu, + gt.dimx, gt.dimy, gt.dimz, gt.dimt, + gt.samplm_nr, gt.nedges, self.delta_nr, self.model_type, + ] + # 6 floats + 12 ints + return struct.pack("6f12i", *values) + + def _boot_sh_pool_bytes(self, gt): + """Compute dynamic threadgroup memory size for boot kernels.""" + n32dimt = ((gt.dimt + 31) // 32) * 32 + sh_per_row = 2 * n32dimt + 2 * max(n32dimt, gt.samplm_nr) + return BLOCK_Y * sh_per_row * REAL_SIZE # bytes + + def getNumStreamlines(self, nseeds_gpu, block, grid, sp): + import Metal + + gt = sp.gpu_tracker + params_bytes = self._make_params_bytes(sp, nseeds_gpu, for_gen=False) + + cmd_buf = gt.command_queue.commandBuffer() + encoder = cmd_buf.computeCommandEncoder() + encoder.setComputePipelineState_(self.getnum_pipeline) + + # Buffer bindings match getNumStreamlinesBoot_k signature in boot.metal + encoder.setBytes_length_atIndex_(params_bytes, len(params_bytes), 0) + encoder.setBuffer_offset_atIndex_(sp.seeds_buf, 0, 1) + encoder.setBuffer_offset_atIndex_(gt.dataf_buf, 0, 2) + encoder.setBuffer_offset_atIndex_(self.H_buf, 0, 3) + encoder.setBuffer_offset_atIndex_(self.R_buf, 0, 4) + encoder.setBuffer_offset_atIndex_(self.delta_b_buf, 0, 5) + encoder.setBuffer_offset_atIndex_(self.delta_q_buf, 0, 6) + encoder.setBuffer_offset_atIndex_(self.b0s_mask_buf, 0, 7) + encoder.setBuffer_offset_atIndex_(self.sampling_matrix_buf, 0, 8) + encoder.setBuffer_offset_atIndex_(gt.sphere_vertices_buf, 0, 9) + encoder.setBuffer_offset_atIndex_(gt.sphere_edges_buf, 0, 10) + encoder.setBuffer_offset_atIndex_(sp.shDirTemp0_buf, 0, 11) + encoder.setBuffer_offset_atIndex_(sp.slinesOffs_buf, 0, 12) + + # Dynamic threadgroup memory (replaces CUDA extern __shared__) + encoder.setThreadgroupMemoryLength_atIndex_(self._boot_sh_pool_bytes(gt), 0) + + threads_per_group = Metal.MTLSize(block[0], block[1], block[2]) + groups = Metal.MTLSize(grid[0], grid[1], grid[2]) + encoder.dispatchThreadgroups_threadsPerThreadgroup_(groups, threads_per_group) + + encoder.endEncoding() + cmd_buf.commit() + cmd_buf.waitUntilCompleted() + self._check_cmd_buf(cmd_buf, "getNumStreamlinesBoot_k") + + def generateStreamlines(self, nseeds_gpu, block, grid, sp): + import Metal + + gt = sp.gpu_tracker + params_bytes = self._make_params_bytes(sp, nseeds_gpu, for_gen=True) + + cmd_buf = gt.command_queue.commandBuffer() + encoder = cmd_buf.computeCommandEncoder() + encoder.setComputePipelineState_(self.gen_pipeline) + + # Buffer bindings match genStreamlinesMergeBoot_k signature in boot.metal + encoder.setBytes_length_atIndex_(params_bytes, len(params_bytes), 0) + encoder.setBuffer_offset_atIndex_(sp.seeds_buf, 0, 1) + encoder.setBuffer_offset_atIndex_(gt.dataf_buf, 0, 2) + encoder.setBuffer_offset_atIndex_(gt.metric_map_buf, 0, 3) + encoder.setBuffer_offset_atIndex_(gt.sphere_vertices_buf, 0, 4) + encoder.setBuffer_offset_atIndex_(gt.sphere_edges_buf, 0, 5) + encoder.setBuffer_offset_atIndex_(self.H_buf, 0, 6) + encoder.setBuffer_offset_atIndex_(self.R_buf, 0, 7) + encoder.setBuffer_offset_atIndex_(self.delta_b_buf, 0, 8) + encoder.setBuffer_offset_atIndex_(self.delta_q_buf, 0, 9) + encoder.setBuffer_offset_atIndex_(self.sampling_matrix_buf, 0, 10) + encoder.setBuffer_offset_atIndex_(self.b0s_mask_buf, 0, 11) + encoder.setBuffer_offset_atIndex_(sp.slinesOffs_buf, 0, 12) + encoder.setBuffer_offset_atIndex_(sp.shDirTemp0_buf, 0, 13) + encoder.setBuffer_offset_atIndex_(sp.slineSeed_buf, 0, 14) + encoder.setBuffer_offset_atIndex_(sp.slineLen_buf, 0, 15) + encoder.setBuffer_offset_atIndex_(sp.sline_buf, 0, 16) + + # Dynamic threadgroup memory (replaces CUDA extern __shared__) + encoder.setThreadgroupMemoryLength_atIndex_(self._boot_sh_pool_bytes(gt), 0) + + threads_per_group = Metal.MTLSize(block[0], block[1], block[2]) + groups = Metal.MTLSize(grid[0], grid[1], grid[2]) + encoder.dispatchThreadgroups_threadsPerThreadgroup_(groups, threads_per_group) + + encoder.endEncoding() + cmd_buf.commit() + cmd_buf.waitUntilCompleted() + self._check_cmd_buf(cmd_buf, "genStreamlinesMergeBoot_k") diff --git a/cuslines/metal/mt_propagate_seeds.py b/cuslines/metal/mt_propagate_seeds.py new file mode 100644 index 0000000..48eaf30 --- /dev/null +++ b/cuslines/metal/mt_propagate_seeds.py @@ -0,0 +1,204 @@ +"""Metal seed batch propagator — mirrors cuslines/cuda_python/cu_propagate_seeds.py. + +Unified memory advantage: no cudaMemcpy needed. Seeds and results live in +shared CPU/GPU buffers. +""" + +import numpy as np +import math +import gc +import logging + +from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE + +from cuslines.metal.mutils import ( + REAL_SIZE, + REAL_DTYPE, + REAL3_SIZE, + MAX_SLINE_LEN, + EXCESS_ALLOC_FACT, + THR_X_SL, + THR_X_BL, + BLOCK_Y, + div_up, +) + +logger = logging.getLogger("GPUStreamlines") + + +class MetalSeedBatchPropagator: + def __init__(self, gpu_tracker, minlen=0, maxlen=np.inf): + self.gpu_tracker = gpu_tracker + self.minlen = minlen + self.maxlen = maxlen + + self.nSlines = 0 + self.nSlines_old = 0 + self.slines = None + self.sline_lens = None + + # Metal buffers + self.seeds_buf = None + self.slinesOffs_buf = None + self.shDirTemp0_buf = None + self.slineSeed_buf = None + self.slineLen_buf = None + self.sline_buf = None + + # Backing numpy arrays (unified memory — these ARE the GPU data) + self._seeds_arr = None + self._slinesOffs_arr = None + self._shDirTemp0_arr = None + self._slineSeed_arr = None + self._slineLen_arr = None + self._sline_arr = None + + def _get_sl_buffer_size(self): + return REAL_SIZE * 2 * 3 * MAX_SLINE_LEN * int(self.nSlines) + + def _allocate_seed_memory(self, seeds): + from cuslines.metal.mt_tractography import ( + _make_shared_buffer, _make_dynamic_buffer, _buffer_as_array, + ) + + nseeds = len(seeds) + device = self.gpu_tracker.device + block = (THR_X_SL, BLOCK_Y, 1) + grid = (div_up(nseeds, BLOCK_Y), 1, 1) + + # Seeds — copy into Metal shared buffer + seeds_arr = np.ascontiguousarray(seeds, dtype=REAL_DTYPE) + self.seeds_buf = _make_shared_buffer(device, seeds_arr) + + # Streamline offsets — dynamic buffer (GPU writes, CPU reads for prefix sum) + offs_nbytes = (nseeds + 1) * np.dtype(np.int32).itemsize + self.slinesOffs_buf = _make_dynamic_buffer(device, offs_nbytes) + self._slinesOffs_arr = _buffer_as_array( + self.slinesOffs_buf, np.int32, (nseeds + 1,) + ) + self._slinesOffs_arr[:] = 0 + + # Initial directions from each seed + shdir_size = self.gpu_tracker.samplm_nr * grid[0] * block[1] + shdir_nbytes = shdir_size * 3 * REAL_SIZE + self.shDirTemp0_buf = _make_dynamic_buffer(device, shdir_nbytes) + + return nseeds, block, grid + + def _cumsum_offsets(self, nseeds): + """CPU-side prefix sum on offsets — no memcpy needed with unified memory.""" + offs = self._slinesOffs_arr + + # Exclusive prefix sum: shift cumsum right, insert 0 at start + counts = offs[:nseeds].copy() + np.cumsum(counts, out=offs[1:nseeds + 1]) + offs[0] = 0 + self.nSlines = int(offs[nseeds]) + + def _allocate_tracking_memory(self): + from cuslines.metal.mt_tractography import ( + _make_dynamic_buffer, _buffer_as_array, + ) + + device = self.gpu_tracker.device + + if self.nSlines > EXCESS_ALLOC_FACT * self.nSlines_old: + self.slines = None + self.sline_lens = None + gc.collect() + + if self.slines is None: + self.slines = np.empty( + (EXCESS_ALLOC_FACT * self.nSlines, MAX_SLINE_LEN * 2, 3), + dtype=REAL_DTYPE, + ) + if self.sline_lens is None: + self.sline_lens = np.empty( + EXCESS_ALLOC_FACT * self.nSlines, dtype=np.int32 + ) + + # Seed-to-streamline mapping — dynamic buffer (GPU writes seed indices) + seed_nbytes = self.nSlines * np.dtype(np.int32).itemsize + self.slineSeed_buf = _make_dynamic_buffer(device, seed_nbytes) + self._slineSeed_arr = _buffer_as_array( + self.slineSeed_buf, np.int32, (self.nSlines,) + ) + self._slineSeed_arr[:] = -1 + + # Streamline lengths — dynamic buffer (GPU writes lengths) + len_nbytes = self.nSlines * np.dtype(np.int32).itemsize + self.slineLen_buf = _make_dynamic_buffer(device, len_nbytes) + self._slineLen_arr = _buffer_as_array( + self.slineLen_buf, np.int32, (self.nSlines,) + ) + self._slineLen_arr[:] = 0 + + # Streamline output buffer — dynamic buffer (GPU writes streamline points) + buffer_count = 2 * 3 * MAX_SLINE_LEN * self.nSlines + sline_nbytes = buffer_count * REAL_SIZE + self.sline_buf = _make_dynamic_buffer(device, sline_nbytes) + self._sline_arr = _buffer_as_array( + self.sline_buf, REAL_DTYPE, (buffer_count,) + ) + + def _copy_results(self): + """With unified memory, results are already in CPU-accessible memory. + Just reshape/copy into the output arrays.""" + if self.nSlines == 0: + return + + # Reshape the flat sline buffer into (nSlines, MAX_SLINE_LEN*2, 3) + sline_view = self._sline_arr.reshape(self.nSlines, MAX_SLINE_LEN * 2, 3) + self.slines[:self.nSlines] = sline_view + self.sline_lens[:self.nSlines] = self._slineLen_arr + + def propagate(self, seeds): + self.nseeds = len(seeds) + + nseeds, block, grid = self._allocate_seed_memory(seeds) + + # Pass 1: count streamlines per seed + self.gpu_tracker.dg.getNumStreamlines(nseeds, block, grid, self) + + # Prefix sum offsets (no memcpy — unified memory) + self._cumsum_offsets(nseeds) + + if self.nSlines == 0: + self.nSlines_old = self.nSlines + self.gpu_tracker.rng_offset += self.nseeds + return + + self._allocate_tracking_memory() + + # Pass 2: generate streamlines + self.gpu_tracker.dg.generateStreamlines(nseeds, block, grid, self) + + # Copy results (trivial with unified memory) + self._copy_results() + + self.nSlines_old = self.nSlines + self.gpu_tracker.rng_offset += self.nseeds + + def get_buffer_size(self): + buffer_size = 0 + lens = self.sline_lens + for jj in range(self.nSlines): + if lens[jj] < self.minlen or lens[jj] > self.maxlen: + continue + buffer_size += lens[jj] * 3 * REAL_SIZE + return math.ceil(buffer_size / MEGABYTE) + + def as_generator(self): + def _yield_slines(): + sls = self.slines + lens = self.sline_lens + for jj in range(self.nSlines): + npts = lens[jj] + if npts < self.minlen or npts > self.maxlen: + continue + yield np.asarray(sls[jj], dtype=REAL_DTYPE)[:npts] + + return _yield_slines() + + def as_array_sequence(self): + return ArraySequence(self.as_generator(), self.get_buffer_size()) diff --git a/cuslines/metal/mt_tractography.py b/cuslines/metal/mt_tractography.py new file mode 100644 index 0000000..337d6f5 --- /dev/null +++ b/cuslines/metal/mt_tractography.py @@ -0,0 +1,246 @@ +"""Metal GPU tracker — mirrors cuslines/cuda_python/cu_tractography.py. + +Key difference from the CUDA backend: Apple Silicon unified memory means +we wrap numpy arrays as Metal shared buffers with zero copies. +""" + +import numpy as np +from tqdm import tqdm +import logging +from math import radians + +from cuslines.metal.mutils import ( + REAL_SIZE, + REAL_DTYPE, + aligned_array, + PAGE_SIZE, + checkMetalError, +) + +from cuslines.metal.mt_direction_getters import MetalGPUDirectionGetter, MetalBootDirectionGetter +from cuslines.metal.mt_propagate_seeds import MetalSeedBatchPropagator + +from trx.trx_file_memmap import TrxFile +from nibabel.streamlines.tractogram import Tractogram +from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE +from dipy.io.stateful_tractogram import Space, StatefulTractogram + +logger = logging.getLogger("GPUStreamlines") + + +def _make_shared_buffer(device, arr): + """Copy a numpy array into a Metal shared buffer. + + Uses newBufferWithBytes (one copy at setup time). The buffer lives in + unified memory and is GPU-accessible without further copies. + """ + import Metal + + buf = device.newBufferWithBytes_length_options_( + arr.tobytes(), arr.nbytes, Metal.MTLResourceStorageModeShared + ) + return buf + + +def _make_dynamic_buffer(device, nbytes): + """Create an empty Metal shared buffer and return (buf, numpy_view). + + The numpy array is a writable view of the Metal buffer's contents, + giving true zero-copy CPU/GPU sharing for dynamic per-batch data. + """ + import Metal + + buf = device.newBufferWithLength_options_( + nbytes, Metal.MTLResourceStorageModeShared + ) + return buf + + +def _buffer_as_array(buf, dtype, shape): + """Create a numpy array view of a Metal buffer's contents (zero-copy).""" + nbytes = buf.length() + memview = buf.contents().as_buffer(nbytes) + count = int(np.prod(shape)) + return np.frombuffer(memview, dtype=dtype, count=count).reshape(shape) + + +class MetalGPUTracker: + def __init__( + self, + dg: MetalGPUDirectionGetter, + dataf: np.ndarray, + stop_map: np.ndarray, + stop_theshold: float, + sphere_vertices: np.ndarray, + sphere_edges: np.ndarray, + max_angle: float = radians(60), + step_size: float = 0.5, + min_pts=0, + max_pts=np.inf, + relative_peak_thresh: float = 0.25, + min_separation_angle: float = radians(45), + ngpus: int = 1, + rng_seed: int = 0, + rng_offset: int = 0, + chunk_size: int = 25000, + ): + import Metal + + self.device = Metal.MTLCreateSystemDefaultDevice() + if self.device is None: + raise RuntimeError("No Metal GPU device found") + self.command_queue = self.device.newCommandQueue() + + # Ensure contiguous float32 arrays + self.dataf = np.ascontiguousarray(dataf, dtype=REAL_DTYPE) + self.metric_map = np.ascontiguousarray(stop_map, dtype=REAL_DTYPE) + self.sphere_vertices = np.ascontiguousarray(sphere_vertices, dtype=REAL_DTYPE) + self.sphere_edges = np.ascontiguousarray(sphere_edges, dtype=np.int32) + + self.dimx, self.dimy, self.dimz, self.dimt = dataf.shape + self.nedges = int(sphere_edges.shape[0]) + if isinstance(dg, MetalBootDirectionGetter): + self.samplm_nr = int(dg.sampling_matrix.shape[0]) + else: + self.samplm_nr = self.dimt + self.n32dimt = ((self.dimt + 31) // 32) * 32 + + self.dg = dg + self.max_angle = np.float32(max_angle) + self.tc_threshold = np.float32(stop_theshold) + self.step_size = np.float32(step_size) + self.relative_peak_thresh = np.float32(relative_peak_thresh) + self.min_separation_angle = np.float32(min_separation_angle) + + # Metal: single GPU (ngpus ignored, always 1) + self.ngpus = 1 + self.rng_seed = int(rng_seed) + self.rng_offset = int(rng_offset) + self.chunk_size = int(chunk_size) + + logger.info("Creating MetalGPUTracker on %s", self.device.name()) + + # Shared buffers — created lazily in __enter__ + self.dataf_buf = None + self.metric_map_buf = None + self.sphere_vertices_buf = None + self.sphere_edges_buf = None + + self.seed_propagator = MetalSeedBatchPropagator( + gpu_tracker=self, minlen=min_pts, maxlen=max_pts + ) + self._allocated = False + + def __enter__(self): + self._allocate() + return self + + def _allocate(self): + if self._allocated: + return + + # Unified memory: wrap numpy arrays as shared Metal buffers + self.dataf_buf = _make_shared_buffer(self.device, self.dataf) + self.metric_map_buf = _make_shared_buffer(self.device, self.metric_map) + self.sphere_vertices_buf = _make_shared_buffer(self.device, self.sphere_vertices) + self.sphere_edges_buf = _make_shared_buffer(self.device, self.sphere_edges) + + self.dg.setup_device(self.device) + self._allocated = True + + def __exit__(self, exc_type, exc, tb): + logger.info("Destroying MetalGPUTracker...") + # Metal buffers are reference-counted; dropping refs is sufficient. + self.dataf_buf = None + self.metric_map_buf = None + self.sphere_vertices_buf = None + self.sphere_edges_buf = None + # Clean up direction getter buffers + if hasattr(self.dg, 'H_buf'): + for attr in ('H_buf', 'R_buf', 'delta_b_buf', 'delta_q_buf', + 'b0s_mask_buf', 'sampling_matrix_buf'): + setattr(self.dg, attr, None) + self.dg.library = None + self.dg.getnum_pipeline = None + self.dg.gen_pipeline = None + self._allocated = False + return False + + def _divide_chunks(self, seeds): + global_chunk_sz = self.chunk_size # single GPU + nchunks = (seeds.shape[0] + global_chunk_sz - 1) // global_chunk_sz + return global_chunk_sz, nchunks + + def generate_sft(self, seeds, ref_img): + global_chunk_sz, nchunks = self._divide_chunks(seeds) + buffer_size = 0 + generators = [] + + with tqdm(total=seeds.shape[0]) as pbar: + for idx in range(nchunks): + chunk = seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz] + self.seed_propagator.propagate(chunk) + buffer_size += self.seed_propagator.get_buffer_size() + generators.append(self.seed_propagator.as_generator()) + pbar.update(chunk.shape[0]) + + array_sequence = ArraySequence( + (item for gen in generators for item in gen), buffer_size + ) + return StatefulTractogram(array_sequence, ref_img, Space.VOX) + + def generate_trx(self, seeds, ref_img): + global_chunk_sz, nchunks = self._divide_chunks(seeds) + + sl_len_guess = 100 + sl_per_seed_guess = 2 + n_sls_guess = sl_per_seed_guess * seeds.shape[0] + + trx_reference = TrxFile(reference=ref_img) + trx_reference.streamlines._data = trx_reference.streamlines._data.astype(np.float32) + trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype(np.uint64) + + trx_file = TrxFile( + nb_streamlines=n_sls_guess, + nb_vertices=n_sls_guess * sl_len_guess, + init_as=trx_reference, + ) + offsets_idx = 0 + sls_data_idx = 0 + + with tqdm(total=seeds.shape[0]) as pbar: + for idx in range(int(nchunks)): + chunk = seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz] + self.seed_propagator.propagate(chunk) + tractogram = Tractogram( + self.seed_propagator.as_array_sequence(), + affine_to_rasmm=ref_img.affine, + ) + tractogram.to_world() + sls = tractogram.streamlines + + new_offsets_idx = offsets_idx + len(sls._offsets) + new_sls_data_idx = sls_data_idx + len(sls._data) + + if ( + new_offsets_idx > trx_file.header["NB_STREAMLINES"] + or new_sls_data_idx > trx_file.header["NB_VERTICES"] + ): + logger.info("TRX resizing...") + trx_file.resize( + nb_streamlines=new_offsets_idx * 2, + nb_vertices=new_sls_data_idx * 2, + ) + + trx_file.streamlines._data[sls_data_idx:new_sls_data_idx] = sls._data + trx_file.streamlines._offsets[offsets_idx:new_offsets_idx] = ( + sls_data_idx + sls._offsets + ) + trx_file.streamlines._lengths[offsets_idx:new_offsets_idx] = sls._lengths + + offsets_idx = new_offsets_idx + sls_data_idx = new_sls_data_idx + pbar.update(chunk.shape[0]) + + trx_file.resize() + return trx_file diff --git a/cuslines/metal/mutils.py b/cuslines/metal/mutils.py new file mode 100644 index 0000000..d190e59 --- /dev/null +++ b/cuslines/metal/mutils.py @@ -0,0 +1,142 @@ +"""Metal backend utilities — type definitions, error checking, aligned allocation. + +Mirrors cuslines/cuda_python/cutils.py for the Metal backend. +Metal only supports float32, so no REAL_SIZE branching is needed. +""" + +import numpy as np +import ctypes +import ctypes.util +import importlib.util +from enum import IntEnum +from pathlib import Path + +# Import _globals.py directly (bypasses cuslines.cuda_python.__init__ +# which would trigger CUDA imports). +_globals_path = Path(__file__).resolve().parent.parent / "cuda_python" / "_globals.py" +_spec = importlib.util.spec_from_file_location("_globals", str(_globals_path)) +_globals_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_globals_mod) + +MAX_SLINE_LEN = _globals_mod.MAX_SLINE_LEN +EXCESS_ALLOC_FACT = _globals_mod.EXCESS_ALLOC_FACT +MAX_SLINES_PER_SEED = _globals_mod.MAX_SLINES_PER_SEED +THR_X_BL = _globals_mod.THR_X_BL +THR_X_SL = _globals_mod.THR_X_SL +PMF_THRESHOLD_P = _globals_mod.PMF_THRESHOLD_P +NORM_EPS = _globals_mod.NORM_EPS + +# Re-export globals +__all__ = [ + "ModelType", + "REAL_SIZE", + "REAL_DTYPE", + "REAL3_SIZE", + "REAL3_DTYPE", + "BLOCK_Y", + "MAX_SLINE_LEN", + "EXCESS_ALLOC_FACT", + "MAX_SLINES_PER_SEED", + "THR_X_BL", + "THR_X_SL", + "PMF_THRESHOLD_P", + "NORM_EPS", + "div_up", + "checkMetalError", + "aligned_array", + "PAGE_SIZE", +] + + +class ModelType(IntEnum): + OPDT = 0 + CSA = 1 + PROB = 2 + PTT = 3 + + +# Metal only supports float32 +REAL_SIZE = 4 +REAL_DTYPE = np.float32 + +# packed_float3 in Metal is 12 bytes — same layout as CUDA float3 in arrays. +# align=False ensures numpy uses 12-byte stride, not 16. +REAL3_SIZE = 3 * REAL_SIZE +REAL3_DTYPE = np.dtype( + [("x", np.float32), ("y", np.float32), ("z", np.float32)], align=False +) + +BLOCK_Y = THR_X_BL // THR_X_SL + +# Apple Silicon page size (16 KB). Buffers passed to +# newBufferWithBytesNoCopy must be page-aligned. +PAGE_SIZE = 16384 + + +def div_up(a, b): + return (a + b - 1) // b + + +def checkMetalError(error): + """Raise if an NSError was returned from a Metal API call.""" + if error is not None: + desc = error.localizedDescription() + raise RuntimeError(f"Metal error: {desc}") + + +# ── page-aligned allocation ─────────────────────────────────────────── + +_libc_name = ctypes.util.find_library("c") +_libc = ctypes.CDLL(_libc_name, use_errno=True) +_libc.free.argtypes = [ctypes.c_void_p] +_libc.free.restype = None + + +def _posix_memalign(size, alignment=PAGE_SIZE): + """Allocate *size* bytes aligned to *alignment* using posix_memalign.""" + ptr = ctypes.c_void_p() + ret = _libc.posix_memalign(ctypes.byref(ptr), alignment, size) + if ret != 0: + raise MemoryError( + f"posix_memalign failed (ret={ret}) for size={size}, align={alignment}" + ) + return ptr + + +def aligned_array(shape, dtype=np.float32, alignment=PAGE_SIZE): + """Return a C-contiguous numpy array whose underlying memory is page-aligned. + + Suitable for wrapping with Metal's ``newBufferWithBytesNoCopy``. + The returned array owns a prevent-GC reference to the raw buffer. + """ + dtype = np.dtype(dtype) + count = int(np.prod(shape)) + nbytes = count * dtype.itemsize + # Round up to page boundary so the buffer length is also page-aligned, + # which Metal requires for newBufferWithBytesNoCopy. + nbytes_aligned = ((nbytes + alignment - 1) // alignment) * alignment + + raw_ptr = _posix_memalign(nbytes_aligned, alignment) + + # Create a numpy array that shares the allocated memory. + # We use ctypes to expose the raw pointer to numpy. + ctypes_array = (ctypes.c_byte * nbytes_aligned).from_address(raw_ptr.value) + arr = np.frombuffer(ctypes_array, dtype=dtype, count=count).reshape(shape) + + # Prevent the raw allocation from being freed while the array lives. + # When the ref is dropped numpy will drop ctypes_array which does NOT + # free the underlying posix_memalign memory (ctypes doesn't own it). + # We attach a Release helper via the buffer owner chain instead. + arr._aligned_raw_ptr = raw_ptr # prevent GC + arr._aligned_ctypes_buf = ctypes_array # prevent GC + + # Register a weakref-free destructor using a ref-cycle-safe closure. + import weakref + + def _free_cb(ptr_val=raw_ptr.value): + _libc.free(ptr_val) + + # Invoke _free_cb when arr gets collected. + weakref.ref(ctypes_array, lambda _: _free_cb()) + + return arr diff --git a/cuslines/metal_shaders/boot.metal b/cuslines/metal_shaders/boot.metal new file mode 100644 index 0000000..1c5fd6f --- /dev/null +++ b/cuslines/metal_shaders/boot.metal @@ -0,0 +1,869 @@ +/* Metal port of cuslines/cuda_c/boot.cu — bootstrap streamline generation. + * + * Translation notes: + * - CUDA __device__ functions → plain inline functions + * - CUDA __global__ kernels → kernel functions + * - CUDA templates removed; concrete float types used throughout + * - __shared__ → threadgroup + * - Warp intrinsics → SIMD group intrinsics (Apple GPU SIMD width == 32) + * - curandStatePhilox4_32_10_t → PhiloxState (from philox_rng.h) + * - REAL_T → float, REAL3_T → float3 (packed_float3 for device buffers) + * - All #ifdef DEBUG / #if 0 blocks removed + * - USE_FIXED_PERMUTATION block removed + */ + +#include "globals.h" +#include "types.h" +#include "philox_rng.h" + +// ── params struct for kernel arguments ────────────────────────────── + +struct BootTrackingParams { + float max_angle; + float tc_threshold; + float step_size; + float relative_peak_thresh; + float min_separation_angle; + float min_signal; + int rng_seed_lo; + int rng_seed_hi; + int rng_offset; + int nseed; + int dimx, dimy, dimz, dimt; + int samplm_nr; + int num_edges; + int delta_nr; + int model_type; +}; + +// ── raw uint from Philox (equivalent to CUDA curand(&st)) ────────── + +inline uint philox_uint(thread PhiloxState& s) { + if (s.idx >= 4) { + philox_next(s); + } + uint bits; + switch (s.idx) { + case 0: bits = s.output.x; break; + case 1: bits = s.output.y; break; + case 2: bits = s.output.z; break; + default: bits = s.output.w; break; + } + s.idx++; + return bits; +} + +// ── avgMask — SIMD-parallel masked average ────────────────────────── + +inline float avgMask(const int mskLen, + const device int* mask, + const threadgroup float* data, + uint tidx) { + + int myCnt = 0; + float mySum = 0.0f; + + for (int i = int(tidx); i < mskLen; i += THR_X_SL) { + if (mask[i]) { + myCnt++; + mySum += data[i]; + } + } + + for (int i = THR_X_SL / 2; i > 0; i /= 2) { + mySum += simd_shuffle_xor(mySum, ushort(i)); + myCnt += simd_shuffle_xor(myCnt, ushort(i)); + } + + return mySum / float(myCnt); +} + +// ── maskGet — compact non-masked entries ──────────────────────────── + +inline int maskGet(const int n, + const device int* mask, + const threadgroup float* plain, + threadgroup float* masked, + uint tidx) { + + const uint laneMask = (1u << tidx) - 1u; + + int woff = 0; + for (int j = 0; j < n; j += THR_X_SL) { + + const int act = (j + int(tidx) < n) ? (!mask[j + int(tidx)]) : 0; + const uint msk = SIMD_BALLOT_MASK(bool(act)); + + const int toff = popcount(msk & laneMask); + if (act) { + masked[woff + toff] = plain[j + int(tidx)]; + } + woff += popcount(msk); + } + return woff; +} + +// ── maskPut — scatter masked entries back ─────────────────────────── + +inline void maskPut(const int n, + const device int* mask, + const threadgroup float* masked, + threadgroup float* plain, + uint tidx) { + + const uint laneMask = (1u << tidx) - 1u; + + int woff = 0; + for (int j = 0; j < n; j += THR_X_SL) { + + const int act = (j + int(tidx) < n) ? (!mask[j + int(tidx)]) : 0; + const uint msk = SIMD_BALLOT_MASK(bool(act)); + + const int toff = popcount(msk & laneMask); + if (act) { + plain[j + int(tidx)] = masked[woff + toff]; + } + woff += popcount(msk); + } +} + +// ── closest_peak_d — find closest peak to current direction ───────── + +inline int closest_peak_d(const float max_angle, + const float3 direction, + const int npeaks, + const threadgroup float3* peaks, + threadgroup float3* peak, + uint tidx) { + + const float cos_similarity = COS(max_angle); + + float cpeak_dot = 0.0f; + int cpeak_idx = -1; + for (int j = 0; j < npeaks; j += THR_X_SL) { + if (j + int(tidx) < npeaks) { + const float dot = direction.x * peaks[j + int(tidx)].x + + direction.y * peaks[j + int(tidx)].y + + direction.z * peaks[j + int(tidx)].z; + + if (FABS(dot) > FABS(cpeak_dot)) { + cpeak_dot = dot; + cpeak_idx = j + int(tidx); + } + } + } + + for (int j = THR_X_SL / 2; j > 0; j /= 2) { + const float dot = simd_shuffle_xor(cpeak_dot, ushort(j)); + const int idx = simd_shuffle_xor(cpeak_idx, ushort(j)); + if (FABS(dot) > FABS(cpeak_dot)) { + cpeak_dot = dot; + cpeak_idx = idx; + } + } + + if (cpeak_idx >= 0) { + if (cpeak_dot >= cos_similarity) { + peak[0] = peaks[cpeak_idx]; + return 1; + } + if (cpeak_dot <= -cos_similarity) { + peak[0] = float3(-peaks[cpeak_idx].x, + -peaks[cpeak_idx].y, + -peaks[cpeak_idx].z); + return 1; + } + } + return 0; +} + +// ── ndotp_d — matrix-vector dot product ───────────────────────────── + +inline void ndotp_d(const int N, + const int M, + const threadgroup float* srcV, + const device float* srcM, + threadgroup float* dstV, + uint tidx) { + + for (int i = 0; i < N; i++) { + + float tmp = 0.0f; + + for (int j = 0; j < M; j += THR_X_SL) { + if (j + int(tidx) < M) { + tmp += srcV[j + int(tidx)] * srcM[i * M + j + int(tidx)]; + } + } + for (int j = THR_X_SL / 2; j > 0; j /= 2) { + tmp += simd_shuffle_down(tmp, ushort(j)); + } + + if (tidx == 0) { + dstV[i] = tmp; + } + } +} + +// ── ndotp_log_opdt_d — OPDT log-weighted dot product ──────────────── + +inline void ndotp_log_opdt_d(const int N, + const int M, + const threadgroup float* srcV, + const device float* srcM, + threadgroup float* dstV, + uint tidx) { + + const float ONEP5 = 1.5f; + + for (int i = 0; i < N; i++) { + + float tmp = 0.0f; + + for (int j = 0; j < M; j += THR_X_SL) { + if (j + int(tidx) < M) { + const float v = srcV[j + int(tidx)]; + tmp += -LOG(v) * (ONEP5 + LOG(v)) * v * srcM[i * M + j + int(tidx)]; + } + } + for (int j = THR_X_SL / 2; j > 0; j /= 2) { + tmp += simd_shuffle_down(tmp, ushort(j)); + } + + if (tidx == 0) { + dstV[i] = tmp; + } + } +} + +// ── ndotp_log_csa_d — CSA log-log-weighted dot product ────────────── + +inline void ndotp_log_csa_d(const int N, + const int M, + const threadgroup float* srcV, + const device float* srcM, + threadgroup float* dstV, + uint tidx) { + + const float csa_min = 0.001f; + const float csa_max = 0.999f; + + for (int i = 0; i < N; i++) { + + float tmp = 0.0f; + + for (int j = 0; j < M; j += THR_X_SL) { + if (j + int(tidx) < M) { + const float v = MIN(MAX(srcV[j + int(tidx)], csa_min), csa_max); + tmp += LOG(-LOG(v)) * srcM[i * M + j + int(tidx)]; + } + } + for (int j = THR_X_SL / 2; j > 0; j /= 2) { + tmp += simd_shuffle_down(tmp, ushort(j)); + } + + if (tidx == 0) { + dstV[i] = tmp; + } + } +} + +// ── fit_opdt — OPDT model fitting ─────────────────────────────────── + +inline void fit_opdt(const int delta_nr, + const int hr_side, + const device float* delta_q, + const device float* delta_b, + const threadgroup float* msk_data_sh, + threadgroup float* h_sh, + threadgroup float* r_sh, + uint tidx) { + + ndotp_log_opdt_d(delta_nr, hr_side, msk_data_sh, delta_q, r_sh, tidx); + ndotp_d(delta_nr, hr_side, msk_data_sh, delta_b, h_sh, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + for (int j = int(tidx); j < delta_nr; j += THR_X_SL) { + r_sh[j] -= h_sh[j]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); +} + +// ── fit_csa — CSA model fitting ───────────────────────────────────── + +inline void fit_csa(const int delta_nr, + const int hr_side, + const device float* fit_matrix, + const threadgroup float* msk_data_sh, + threadgroup float* r_sh, + uint tidx) { + + const float n0_const = 0.28209479177387814f; + ndotp_log_csa_d(delta_nr, hr_side, msk_data_sh, fit_matrix, r_sh, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + if (tidx == 0) { + r_sh[0] = n0_const; + } + simdgroup_barrier(mem_flags::mem_threadgroup); +} + +// ── fit_model_coef — dispatch to OPDT or CSA ──────────────────────── + +inline void fit_model_coef(const int model_type, + const int delta_nr, + const int hr_side, + const device float* delta_q, + const device float* delta_b, + const threadgroup float* msk_data_sh, + threadgroup float* h_sh, + threadgroup float* r_sh, + uint tidx) { + switch (model_type) { + case OPDT: + fit_opdt(delta_nr, hr_side, delta_q, delta_b, msk_data_sh, h_sh, r_sh, tidx); + break; + case CSA: + fit_csa(delta_nr, hr_side, delta_q, msk_data_sh, r_sh, tidx); + break; + default: + break; + } +} + +// ── get_direction_boot_d — bootstrap direction getter ─────────────── + +inline int get_direction_boot_d( + thread PhiloxState& st, + const int nattempts, + const int model_type, + const float max_angle, + const float min_signal, + const float relative_peak_thres, + const float min_separation_angle, + float3 dir, + const int dimx, + const int dimy, + const int dimz, + const int dimt, + const device float* dataf, + const device int* b0s_mask, + const float3 point, + const device float* H, + const device float* R, + const int delta_nr, + const device float* delta_b, + const device float* delta_q, + const int samplm_nr, + const device float* sampling_matrix, + const device packed_float3* sphere_vertices, + const device int2* sphere_edges, + const int num_edges, + threadgroup float3* dirs, + threadgroup float* sh_mem, + threadgroup float3* scratch_f3, + uint tidx, + uint tidy) { + + const int n32dimt = ((dimt + 31) / 32) * 32; + + // Partition shared memory — mirrors the CUDA layout + threadgroup float* vox_data_sh = sh_mem; + threadgroup float* msk_data_sh = vox_data_sh + n32dimt; + + threadgroup float* r_sh = msk_data_sh + n32dimt; + threadgroup float* h_sh = r_sh + MAX(n32dimt, samplm_nr); + + // Compute hr_side (number of non-b0 volumes) + int hr_side = 0; + for (int j = int(tidx); j < dimt; j += THR_X_SL) { + hr_side += (!b0s_mask[j]) ? 1 : 0; + } + for (int i = THR_X_SL / 2; i > 0; i /= 2) { + hr_side += simd_shuffle_xor(hr_side, ushort(i)); + } + + for (int attempt = 0; attempt < nattempts; attempt++) { + + const int rv = trilinear_interp(dimx, dimy, dimz, dimt, -1, + dataf, point, vox_data_sh, tidx); + + maskGet(dimt, b0s_mask, vox_data_sh, msk_data_sh, tidx); + + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (rv == 0) { + + // Multiply masked data by R and H matrices + ndotp_d(hr_side, hr_side, msk_data_sh, R, r_sh, tidx); + ndotp_d(hr_side, hr_side, msk_data_sh, H, h_sh, tidx); + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Bootstrap: add permuted residuals + for (int j = 0; j < hr_side; j += THR_X_SL) { + if (j + int(tidx) < hr_side) { + const int srcPermInd = int(philox_uint(st) % uint(hr_side)); + h_sh[j + int(tidx)] += r_sh[srcPermInd]; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // vox_data[dwi_mask] = masked_data + maskPut(dimt, b0s_mask, h_sh, vox_data_sh, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (int j = int(tidx); j < dimt; j += THR_X_SL) { + vox_data_sh[j] = MAX(min_signal, vox_data_sh[j]); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const float denom = avgMask(dimt, b0s_mask, vox_data_sh, tidx); + + for (int j = int(tidx); j < dimt; j += THR_X_SL) { + vox_data_sh[j] /= denom; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + maskGet(dimt, b0s_mask, vox_data_sh, msk_data_sh, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + fit_model_coef(model_type, delta_nr, hr_side, + delta_q, delta_b, msk_data_sh, h_sh, r_sh, tidx); + + // r_sh <- coef; compute pmf = sampling_matrix * coef + ndotp_d(samplm_nr, delta_nr, r_sh, sampling_matrix, h_sh, tidx); + + // h_sh <- pmf + } else { + for (int j = int(tidx); j < samplm_nr; j += THR_X_SL) { + h_sh[j] = 0.0f; + } + // h_sh <- pmf (all zeros) + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Optional soft angular weighting: boost PMF values near the + // current trajectory direction BEFORE thresholding. At fiber + // crossings (e.g. corona radiata), the commissural peak may be + // weaker than the dominant projection peak. Without weighting, + // the aligned peak can fall below the 5% absolute or 25% + // relative threshold and be zeroed out. By weighting first, + // the aligned peak is preserved and the perpendicular peak is + // suppressed. + // Controlled by ANGULAR_WEIGHT (0.0 = disabled, default). + // Typical value: 0.5 → weight = 0.5 + 0.5*|cos(angle)|. +#if ENABLE_ANGULAR_WEIGHT + if (nattempts > 1) { + for (int j = int(tidx); j < samplm_nr; j += THR_X_SL) { + const float3 sv = load_f3(sphere_vertices, uint(j)); + const float cos_sim = FABS(dir.x * sv.x + + dir.y * sv.y + + dir.z * sv.z); + h_sh[j] *= ((1.0f - ANGULAR_WEIGHT) + ANGULAR_WEIGHT * cos_sim); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + } +#endif + + const float abs_pmf_thr = PMF_THRESHOLD_P * + simd_max_reduce(samplm_nr, h_sh, REAL_MIN, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (int j = int(tidx); j < samplm_nr; j += THR_X_SL) { + const float v = h_sh[j]; + if (v < abs_pmf_thr) { + h_sh[j] = 0.0f; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const int ndir = peak_directions(h_sh, dirs, + sphere_vertices, + sphere_edges, + num_edges, + samplm_nr, + reinterpret_cast(r_sh), + relative_peak_thres, + min_separation_angle, + tidx); + if (nattempts == 1) { // init=True + return ndir; + } else { // init=False + if (ndir > 0) { + const int foundPeak = closest_peak_d(max_angle, dir, ndir, dirs, scratch_f3, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + if (foundPeak) { + if (tidx == 0) { + dirs[0] = *scratch_f3; + } + return 1; + } + } + } + } + return 0; +} + +// ── tracker_boot_d — single-direction streamline tracker ──────────── + +inline int tracker_boot_d( + thread PhiloxState& st, + const int model_type, + const float max_angle, + const float tc_threshold, + const float step_size, + const float relative_peak_thres, + const float min_separation_angle, + float3 seed, + float3 first_step, + float3 voxel_size, + const int dimx, + const int dimy, + const int dimz, + const int dimt, + const device float* dataf, + const device float* metric_map, + const int samplm_nr, + const device packed_float3* sphere_vertices, + const device int2* sphere_edges, + const int num_edges, + const float min_signal, + const int delta_nr, + const device float* H, + const device float* R, + const device float* delta_b, + const device float* delta_q, + const device float* sampling_matrix, + const device int* b0s_mask, + threadgroup int* nsteps, + device packed_float3* streamline, + threadgroup float* sh_mem, + threadgroup float3* sh_dirs, + threadgroup float* sh_interp, + threadgroup float3* scratch_f3, + uint tidx, + uint tidy) { + + int tissue_class = TRACKPOINT; + + float3 point = seed; + float3 direction = first_step; + + if (tidx == 0) { + store_f3(streamline, 0, point); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const int step_frac = 1; + + int i; + for (i = 1; i < MAX_SLINE_LEN * step_frac; i++) { + int ndir = get_direction_boot_d( + st, + 5, // NATTEMPTS + model_type, + max_angle, + min_signal, + relative_peak_thres, + min_separation_angle, + direction, + dimx, dimy, dimz, dimt, dataf, + b0s_mask, + point, + H, R, + delta_nr, + delta_b, delta_q, + samplm_nr, + sampling_matrix, + sphere_vertices, + sphere_edges, + num_edges, + sh_dirs, + sh_mem, + scratch_f3, + tidx, tidy); + simdgroup_barrier(mem_flags::mem_threadgroup); + direction = *scratch_f3; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (ndir == 0) { + break; + } + + point.x += (direction.x / voxel_size.x) * (step_size / float(step_frac)); + point.y += (direction.y / voxel_size.y) * (step_size / float(step_frac)); + point.z += (direction.z / voxel_size.z) * (step_size / float(step_frac)); + + if ((tidx == 0) && ((i % step_frac) == 0)) { + store_f3(streamline, uint(i / step_frac), point); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + tissue_class = check_point(tc_threshold, point, + dimx, dimy, dimz, + metric_map, + sh_interp, + tidx, tidy); + + if (tissue_class == ENDPOINT || + tissue_class == INVALIDPOINT || + tissue_class == OUTSIDEIMAGE) { + break; + } + } + nsteps[0] = i / step_frac; + if (((i % step_frac) != 0) && i < step_frac * (MAX_SLINE_LEN - 1)) { + nsteps[0]++; + if (tidx == 0) { + store_f3(streamline, uint(nsteps[0]), point); + } + } + + return tissue_class; +} + +// ── getNumStreamlinesBoot_k — count streamlines per seed (kernel) ─── + +kernel void getNumStreamlinesBoot_k( + constant BootTrackingParams& params [[buffer(0)]], + const device packed_float3* seeds [[buffer(1)]], + const device float* dataf [[buffer(2)]], + const device float* H [[buffer(3)]], + const device float* R [[buffer(4)]], + const device float* delta_b [[buffer(5)]], + const device float* delta_q [[buffer(6)]], + const device int* b0s_mask [[buffer(7)]], + const device float* sampling_matrix [[buffer(8)]], + const device packed_float3* sphere_vertices [[buffer(9)]], + const device int2* sphere_edges [[buffer(10)]], + device packed_float3* shDir0 [[buffer(11)]], + device int* slineOutOff [[buffer(12)]], + threadgroup float* sh_pool [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + uint3 tptg [[threads_per_threadgroup]], + uint3 tid_in_tg [[thread_position_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]]) { + + const uint tidx = tid_in_tg.x; + const uint tidy = tid_in_tg.y; + const uint BDIM_Y = tptg.y; + + const int slid = int(tgpig.x) * int(BDIM_Y) + int(tidy); + const uint gid = tgpig.x * tptg.y * tptg.x + tptg.x * tidy + tidx; + + if (slid >= params.nseed) { + return; + } + + float3 seed = load_f3(seeds, uint(slid)); + + PhiloxState st = philox_init(uint(params.rng_seed_lo), uint(params.rng_seed_hi), gid, 0); + + // Shared memory layout: + // Per-thread-row shared memory for get_direction_boot_d + const int n32dimt = ((params.dimt + 31) / 32) * 32; + const int sh_per_row = 2 * n32dimt + 2 * MAX(n32dimt, params.samplm_nr); + + // sh_pool is dynamically sized via setThreadgroupMemoryLength (CUDA extern __shared__ equivalent) + threadgroup float3 sh_dirs[BLOCK_Y * MAX_SLINES_PER_SEED]; // per-tidy dirs + threadgroup float3 scratch_f3[BLOCK_Y]; // per-tidy scratch for closest_peak_d + threadgroup float* sh_mem = sh_pool + tidy * sh_per_row; + + int ndir; + switch (params.model_type) { + case OPDT: + case CSA: + ndir = get_direction_boot_d( + st, + 1, // NATTEMPTS=1 (init=True) + params.model_type, + params.max_angle, + params.min_signal, + params.relative_peak_thresh, + params.min_separation_angle, + float3(0.0f, 0.0f, 0.0f), + params.dimx, params.dimy, params.dimz, params.dimt, + dataf, b0s_mask, + seed, + H, R, + params.delta_nr, + delta_b, delta_q, + params.samplm_nr, + sampling_matrix, + sphere_vertices, + sphere_edges, + params.num_edges, + sh_dirs + tidy * MAX_SLINES_PER_SEED, + sh_mem, + scratch_f3 + tidy, + tidx, tidy); + break; + default: + ndir = 0; + break; + } + + // Copy directions to output buffer + device packed_float3* dirOut = shDir0 + slid * params.samplm_nr; + for (int j = int(tidx); j < ndir; j += THR_X_SL) { + store_f3(dirOut, uint(j), sh_dirs[tidy * MAX_SLINES_PER_SEED + j]); + } + + if (tidx == 0) { + slineOutOff[slid] = ndir; + } +} + +// ── genStreamlinesMergeBoot_k — main bootstrap streamline kernel ──── + +kernel void genStreamlinesMergeBoot_k( + constant BootTrackingParams& params [[buffer(0)]], + const device packed_float3* seeds [[buffer(1)]], + const device float* dataf [[buffer(2)]], + const device float* metric_map [[buffer(3)]], + const device packed_float3* sphere_vertices [[buffer(4)]], + const device int2* sphere_edges [[buffer(5)]], + const device float* H [[buffer(6)]], + const device float* R [[buffer(7)]], + const device float* delta_b [[buffer(8)]], + const device float* delta_q [[buffer(9)]], + const device float* sampling_matrix [[buffer(10)]], + const device int* b0s_mask [[buffer(11)]], + const device int* slineOutOff [[buffer(12)]], + device packed_float3* shDir0 [[buffer(13)]], + device int* slineSeed [[buffer(14)]], + device int* slineLen [[buffer(15)]], + device packed_float3* sline [[buffer(16)]], + threadgroup float* sh_pool [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + uint3 tptg [[threads_per_threadgroup]], + uint3 tid_in_tg [[thread_position_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]]) { + + const uint tidx = tid_in_tg.x; + const uint tidy = tid_in_tg.y; + const uint BDIM_Y = tptg.y; + + const int slid = int(tgpig.x) * int(BDIM_Y) + int(tidy); + + const uint gid = tgpig.x * tptg.y * tptg.x + tptg.x * tidy + tidx; + PhiloxState st = philox_init(uint(params.rng_seed_lo), uint(params.rng_seed_hi), gid + 1, 0); + + if (slid >= params.nseed) { + return; + } + + float3 seed = load_f3(seeds, uint(slid)); + + int ndir = slineOutOff[slid + 1] - slineOutOff[slid]; + + simdgroup_barrier(mem_flags::mem_threadgroup); + + int slineOff = slineOutOff[slid]; + + // Shared memory layout for this thread row + const int n32dimt = ((params.dimt + 31) / 32) * 32; + const int sh_per_row = 2 * n32dimt + 2 * MAX(n32dimt, params.samplm_nr); + + // sh_pool is dynamically sized via setThreadgroupMemoryLength (CUDA extern __shared__ equivalent) + threadgroup float3 sh_dirs[BLOCK_Y * MAX_SLINES_PER_SEED]; // per-tidy dirs + threadgroup float sh_interp[BLOCK_Y]; // for check_point (indexed by tidy) + threadgroup int sh_nsteps[BLOCK_Y]; // per-tidy step counts + threadgroup float3 scratch_f3[BLOCK_Y]; // per-tidy scratch for closest_peak_d + threadgroup float* sh_mem = sh_pool + tidy * sh_per_row; + + for (int i = 0; i < ndir; i++) { + float3 first_step = load_f3(shDir0, uint(slid * params.samplm_nr + i)); + + device packed_float3* currSline = sline + slineOff * MAX_SLINE_LEN * 2; + + if (tidx == 0) { + slineSeed[slineOff] = slid; + } + + // Track backward + int stepsB; + tracker_boot_d( + st, + params.model_type, + params.max_angle, + params.tc_threshold, + params.step_size, + params.relative_peak_thresh, + params.min_separation_angle, + seed, + float3(-first_step.x, -first_step.y, -first_step.z), + float3(1.0f, 1.0f, 1.0f), + params.dimx, params.dimy, params.dimz, params.dimt, + dataf, + metric_map, + params.samplm_nr, + sphere_vertices, + sphere_edges, + params.num_edges, + params.min_signal, + params.delta_nr, + H, R, + delta_b, delta_q, + sampling_matrix, + b0s_mask, + sh_nsteps + tidy, + currSline, + sh_mem, + sh_dirs + tidy * MAX_SLINES_PER_SEED, + sh_interp, + scratch_f3 + tidy, + tidx, tidy); + stepsB = sh_nsteps[tidy]; + + // Reverse backward streamline + for (int j = 0; j < stepsB / 2; j += THR_X_SL) { + if (j + int(tidx) < stepsB / 2) { + const float3 p = load_f3(currSline, uint(j + int(tidx))); + const float3 q = load_f3(currSline, uint(stepsB - 1 - (j + int(tidx)))); + store_f3(currSline, uint(j + int(tidx)), q); + store_f3(currSline, uint(stepsB - 1 - (j + int(tidx))), p); + } + } + + // Track forward + int stepsF; + tracker_boot_d( + st, + params.model_type, + params.max_angle, + params.tc_threshold, + params.step_size, + params.relative_peak_thresh, + params.min_separation_angle, + seed, + first_step, + float3(1.0f, 1.0f, 1.0f), + params.dimx, params.dimy, params.dimz, params.dimt, + dataf, + metric_map, + params.samplm_nr, + sphere_vertices, + sphere_edges, + params.num_edges, + params.min_signal, + params.delta_nr, + H, R, + delta_b, delta_q, + sampling_matrix, + b0s_mask, + sh_nsteps + tidy, + currSline + stepsB - 1, + sh_mem, + sh_dirs + tidy * MAX_SLINES_PER_SEED, + sh_interp, + scratch_f3 + tidy, + tidx, tidy); + stepsF = sh_nsteps[tidy]; + + if (tidx == 0) { + slineLen[slineOff] = stepsB - 1 + stepsF; + } + + slineOff += 1; + } +} diff --git a/cuslines/metal_shaders/disc.h b/cuslines/metal_shaders/disc.h new file mode 100644 index 0000000..dbeddda --- /dev/null +++ b/cuslines/metal_shaders/disc.h @@ -0,0 +1,1890 @@ + +/* +This code from: https://github.com/nibrary/nibrary/blob/main/src/math/disc.h + +BSD 3-Clause License + +Copyright (c) 2024, Dogu Baran Aydogan All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef __DISC_H__ +#define __DISC_H__ + +#define DISC_2_VERT_CNT 24 +#define DISC_2_FACE_CNT 31 + +#define DISC_2_VERT {\ + -0.99680788,-0.07983759,\ + -0.94276539,0.33345677,\ + -0.87928469,-0.47629658,\ + -0.72856617,0.68497542,\ + -0.60006556,-0.79995082,\ + -0.54129995,-0.02761342,\ + -0.39271207,0.37117272,\ + -0.39217391,0.91989110,\ + -0.36362884,-0.40757367,\ + -0.22391316,-0.97460910,\ + -0.00130022,0.53966106,\ + 0.00000000,0.00000000,\ + 0.00973999,0.99995257,\ + 0.01606516,-0.54289908,\ + 0.21342395,-0.97695968,\ + 0.38192071,-0.38666136,\ + 0.38897094,0.37442837,\ + 0.40696681,0.91344295,\ + 0.54387161,-0.01477123,\ + 0.59119367,-0.80652963,\ + 0.73955688,0.67309406,\ + 0.87601150,-0.48229022,\ + 0.94617928,0.32364298,\ + 0.99585368,-0.09096944} + + + +#define DISC_2_FACE {\ + 9,8,4,\ + 11,16,10,\ + 5,8,11,\ + 5,1,0,\ + 18,16,11,\ + 11,15,18,\ + 13,8,9,\ + 11,8,13,\ + 13,15,11,\ + 22,18,23,\ + 22,20,16,\ + 16,18,22,\ + 16,20,17,\ + 12,10,17,\ + 17,10,16,\ + 15,19,21,\ + 23,18,21,\ + 21,18,15,\ + 2,4,8,\ + 2,5,0,\ + 8,5,2,\ + 7,10,12,\ + 6,7,3,\ + 10,7,6,\ + 3,1,6,\ + 1,5,6,\ + 11,10,6,\ + 6,5,11,\ + 14,19,15,\ + 15,13,14,\ + 14,13,9} + + + +#define DISC_3_VERT_CNT 36 +#define DISC_3_FACE_CNT 52 + +#define DISC_3_VERT {\ + -0.98798409,-0.15455565,\ + -0.98026530,0.19768646,\ + -0.87061458,-0.49196570,\ + -0.85315536,0.52165691,\ + -0.67948751,0.12870519,\ + -0.65830249,-0.75275350,\ + -0.60977645,0.79257345,\ + -0.60599745,-0.25746218,\ + -0.49175185,0.45085081,\ + -0.39584766,-0.56449807,\ + -0.37151031,0.02599657,\ + -0.34538749,-0.93846017,\ + -0.31409968,0.94939001,\ + -0.19774331,-0.30335822,\ + -0.18708240,0.31479263,\ + -0.14013436,0.65112487,\ + -0.02230445,-0.65649640,\ + -0.01247874,-0.99992214,\ + 0,0,\ + 0.03699045,0.99931562,\ + 0.15587647,0.33306130,\ + 0.17739302,-0.31129535,\ + 0.23950456,0.64808985,\ + 0.31593561,-0.94878063,\ + 0.34839477,-0.59393230,\ + 0.35674244,0.02011329,\ + 0.38082583,0.92464679,\ + 0.52353496,0.39304489,\ + 0.57766607,-0.30046041,\ + 0.63711661,-0.77076743,\ + 0.66791137,0.74424082,\ + 0.68671421,0.06646131,\ + 0.85301727,-0.52188269,\ + 0.88617706,0.46334676,\ + 0.97866100,-0.20548151,\ + 0.98951863,0.14440528} + + + +#define DISC_3_FACE {\ + 27,30,22,\ + 9,2,5,\ + 28,32,34,\ + 22,30,26,\ + 26,19,22,\ + 30,27,33,\ + 31,34,35,\ + 28,34,31,\ + 35,33,31,\ + 31,33,27,\ + 25,31,27,\ + 28,31,25,\ + 10,14,8,\ + 10,13,18,\ + 18,14,10,\ + 15,19,12,\ + 22,19,15,\ + 8,14,15,\ + 11,9,5,\ + 13,9,16,\ + 16,11,17,\ + 9,11,16,\ + 17,23,16,\ + 23,24,16,\ + 29,24,23,\ + 29,32,28,\ + 28,24,29,\ + 6,3,8,\ + 6,15,12,\ + 8,15,6,\ + 20,27,22,\ + 20,25,27,\ + 18,25,20,\ + 20,14,18,\ + 22,15,20,\ + 20,15,14,\ + 21,24,28,\ + 28,25,21,\ + 21,25,18,\ + 18,13,21,\ + 13,16,21,\ + 21,16,24,\ + 4,10,8,\ + 8,3,4,\ + 4,3,1,\ + 4,1,0,\ + 7,9,13,\ + 13,10,7,\ + 2,9,7,\ + 10,4,7,\ + 0,2,7,\ + 7,4,0} + + + +#define DISC_4_VERT_CNT 62 +#define DISC_4_FACE_CNT 97 + +#define DISC_4_VERT {\ + -0.99632399,0.08566510,\ + -0.98618071,-0.16567317,\ + -0.94150749,0.33699206,\ + -0.91375624,-0.40626289,\ + -0.82498245,0.56515834,\ + -0.78016046,-0.62557946,\ + -0.76768368,-0.12856657,\ + -0.73146437,0.13012819,\ + -0.65993870,0.75131945,\ + -0.64923474,-0.36077026,\ + -0.64404827,0.37054571,\ + -0.59590501,-0.80305493,\ + -0.52723292,-0.08735736,\ + -0.50689203,0.59055453,\ + -0.48847116,-0.55956115,\ + -0.45296756,0.16857820,\ + -0.44983881,0.89310976,\ + -0.37980292,-0.92506742,\ + -0.37431635,-0.30508770,\ + -0.34347782,0.41024244,\ + -0.28695359,-0.72298195,\ + -0.26607611,-0.04337891,\ + -0.26300954,0.69460622,\ + -0.20716495,0.97830603,\ + -0.19234589,-0.49854986,\ + -0.17059736,0.21089158,\ + -0.13571649,-0.99074771,\ + -0.09401357,-0.25347966,\ + -0.08264934,0.47807857,\ + -0.02277615,-0.74173498,\ + -0.00823427,0.74390934,\ + 0,0,\ + 0.04408226,0.99902790,\ + 0.07601333,-0.47847223,\ + 0.09881278,0.25627739,\ + 0.12027544,-0.99274056,\ + 0.17542943,-0.20988186,\ + 0.18324588,0.50353410,\ + 0.23456374,-0.70587816,\ + 0.25179308,0.73431645,\ + 0.27247058,0.04550104,\ + 0.28231740,0.95932106,\ + 0.33582739,-0.41640371,\ + 0.36311579,-0.93174402,\ + 0.37710908,0.31092054,\ + 0.45674883,-0.17177025,\ + 0.47051745,0.57347807,\ + 0.47748339,-0.61101450,\ + 0.51751229,0.85567577,\ + 0.53030761,0.08858984,\ + 0.57509454,-0.81808696,\ + 0.63141296,-0.38867064,\ + 0.64811006,0.37237321,\ + 0.71524914,0.69886956,\ + 0.73323663,-0.14081300,\ + 0.76190940,0.12763299,\ + 0.76512049,-0.64388713,\ + 0.85941151,0.51128451,\ + 0.90165136,-0.43246368,\ + 0.95650965,0.29170069,\ + 0.97794249,-0.20887431,\ + 0.99951822,0.03103749} + + + +#define DISC_4_FACE {\ + 39,32,30,\ + 30,22,28,\ + 28,22,19,\ + 52,59,57,\ + 14,11,20,\ + 47,56,51,\ + 50,56,47,\ + 47,43,50,\ + 41,39,48,\ + 32,39,41,\ + 10,4,2,\ + 37,39,30,\ + 30,28,37,\ + 25,28,19,\ + 25,21,31,\ + 53,52,57,\ + 44,52,46,\ + 46,37,44,\ + 39,37,46,\ + 52,53,46,\ + 48,39,46,\ + 46,53,48,\ + 61,59,55,\ + 59,52,55,\ + 42,47,51,\ + 35,29,26,\ + 26,29,20,\ + 12,21,15,\ + 19,10,15,\ + 15,25,19,\ + 21,25,15,\ + 18,21,12,\ + 18,9,14,\ + 12,9,18,\ + 6,9,12,\ + 0,1,6,\ + 5,11,14,\ + 14,9,5,\ + 51,56,58,\ + 16,22,23,\ + 30,32,23,\ + 23,22,30,\ + 34,40,44,\ + 44,37,34,\ + 31,40,34,\ + 34,25,31,\ + 34,37,28,\ + 28,25,34,\ + 51,58,54,\ + 54,58,60,\ + 54,60,61,\ + 61,55,54,\ + 49,52,44,\ + 49,55,52,\ + 44,40,49,\ + 49,54,55,\ + 36,40,31,\ + 33,42,36,\ + 43,47,38,\ + 35,43,38,\ + 38,29,35,\ + 33,29,38,\ + 38,42,33,\ + 47,42,38,\ + 17,20,11,\ + 17,26,20,\ + 21,18,27,\ + 33,36,27,\ + 31,21,27,\ + 27,36,31,\ + 14,20,24,\ + 24,18,14,\ + 24,27,18,\ + 33,27,24,\ + 24,29,33,\ + 20,29,24,\ + 7,6,12,\ + 12,15,7,\ + 7,15,10,\ + 0,6,7,\ + 7,2,0,\ + 7,10,2,\ + 3,5,9,\ + 3,6,1,\ + 9,6,3,\ + 4,10,13,\ + 13,8,4,\ + 13,10,19,\ + 19,22,13,\ + 13,22,16,\ + 16,8,13,\ + 45,49,40,\ + 40,36,45,\ + 45,36,42,\ + 45,42,51,\ + 51,54,45,\ + 54,49,45} + + + + +#define DISC_5_VERT_CNT 88 +#define DISC_5_FACE_CNT 143 + +#define DISC_5_VERT {\ + -0.99971936,0.02368974,\ + -0.98497387,-0.17270345,\ + -0.97603282,0.21762338,\ + -0.92922869,-0.36950514,\ + -0.90708773,0.42094161,\ + -0.83415725,-0.55152668,\ + -0.79951365,0.60064792,\ + -0.79931959,-0.15614114,\ + -0.78301036,0.22417418,\ + -0.73599072,0.03246227,\ + -0.70693097,-0.70728255,\ + -0.70138379,-0.35670071,\ + -0.67496938,0.41799949,\ + -0.65872201,0.75238641,\ + -0.58146330,-0.53578945,\ + -0.58133729,-0.15872598,\ + -0.56876870,0.21601921,\ + -0.55866507,-0.82939335,\ + -0.55072833,0.60394640,\ + -0.49682087,0.86785311,\ + -0.48121950,0.02702752,\ + -0.46260160,-0.34991875,\ + -0.44852900,0.41519979,\ + -0.43296059,-0.69074037,\ + -0.37854350,-0.92558350,\ + -0.35947581,0.64734787,\ + -0.34709225,-0.16018397,\ + -0.34353985,0.21908385,\ + -0.32747735,-0.52839636,\ + -0.31642072,0.94861896,\ + -0.23858273,0.43081147,\ + -0.23727114,0.03012835,\ + -0.22120862,-0.34920779,\ + -0.22093401,0.78635645,\ + -0.21959193,-0.75538583,\ + -0.18652372,-0.98245046,\ + -0.11983054,0.22928120,\ + -0.11608911,0.60000621,\ + -0.11215562,-0.15690164,\ + -0.10768432,-0.56628902,\ + -0.10628279,0.99433594,\ + -0.00000000,-1.00000000,\ + -0.00000000,-0.78706952,\ + 0.00000000,0.41097842,\ + 0.00000000,-0.36210641,\ + 0.00000000,0.04030552,\ + 0.00000000,0.79503467,\ + 0.10628279,0.99433594,\ + 0.10768432,-0.56628902,\ + 0.11215563,-0.15690164,\ + 0.11608911,0.60000621,\ + 0.11983054,0.22928120,\ + 0.18652372,-0.98245046,\ + 0.21959193,-0.75538583,\ + 0.22093402,0.78635645,\ + 0.22120862,-0.34920779,\ + 0.23727114,0.03012835,\ + 0.23858273,0.43081146,\ + 0.31642072,0.94861896,\ + 0.32747735,-0.52839636,\ + 0.34353985,0.21908385,\ + 0.34709225,-0.16018397,\ + 0.35947581,0.64734786,\ + 0.37854350,-0.92558350,\ + 0.43296060,-0.69074037,\ + 0.44852900,0.41519979,\ + 0.46260160,-0.34991875,\ + 0.48121950,0.02702752,\ + 0.49682086,0.86785311,\ + 0.55072833,0.60394640,\ + 0.55866507,-0.82939336,\ + 0.56876870,0.21601921,\ + 0.58133729,-0.15872598,\ + 0.58146330,-0.53578945,\ + 0.65872201,0.75238641,\ + 0.67496938,0.41799949,\ + 0.70138379,-0.35670071,\ + 0.70693097,-0.70728255,\ + 0.73599072,0.03246227,\ + 0.78301036,0.22417418,\ + 0.79931959,-0.15614114,\ + 0.79951365,0.60064792,\ + 0.83415725,-0.55152668,\ + 0.90708773,0.42094162,\ + 0.92922869,-0.36950514,\ + 0.97603282,0.21762338,\ + 0.98497387,-0.17270345,\ + 0.99971936,0.02368974} + + + +#define DISC_5_FACE {\ + 12,6,4,\ + 69,74,68,\ + 1,7,0,\ + 81,74,69,\ + 75,81,69,\ + 83,81,75,\ + 41,52,42,\ + 42,35,41,\ + 42,52,53,\ + 53,48,42,\ + 73,64,77,\ + 77,64,70,\ + 45,31,38,\ + 56,61,67,\ + 42,48,39,\ + 59,73,66,\ + 64,73,59,\ + 59,53,64,\ + 48,53,59,\ + 40,46,47,\ + 36,31,45,\ + 30,36,43,\ + 54,47,46,\ + 18,6,12,\ + 9,0,7,\ + 27,36,30,\ + 31,36,27,\ + 72,61,66,\ + 67,61,72,\ + 5,11,3,\ + 3,7,1,\ + 3,11,7,\ + 17,23,10,\ + 24,23,17,\ + 10,23,14,\ + 14,5,10,\ + 14,11,5,\ + 24,35,34,\ + 34,23,24,\ + 34,35,42,\ + 42,39,34,\ + 83,75,79,\ + 79,85,83,\ + 75,71,79,\ + 80,72,76,\ + 66,73,76,\ + 76,72,66,\ + 63,53,52,\ + 64,53,63,\ + 63,70,64,\ + 44,38,32,\ + 32,39,44,\ + 44,39,48,\ + 66,61,55,\ + 55,59,66,\ + 48,59,55,\ + 55,44,48,\ + 30,43,37,\ + 51,36,45,\ + 51,43,36,\ + 45,56,51,\ + 57,43,51,\ + 58,54,68,\ + 47,54,58,\ + 69,68,62,\ + 68,54,62,\ + 50,43,57,\ + 57,62,50,\ + 50,62,54,\ + 50,54,46,\ + 46,37,50,\ + 50,37,43,\ + 6,18,13,\ + 13,18,19,\ + 12,4,8,\ + 4,2,8,\ + 8,2,0,\ + 0,9,8,\ + 7,11,15,\ + 15,9,7,\ + 12,8,16,\ + 16,8,9,\ + 28,14,23,\ + 28,39,32,\ + 23,34,28,\ + 28,34,39,\ + 85,79,87,\ + 87,80,86,\ + 86,80,84,\ + 80,76,84,\ + 38,44,49,\ + 44,55,49,\ + 45,38,49,\ + 49,56,45,\ + 61,56,49,\ + 49,55,61,\ + 19,18,25,\ + 30,37,25,\ + 60,51,56,\ + 60,56,67,\ + 67,71,60,\ + 57,51,60,\ + 32,38,26,\ + 26,38,31,\ + 22,27,30,\ + 22,16,27,\ + 30,25,22,\ + 22,25,18,\ + 22,18,12,\ + 12,16,22,\ + 9,15,20,\ + 20,16,9,\ + 15,26,20,\ + 20,26,31,\ + 31,27,20,\ + 27,16,20,\ + 14,28,21,\ + 11,14,21,\ + 21,15,11,\ + 21,26,15,\ + 21,28,32,\ + 32,26,21,\ + 80,87,78,\ + 78,87,79,\ + 67,72,78,\ + 78,72,80,\ + 78,71,67,\ + 78,79,71,\ + 82,84,76,\ + 82,73,77,\ + 82,76,73,\ + 33,29,19,\ + 19,25,33,\ + 33,25,37,\ + 33,37,46,\ + 33,46,40,\ + 40,29,33,\ + 57,60,65,\ + 65,60,71,\ + 69,62,65,\ + 65,62,57,\ + 65,75,69,\ + 65,71,75} + + + + +#define DISC_6_VERT_CNT 93 +#define DISC_6_FACE_CNT 152 + +#define DISC_6_VERT {\ + -0.99999594,0.00284872,\ + -0.98015885,-0.19821361,\ + -0.97910452,0.20335765,\ + -0.91824742,-0.39600716,\ + -0.91642084,0.40021599,\ + -0.82654691,0.00183534,\ + -0.81562261,-0.57858427,\ + -0.81341702,0.58168096,\ + -0.77345426,-0.20215086,\ + -0.77255314,0.20518626,\ + -0.68941469,-0.39022484,\ + -0.68791137,0.39257991,\ + -0.68013707,-0.73308497,\ + -0.67797282,0.73508697,\ + -0.63913549,0.00098160,\ + -0.57136133,-0.55867323,\ + -0.56955891,0.56029565,\ + -0.55841446,-0.19717430,\ + -0.55776663,0.19858763,\ + -0.52205785,-0.85291008,\ + -0.52029081,0.85398915,\ + -0.45456737,-0.37917968,\ + -0.45349521,0.38007165,\ + -0.43551427,0.00036391,\ + -0.42468657,-0.70786237,\ + -0.42282245,0.70855298,\ + -0.34062665,-0.94019864,\ + -0.33950518,0.94060419,\ + -0.33587158,-0.19180791,\ + -0.33542840,0.19211031,\ + -0.32531753,-0.54899060,\ + -0.32390201,0.54924250,\ + -0.22054068,-0.00005872,\ + -0.22037192,-0.37480908,\ + -0.21959501,-0.76658281,\ + -0.21957535,0.37471145,\ + -0.21830443,0.76666288,\ + -0.14314293,0.98970203,\ + -0.14308842,-0.98970991,\ + -0.11182136,-0.58019243,\ + -0.11144867,-0.19011261,\ + -0.11115009,0.18972064,\ + -0.11083876,0.58001267,\ + -0.01202610,-0.79058356,\ + -0.01146133,0.79044198,\ + -0.00296704,-0.38655576,\ + -0.00251985,0.38610767,\ + 0,0,\ + 0.05602383,0.99842943,\ + 0.05753595,-0.99834343,\ + 0.09929965,-0.58445955,\ + 0.09978787,0.58401640,\ + 0.10720128,-0.19034642,\ + 0.10732262,0.18978004,\ + 0.19741560,0.77640808,\ + 0.19782431,-0.77666373,\ + 0.21311365,-0.37705130,\ + 0.21324783,0.37657351,\ + 0.21679316,-0.00027758,\ + 0.24086510,0.97055861,\ + 0.24384828,-0.96981339,\ + 0.30997719,-0.55399954,\ + 0.31009251,0.55373347,\ + 0.33113518,0.19349214,\ + 0.33116892,-0.19395431,\ + 0.39954032,0.71482630,\ + 0.40007135,-0.71505270,\ + 0.41738752,0.90872859,\ + 0.42074496,-0.90717897,\ + 0.43070234,-0.00023714,\ + 0.44655149,0.38465307,\ + 0.44677948,-0.38475532,\ + 0.55227292,0.20180373,\ + 0.55259284,-0.20199491,\ + 0.55434766,-0.56898461,\ + 0.55436109,0.56926283,\ + 0.59389421,0.80454314,\ + 0.59763878,-0.80176548,\ + 0.63149306,0.00014085,\ + 0.68111571,0.39992460,\ + 0.68195669,-0.39979424,\ + 0.73619953,0.67676455,\ + 0.74034292,-0.67222940,\ + 0.76516557,0.20592197,\ + 0.76545629,-0.20472387,\ + 0.81831151,0.00107473,\ + 0.84515629,0.53451927,\ + 0.84868724,-0.52889504,\ + 0.92889024,0.37035513,\ + 0.93106261,-0.36485945,\ + 0.98208177,0.18845531,\ + 0.98320733,-0.18249206,\ + 0.99999396,0.00347432} + + + +#define DISC_6_FACE {\ + 47,41,32,\ + 29,35,22,\ + 23,32,29,\ + 29,41,35,\ + 29,32,41,\ + 8,1,3,\ + 6,12,15,\ + 15,12,24,\ + 71,80,73,\ + 73,80,84,\ + 19,24,12,\ + 87,80,82,\ + 48,37,44,\ + 22,11,18,\ + 18,29,22,\ + 23,29,18,\ + 44,37,36,\ + 22,35,31,\ + 31,36,25,\ + 16,11,22,\ + 22,31,16,\ + 16,31,25,\ + 16,25,13,\ + 13,7,16,\ + 16,7,11,\ + 11,7,4,\ + 13,25,20,\ + 69,73,78,\ + 78,73,84,\ + 49,60,55,\ + 60,68,55,\ + 47,32,40,\ + 45,33,39,\ + 45,40,33,\ + 0,1,5,\ + 1,8,5,\ + 39,33,30,\ + 33,21,30,\ + 15,24,30,\ + 30,21,15,\ + 39,30,34,\ + 34,30,24,\ + 24,19,26,\ + 26,34,24,\ + 74,80,71,\ + 74,82,80,\ + 77,82,74,\ + 0,5,2,\ + 27,36,37,\ + 25,36,27,\ + 27,20,25,\ + 89,91,84,\ + 89,80,87,\ + 84,80,89,\ + 85,83,78,\ + 85,78,84,\ + 92,90,85,\ + 85,90,83,\ + 85,91,92,\ + 84,91,85,\ + 69,58,64,\ + 64,73,69,\ + 64,56,71,\ + 71,73,64,\ + 81,76,75,\ + 63,58,69,\ + 78,83,72,\ + 70,63,72,\ + 69,78,72,\ + 72,63,69,\ + 35,41,46,\ + 28,21,33,\ + 28,32,23,\ + 28,40,32,\ + 33,40,28,\ + 15,21,10,\ + 10,6,15,\ + 10,8,3,\ + 3,6,10,\ + 47,40,52,\ + 52,58,47,\ + 52,45,56,\ + 40,45,52,\ + 56,64,52,\ + 52,64,58,\ + 56,45,50,\ + 50,45,39,\ + 9,18,11,\ + 9,2,5,\ + 11,4,9,\ + 4,2,9,\ + 23,18,14,\ + 14,5,8,\ + 18,9,14,\ + 14,9,5,\ + 34,26,38,\ + 83,90,88,\ + 79,86,81,\ + 81,75,79,\ + 79,88,86,\ + 83,88,79,\ + 79,75,70,\ + 79,72,83,\ + 70,72,79,\ + 65,76,67,\ + 65,75,76,\ + 44,36,42,\ + 36,31,42,\ + 42,31,35,\ + 35,46,42,\ + 53,46,41,\ + 58,63,53,\ + 53,41,47,\ + 47,58,53,\ + 66,55,68,\ + 66,68,77,\ + 77,74,66,\ + 55,50,43,\ + 49,55,43,\ + 39,34,43,\ + 43,50,39,\ + 43,38,49,\ + 34,38,43,\ + 21,28,17,\ + 17,14,8,\ + 17,28,23,\ + 23,14,17,\ + 8,10,17,\ + 17,10,21,\ + 67,59,54,\ + 54,65,67,\ + 54,48,44,\ + 54,59,48,\ + 56,50,61,\ + 71,56,61,\ + 61,50,55,\ + 55,66,61,\ + 61,74,71,\ + 61,66,74,\ + 65,54,62,\ + 70,75,62,\ + 75,65,62,\ + 57,63,70,\ + 70,62,57,\ + 57,53,63,\ + 46,53,57,\ + 51,62,54,\ + 44,42,51,\ + 51,54,44,\ + 51,42,46,\ + 46,57,51,\ + 51,57,62} + + + +#define DISC_7_VERT_CNT 362 +#define DISC_7_FACE_CNT 661 + +#define DISC_7_VERT {\ + -0.99985012,-0.01731283,\ + -0.99556874,0.09403658,\ + -0.99269568,-0.12064526,\ + -0.98039504,0.19704206,\ + -0.97660422,-0.21504466,\ + -0.95669950,0.29107744,\ + -0.94719722,-0.32065158,\ + -0.91977079,0.03476785,\ + -0.91823426,0.39603769,\ + -0.90908498,-0.41661072,\ + -0.90313047,-0.06293630,\ + -0.89624049,0.13116999,\ + -0.88570063,-0.24444376,\ + -0.87285522,-0.15209581,\ + -0.87078228,0.49166881,\ + -0.86621517,0.31178476,\ + -0.86248349,-0.50608521,\ + -0.86017585,0.21863311,\ + -0.84170474,-0.33621087,\ + -0.83205774,0.03180995,\ + -0.81606482,0.40088958,\ + -0.81347977,0.58159321,\ + -0.80880317,-0.58807944,\ + -0.80047290,-0.06419645,\ + -0.79753133,-0.42481233,\ + -0.79434394,0.12621664,\ + -0.79156344,-0.23872410,\ + -0.77379916,0.30089105,\ + -0.76620500,0.48702609,\ + -0.75497372,-0.15055590,\ + -0.74996331,-0.66147943,\ + -0.74615778,0.66576915,\ + -0.74611496,-0.50995249,\ + -0.74439196,0.21110739,\ + -0.73952765,-0.33215782,\ + -0.73635629,0.02967200,\ + -0.71613187,0.39251295,\ + -0.70989537,0.56969253,\ + -0.69311231,-0.06262365,\ + -0.68888730,-0.42215438,\ + -0.68873812,-0.23996840,\ + -0.68789855,0.12079206,\ + -0.68775964,-0.59169999,\ + -0.68761735,-0.72607326,\ + -0.67300273,0.29907892,\ + -0.66617715,0.74579354,\ + -0.66034349,0.48247936,\ + -0.64930335,0.65326738,\ + -0.64251884,-0.15112779,\ + -0.63423915,-0.33167241,\ + -0.63367295,0.20893327,\ + -0.63361786,0.02834503,\ + -0.63317442,-0.50922682,\ + -0.62397583,-0.67129702,\ + -0.61668934,-0.78720662,\ + -0.61382861,0.39179584,\ + -0.60234812,0.57529738,\ + -0.58817577,0.80873313,\ + -0.58517097,-0.06261579,\ + -0.58292803,-0.24113073,\ + -0.58101739,0.11895848,\ + -0.58024100,-0.42164527,\ + -0.57342995,-0.59754731,\ + -0.57023518,0.30017942,\ + -0.56388698,0.69225657,\ + -0.55806447,0.48787902,\ + -0.54436708,-0.83884712,\ + -0.53471317,-0.71163670,\ + -0.53277416,-0.15208460,\ + -0.52845305,0.02779153,\ + -0.52768053,-0.33226318,\ + -0.52626233,0.20945175,\ + -0.52436621,-0.51257293,\ + -0.51322968,0.39629252,\ + -0.51105321,0.59594241,\ + -0.50070339,0.77148078,\ + -0.49284021,0.87011983,\ + -0.47766585,-0.06283164,\ + -0.47709422,-0.61587203,\ + -0.47617996,-0.24250847,\ + -0.47492414,0.11881178,\ + -0.47332691,-0.42421072,\ + -0.46891581,-0.78503521,\ + -0.46830639,0.30307673,\ + -0.46266258,0.49810707,\ + -0.45708829,-0.88942133,\ + -0.45576672,0.68708769,\ + -0.42465108,-0.15303707,\ + -0.42294844,-0.52165201,\ + -0.42277498,0.02774444,\ + -0.42158212,-0.33447625,\ + -0.42154903,-0.70155621,\ + -0.42118179,0.21073835,\ + -0.41518847,0.40003793,\ + -0.41014166,0.59223339,\ + -0.40595577,0.79572144,\ + -0.40144096,0.91588490,\ + -0.37410706,-0.80603587,\ + -0.37388864,-0.92747360,\ + -0.37084162,-0.06308696,\ + -0.36985500,-0.24430894,\ + -0.36964213,0.11912678,\ + -0.36961888,-0.42913077,\ + -0.36950835,-0.61160154,\ + -0.36692104,0.30426278,\ + -0.36331898,0.49302705,\ + -0.36031066,0.69223951,\ + -0.32677443,0.86283108,\ + -0.32032475,-0.70802805,\ + -0.31779837,-0.15404597,\ + -0.31719767,0.02776004,\ + -0.31689202,0.21090281,\ + -0.31681891,-0.33738863,\ + -0.31666303,-0.52066428,\ + -0.31403760,0.39611874,\ + -0.31388083,0.58661750,\ + -0.30063918,-0.86564249,\ + -0.29840707,0.77598627,\ + -0.29118453,0.95666691,\ + -0.27578713,-0.96121874,\ + -0.26535904,-0.61453199,\ + -0.26467581,-0.06343296,\ + -0.26450873,0.11893627,\ + -0.26436076,-0.24601298,\ + -0.26386624,-0.42936315,\ + -0.26351074,0.30183375,\ + -0.26320615,0.48755038,\ + -0.26150637,0.67168140,\ + -0.25770896,-0.78444210,\ + -0.22690147,0.87620642,\ + -0.21182740,-0.15487168,\ + -0.21176386,0.20951421,\ + -0.21163724,0.02750973,\ + -0.21161899,-0.52224104,\ + -0.21148579,0.57469681,\ + -0.21131707,0.39241729,\ + -0.21124026,-0.33783321,\ + -0.20788702,-0.69983910,\ + -0.20780059,0.75906391,\ + -0.20086952,-0.88206044,\ + -0.18395374,0.98293490,\ + -0.18078187,-0.98352322,\ + -0.15939234,0.66160521,\ + -0.15922292,0.48080959,\ + -0.15893189,0.29974615,\ + -0.15886894,0.11819996,\ + -0.15886163,-0.06382544,\ + -0.15865438,-0.24633350,\ + -0.15863082,-0.43008689,\ + -0.15661700,-0.61168462,\ + -0.15237093,-0.79492683,\ + -0.14769055,0.83062127,\ + -0.11501977,0.91942011,\ + -0.10707064,0.56845652,\ + -0.10625023,0.38888883,\ + -0.10607360,0.20858766,\ + -0.10591811,-0.15506199,\ + -0.10588439,0.02705488,\ + -0.10583335,-0.33807967,\ + -0.10494986,-0.52098388,\ + -0.10231280,-0.70603494,\ + -0.10202215,-0.89648205,\ + -0.10135860,0.74182236,\ + -0.08783976,-0.99613462,\ + -0.06094913,0.99814087,\ + -0.05373065,0.47736208,\ + -0.05303096,0.29846240,\ + -0.05301161,0.11775088,\ + -0.05297384,-0.24644770,\ + -0.05294683,-0.06402902,\ + -0.05272878,-0.42928088,\ + -0.05216873,0.65274562,\ + -0.05171351,-0.61420930,\ + -0.05042272,-0.80478224,\ + -0.04981932,0.83515353,\ + -0.00000007,0.92251132,\ + -0.00000000,-1.00000000,\ + -0.00000000,-0.52141188,\ + -0.00000000,-0.33767237,\ + 0.00000000,-0.70961850,\ + 0.00000000,-0.15521184,\ + 0,0,\ + 0.00000000,-0.90399448,\ + 0.00000001,0.74344973,\ + 0.00000001,0.20812464,\ + 0.00000002,0.38816625,\ + 0.00000004,0.56389175,\ + 0.04981924,0.83515359,\ + 0.05042272,-0.80478224,\ + 0.05171351,-0.61420930,\ + 0.05216879,0.65274567,\ + 0.05272878,-0.42928088,\ + 0.05294684,-0.06402902,\ + 0.05297384,-0.24644770,\ + 0.05301163,0.11775088,\ + 0.05303099,0.29846241,\ + 0.05373071,0.47736209,\ + 0.06094900,0.99814088,\ + 0.08783975,-0.99613462,\ + 0.10135861,0.74182246,\ + 0.10202215,-0.89648204,\ + 0.10231280,-0.70603494,\ + 0.10494986,-0.52098388,\ + 0.10583335,-0.33807967,\ + 0.10588440,0.02705488,\ + 0.10591811,-0.15506198,\ + 0.10607362,0.20858767,\ + 0.10625027,0.38888884,\ + 0.10707069,0.56845657,\ + 0.11501967,0.91942019,\ + 0.14769048,0.83062141,\ + 0.15237093,-0.79492683,\ + 0.15661700,-0.61168462,\ + 0.15863082,-0.43008689,\ + 0.15865438,-0.24633350,\ + 0.15886163,-0.06382544,\ + 0.15886894,0.11819997,\ + 0.15893191,0.29974616,\ + 0.15922296,0.48080963,\ + 0.15939237,0.66160531,\ + 0.18078187,-0.98352322,\ + 0.18395366,0.98293492,\ + 0.20086952,-0.88206044,\ + 0.20780058,0.75906403,\ + 0.20788703,-0.69983910,\ + 0.21124026,-0.33783320,\ + 0.21131709,0.39241733,\ + 0.21148582,0.57469689,\ + 0.21161899,-0.52224104,\ + 0.21163724,0.02750974,\ + 0.21176387,0.20951422,\ + 0.21182740,-0.15487168,\ + 0.22690143,0.87620648,\ + 0.25770896,-0.78444210,\ + 0.26150638,0.67168149,\ + 0.26320616,0.48755045,\ + 0.26351075,0.30183378,\ + 0.26386624,-0.42936315,\ + 0.26436076,-0.24601298,\ + 0.26450874,0.11893627,\ + 0.26467581,-0.06343296,\ + 0.26535904,-0.61453199,\ + 0.27578713,-0.96121874,\ + 0.29118450,0.95666692,\ + 0.29840708,0.77598632,\ + 0.30063918,-0.86564249,\ + 0.31388084,0.58661757,\ + 0.31403760,0.39611878,\ + 0.31666303,-0.52066428,\ + 0.31681891,-0.33738863,\ + 0.31689202,0.21090283,\ + 0.31719767,0.02776005,\ + 0.31779837,-0.15404597,\ + 0.32032475,-0.70802805,\ + 0.32677441,0.86283111,\ + 0.36031068,0.69223956,\ + 0.36331898,0.49302709,\ + 0.36692104,0.30426281,\ + 0.36950835,-0.61160154,\ + 0.36961888,-0.42913077,\ + 0.36964213,0.11912679,\ + 0.36985500,-0.24430894,\ + 0.37084162,-0.06308696,\ + 0.37388864,-0.92747360,\ + 0.37410706,-0.80603587,\ + 0.40144095,0.91588491,\ + 0.40595578,0.79572146,\ + 0.41014167,0.59223341,\ + 0.41518846,0.40003796,\ + 0.42118179,0.21073837,\ + 0.42154903,-0.70155621,\ + 0.42158212,-0.33447625,\ + 0.42277498,0.02774444,\ + 0.42294844,-0.52165201,\ + 0.42465107,-0.15303707,\ + 0.45576674,0.68708770,\ + 0.45708829,-0.88942133,\ + 0.46266258,0.49810708,\ + 0.46830638,0.30307674,\ + 0.46891581,-0.78503521,\ + 0.47332691,-0.42421072,\ + 0.47492414,0.11881178,\ + 0.47617996,-0.24250847,\ + 0.47709422,-0.61587202,\ + 0.47766585,-0.06283163,\ + 0.49284023,0.87011983,\ + 0.50070341,0.77148078,\ + 0.51105323,0.59594241,\ + 0.51322967,0.39629252,\ + 0.52436621,-0.51257293,\ + 0.52626232,0.20945175,\ + 0.52768053,-0.33226318,\ + 0.52845304,0.02779153,\ + 0.53277416,-0.15208460,\ + 0.53471317,-0.71163670,\ + 0.54436708,-0.83884712,\ + 0.55806447,0.48787901,\ + 0.56388700,0.69225656,\ + 0.57023517,0.30017942,\ + 0.57342995,-0.59754731,\ + 0.58024100,-0.42164527,\ + 0.58101739,0.11895848,\ + 0.58292803,-0.24113073,\ + 0.58517097,-0.06261579,\ + 0.58817579,0.80873311,\ + 0.60234813,0.57529736,\ + 0.61382860,0.39179583,\ + 0.61668934,-0.78720662,\ + 0.62397583,-0.67129702,\ + 0.63317442,-0.50922682,\ + 0.63361786,0.02834503,\ + 0.63367294,0.20893326,\ + 0.63423914,-0.33167241,\ + 0.64251884,-0.15112779,\ + 0.64930338,0.65326736,\ + 0.66034350,0.48247934,\ + 0.66617718,0.74579352,\ + 0.67300273,0.29907891,\ + 0.68761735,-0.72607326,\ + 0.68775963,-0.59169999,\ + 0.68789855,0.12079206,\ + 0.68873811,-0.23996840,\ + 0.68888730,-0.42215438,\ + 0.69311230,-0.06262365,\ + 0.70989539,0.56969251,\ + 0.71613187,0.39251294,\ + 0.73635629,0.02967200,\ + 0.73952765,-0.33215782,\ + 0.74439196,0.21110738,\ + 0.74611496,-0.50995249,\ + 0.74615780,0.66576913,\ + 0.74996331,-0.66147943,\ + 0.75497372,-0.15055590,\ + 0.76620501,0.48702606,\ + 0.77379915,0.30089103,\ + 0.79156344,-0.23872410,\ + 0.79434393,0.12621664,\ + 0.79753133,-0.42481233,\ + 0.80047290,-0.06419645,\ + 0.80880317,-0.58807944,\ + 0.81347979,0.58159318,\ + 0.81606483,0.40088956,\ + 0.83205774,0.03180994,\ + 0.84170474,-0.33621087,\ + 0.86017584,0.21863310,\ + 0.86248349,-0.50608521,\ + 0.86621517,0.31178474,\ + 0.87078230,0.49166878,\ + 0.87285522,-0.15209581,\ + 0.88570063,-0.24444376,\ + 0.89624049,0.13116998,\ + 0.90313047,-0.06293631,\ + 0.90908498,-0.41661073,\ + 0.91823427,0.39603766,\ + 0.91977079,0.03476784,\ + 0.94719721,-0.32065159,\ + 0.95669950,0.29107742,\ + 0.97660422,-0.21504466,\ + 0.98039504,0.19704204,\ + 0.99269568,-0.12064527,\ + 0.99556874,0.09403657,\ + 0.99985012,-0.01731283} + + + +#define DISC_7_FACE {\ + 130,146,121,\ + 17,5,3,\ + 197,209,221,\ + 243,221,232,\ + 232,221,209,\ + 226,218,207,\ + 67,66,82,\ + 132,121,146,\ + 122,110,132,\ + 132,110,121,\ + 3,1,11,\ + 11,17,3,\ + 54,67,53,\ + 66,67,54,\ + 294,279,295,\ + 295,279,276,\ + 279,294,270,\ + 272,251,262,\ + 357,349,355,\ + 236,230,250,\ + 50,44,33,\ + 63,44,50,\ + 277,296,287,\ + 226,236,247,\ + 166,184,195,\ + 155,184,166,\ + 93,104,114,\ + 135,114,125,\ + 125,114,104,\ + 76,95,96,\ + 106,115,127,\ + 115,106,94,\ + 117,106,127,\ + 95,106,117,\ + 196,207,218,\ + 126,114,135,\ + 127,115,134,\ + 115,126,134,\ + 162,174,151,\ + 53,67,62,\ + 72,52,62,\ + 237,259,249,\ + 136,148,158,\ + 122,132,145,\ + 130,121,109,\ + 109,123,130,\ + 100,123,109,\ + 0,2,10,\ + 37,47,31,\ + 31,47,45,\ + 45,47,64,\ + 64,57,45,\ + 75,57,64,\ + 75,95,76,\ + 76,57,75,\ + 28,14,20,\ + 361,360,354,\ + 321,327,335,\ + 357,359,348,\ + 348,349,357,\ + 348,335,349,\ + 315,296,306,\ + 315,306,325,\ + 41,50,33,\ + 329,339,345,\ + 242,222,220,\ + 176,198,182,\ + 182,163,176,\ + 161,182,173,\ + 163,182,161,\ + 53,30,43,\ + 43,54,53,\ + 24,9,16,\ + 42,22,30,\ + 42,30,53,\ + 53,62,42,\ + 42,62,52,\ + 307,294,295,\ + 299,294,308,\ + 308,307,318,\ + 294,307,308,\ + 283,299,289,\ + 283,294,299,\ + 283,270,294,\ + 225,237,249,\ + 217,207,195,\ + 230,236,217,\ + 226,207,217,\ + 217,236,226,\ + 250,230,239,\ + 239,229,251,\ + 251,272,260,\ + 260,272,281,\ + 281,269,260,\ + 260,239,251,\ + 260,269,250,\ + 250,239,260,\ + 268,247,257,\ + 257,247,236,\ + 250,269,257,\ + 257,236,250,\ + 314,330,316,\ + 244,232,223,\ + 84,73,93,\ + 84,94,74,\ + 83,73,63,\ + 83,92,104,\ + 83,104,93,\ + 93,73,83,\ + 55,44,63,\ + 63,73,55,\ + 246,227,235,\ + 235,227,218,\ + 235,218,226,\ + 226,247,235,\ + 181,204,194,\ + 155,166,144,\ + 135,125,144,\ + 80,92,71,\ + 63,50,71,\ + 71,83,63,\ + 92,83,71,\ + 101,110,122,\ + 101,92,80,\ + 131,144,125,\ + 122,145,131,\ + 131,145,155,\ + 155,144,131,\ + 111,125,104,\ + 104,92,111,\ + 111,131,125,\ + 122,131,111,\ + 111,101,122,\ + 92,101,111,\ + 164,140,152,\ + 151,174,152,\ + 138,117,127,\ + 138,162,151,\ + 129,140,118,\ + 129,138,151,\ + 117,138,129,\ + 151,152,129,\ + 129,152,140,\ + 107,96,95,\ + 95,117,107,\ + 118,96,107,\ + 107,129,118,\ + 117,129,107,\ + 197,164,175,\ + 175,209,197,\ + 164,152,175,\ + 175,152,174,\ + 218,227,208,\ + 208,196,218,\ + 195,207,185,\ + 207,196,185,\ + 196,165,185,\ + 185,166,195,\ + 143,126,135,\ + 143,165,153,\ + 153,134,143,\ + 143,134,126,\ + 93,114,105,\ + 114,126,105,\ + 105,84,93,\ + 105,126,115,\ + 115,94,105,\ + 94,84,105,\ + 78,62,67,\ + 78,88,72,\ + 72,62,78,\ + 13,2,4,\ + 13,10,2,\ + 18,9,24,\ + 18,6,9,\ + 61,52,72,\ + 61,39,52,\ + 49,39,61,\ + 150,161,173,\ + 81,61,72,\ + 72,88,81,\ + 88,102,81,\ + 124,148,136,\ + 238,252,231,\ + 238,225,249,\ + 249,259,271,\ + 136,158,147,\ + 147,123,136,\ + 130,123,147,\ + 68,77,58,\ + 58,48,68,\ + 155,145,167,\ + 167,184,155,\ + 181,194,167,\ + 167,194,184,\ + 112,123,100,\ + 136,123,112,\ + 112,124,136,\ + 102,124,112,\ + 89,101,80,\ + 110,101,89,\ + 87,68,79,\ + 77,68,87,\ + 79,100,87,\ + 100,109,87,\ + 21,14,28,\ + 28,37,21,\ + 21,37,31,\ + 86,64,74,\ + 86,75,64,\ + 74,94,86,\ + 86,94,106,\ + 86,106,95,\ + 95,75,86,\ + 8,20,14,\ + 33,44,27,\ + 27,17,33,\ + 351,359,361,\ + 351,348,359,\ + 361,354,351,\ + 351,354,342,\ + 321,335,332,\ + 335,348,332,\ + 310,301,292,\ + 281,272,292,\ + 292,301,281,\ + 353,346,356,\ + 344,346,334,\ + 358,356,344,\ + 344,356,346,\ + 288,306,296,\ + 288,277,268,\ + 296,277,288,\ + 298,306,288,\ + 334,346,341,\ + 341,325,334,\ + 341,346,353,\ + 341,353,347,\ + 50,41,60,\ + 80,71,60,\ + 60,71,50,\ + 309,329,322,\ + 322,300,309,\ + 289,299,309,\ + 309,300,289,\ + 173,182,188,\ + 188,179,173,\ + 211,201,188,\ + 188,201,179,\ + 200,182,198,\ + 198,220,200,\ + 200,220,222,\ + 200,188,182,\ + 200,222,211,\ + 211,188,200,\ + 163,161,141,\ + 82,66,85,\ + 116,98,119,\ + 22,42,32,\ + 24,16,32,\ + 32,16,22,\ + 32,39,24,\ + 32,42,52,\ + 52,39,32,\ + 331,308,318,\ + 212,201,224,\ + 224,201,211,\ + 258,253,270,\ + 270,283,258,\ + 245,242,263,\ + 245,222,242,\ + 237,225,213,\ + 231,252,240,\ + 251,229,240,\ + 262,251,240,\ + 240,252,262,\ + 349,335,343,\ + 343,335,327,\ + 343,355,349,\ + 343,352,355,\ + 333,341,347,\ + 315,325,333,\ + 325,341,333,\ + 340,333,347,\ + 330,314,324,\ + 315,333,324,\ + 324,340,330,\ + 333,340,324,\ + 287,296,305,\ + 305,296,315,\ + 315,324,305,\ + 305,324,314,\ + 234,244,223,\ + 234,227,246,\ + 243,232,254,\ + 232,244,254,\ + 254,265,243,\ + 46,37,28,\ + 44,55,36,\ + 36,27,44,\ + 20,27,36,\ + 55,46,36,\ + 28,20,36,\ + 36,46,28,\ + 246,235,256,\ + 268,277,256,\ + 256,247,268,\ + 256,235,247,\ + 206,217,195,\ + 230,217,206,\ + 195,184,206,\ + 184,194,206,\ + 142,138,127,\ + 162,138,142,\ + 127,134,142,\ + 142,134,153,\ + 210,199,223,\ + 223,232,210,\ + 210,232,209,\ + 174,162,183,\ + 219,208,227,\ + 227,234,219,\ + 223,199,219,\ + 219,234,223,\ + 196,208,186,\ + 153,165,186,\ + 186,165,196,\ + 135,144,154,\ + 154,143,135,\ + 165,143,154,\ + 154,144,166,\ + 166,185,154,\ + 154,185,165,\ + 82,85,97,\ + 97,85,98,\ + 98,116,97,\ + 34,39,49,\ + 34,18,24,\ + 24,39,34,\ + 6,18,12,\ + 4,6,12,\ + 12,13,4,\ + 202,213,191,\ + 137,150,160,\ + 179,172,160,\ + 173,179,160,\ + 160,150,173,\ + 61,81,70,\ + 49,61,70,\ + 113,102,88,\ + 113,124,102,\ + 291,271,280,\ + 289,300,280,\ + 280,300,291,\ + 280,271,259,\ + 261,238,249,\ + 249,271,261,\ + 252,238,261,\ + 322,327,312,\ + 312,300,322,\ + 312,327,321,\ + 321,302,312,\ + 291,300,312,\ + 312,302,291,\ + 156,146,130,\ + 130,147,156,\ + 59,68,48,\ + 49,70,59,\ + 79,68,59,\ + 59,70,79,\ + 157,167,145,\ + 157,132,146,\ + 157,145,132,\ + 181,167,157,\ + 90,112,100,\ + 90,70,81,\ + 90,81,102,\ + 102,112,90,\ + 90,100,79,\ + 79,70,90,\ + 99,89,77,\ + 77,87,99,\ + 99,87,109,\ + 110,89,99,\ + 121,110,99,\ + 99,109,121,\ + 0,10,7,\ + 10,19,7,\ + 7,1,0,\ + 7,11,1,\ + 7,19,11,\ + 25,41,33,\ + 11,19,25,\ + 33,17,25,\ + 17,11,25,\ + 23,19,10,\ + 23,13,29,\ + 10,13,23,\ + 20,8,15,\ + 15,27,20,\ + 15,8,5,\ + 5,17,15,\ + 17,27,15,\ + 321,332,313,\ + 313,302,321,\ + 293,302,313,\ + 338,351,342,\ + 348,351,338,\ + 338,332,348,\ + 281,301,290,\ + 290,269,281,\ + 298,290,311,\ + 311,290,301,\ + 310,292,303,\ + 293,313,303,\ + 284,272,262,\ + 284,292,272,\ + 293,303,284,\ + 284,303,292,\ + 342,354,350,\ + 350,354,360,\ + 350,360,358,\ + 358,344,350,\ + 268,257,278,\ + 278,288,268,\ + 298,288,278,\ + 278,257,269,\ + 278,290,298,\ + 269,290,278,\ + 51,60,41,\ + 139,141,161,\ + 161,150,139,\ + 119,141,139,\ + 139,116,119,\ + 308,331,319,\ + 299,308,319,\ + 319,309,299,\ + 329,309,319,\ + 339,329,319,\ + 319,331,339,\ + 241,224,253,\ + 253,258,241,\ + 212,224,241,\ + 259,237,248,\ + 248,241,258,\ + 253,224,233,\ + 233,224,211,\ + 211,222,233,\ + 222,245,233,\ + 279,270,264,\ + 270,253,264,\ + 264,276,279,\ + 253,233,264,\ + 264,233,245,\ + 263,276,264,\ + 264,245,263,\ + 214,238,231,\ + 225,238,214,\ + 191,213,203,\ + 203,213,225,\ + 203,178,191,\ + 193,178,203,\ + 225,214,203,\ + 203,214,193,\ + 215,229,204,\ + 231,240,215,\ + 215,240,229,\ + 345,352,337,\ + 352,343,337,\ + 337,329,345,\ + 322,329,337,\ + 337,327,322,\ + 337,343,327,\ + 287,275,267,\ + 267,277,287,\ + 246,256,267,\ + 267,256,277,\ + 244,234,255,\ + 255,267,275,\ + 255,234,246,\ + 246,267,255,\ + 297,275,287,\ + 287,305,297,\ + 297,305,314,\ + 297,316,304,\ + 297,314,316,\ + 265,254,266,\ + 266,255,275,\ + 266,254,244,\ + 244,255,266,\ + 65,84,74,\ + 73,84,65,\ + 65,55,73,\ + 65,46,55,\ + 47,37,56,\ + 37,46,56,\ + 46,65,56,\ + 56,65,74,\ + 74,64,56,\ + 64,47,56,\ + 216,239,230,\ + 230,206,216,\ + 229,239,216,\ + 204,229,216,\ + 216,194,204,\ + 216,206,194,\ + 199,210,187,\ + 187,183,199,\ + 174,183,187,\ + 187,175,174,\ + 209,175,187,\ + 187,210,209,\ + 171,183,162,\ + 153,186,171,\ + 171,142,153,\ + 162,142,171,\ + 91,67,82,\ + 91,78,67,\ + 82,97,91,\ + 97,108,91,\ + 128,97,116,\ + 128,108,97,\ + 137,108,128,\ + 128,150,137,\ + 128,139,150,\ + 116,139,128,\ + 40,34,49,\ + 49,59,40,\ + 40,59,48,\ + 40,48,29,\ + 18,34,26,\ + 26,12,18,\ + 34,40,26,\ + 26,40,29,\ + 29,13,26,\ + 13,12,26,\ + 191,178,170,\ + 170,178,158,\ + 170,158,148,\ + 148,159,170,\ + 172,159,149,\ + 137,160,149,\ + 149,160,172,\ + 212,202,189,\ + 189,201,212,\ + 179,201,189,\ + 189,172,179,\ + 120,108,137,\ + 137,149,120,\ + 293,284,274,\ + 274,284,262,\ + 262,252,274,\ + 252,261,274,\ + 282,271,291,\ + 282,261,271,\ + 282,274,261,\ + 293,274,282,\ + 282,302,293,\ + 291,302,282,\ + 192,204,181,\ + 192,215,204,\ + 146,156,169,\ + 169,157,146,\ + 181,157,169,\ + 169,192,181,\ + 323,313,332,\ + 332,338,323,\ + 310,303,323,\ + 323,303,313,\ + 298,311,317,\ + 317,306,298,\ + 317,325,306,\ + 334,325,317,\ + 328,344,334,\ + 334,317,328,\ + 328,317,311,\ + 310,323,326,\ + 326,338,342,\ + 326,323,338,\ + 60,51,69,\ + 80,60,69,\ + 58,77,69,\ + 69,51,58,\ + 69,89,80,\ + 77,89,69,\ + 19,23,35,\ + 35,51,41,\ + 41,25,35,\ + 35,25,19,\ + 273,248,258,\ + 273,283,289,\ + 273,258,283,\ + 259,248,273,\ + 289,280,273,\ + 273,280,259,\ + 237,213,228,\ + 228,248,237,\ + 228,202,212,\ + 213,202,228,\ + 212,241,228,\ + 241,248,228,\ + 168,178,193,\ + 168,156,147,\ + 168,147,158,\ + 158,178,168,\ + 286,297,304,\ + 275,297,286,\ + 286,266,275,\ + 199,183,190,\ + 183,171,190,\ + 190,219,199,\ + 208,219,190,\ + 190,186,208,\ + 190,171,186,\ + 177,189,202,\ + 177,170,159,\ + 177,159,172,\ + 172,189,177,\ + 177,202,191,\ + 191,170,177,\ + 113,120,133,\ + 133,120,149,\ + 148,124,133,\ + 124,113,133,\ + 133,159,148,\ + 133,149,159,\ + 103,91,108,\ + 108,120,103,\ + 88,78,103,\ + 78,91,103,\ + 103,113,88,\ + 103,120,113,\ + 215,192,205,\ + 205,214,231,\ + 231,215,205,\ + 193,214,205,\ + 192,169,180,\ + 193,205,180,\ + 180,205,192,\ + 180,168,193,\ + 180,169,156,\ + 156,168,180,\ + 336,350,344,\ + 344,328,336,\ + 342,350,336,\ + 336,326,342,\ + 320,328,311,\ + 320,311,301,\ + 320,336,328,\ + 326,336,320,\ + 320,301,310,\ + 310,326,320,\ + 38,35,23,\ + 29,48,38,\ + 38,23,29,\ + 38,48,58,\ + 58,51,38,\ + 51,35,38,\ + 265,266,285,\ + 266,286,285,\ + 285,286,304} + +#endif /* __DISC_H__ */ diff --git a/cuslines/metal_shaders/generate_streamlines_metal.metal b/cuslines/metal_shaders/generate_streamlines_metal.metal new file mode 100644 index 0000000..4a0a681 --- /dev/null +++ b/cuslines/metal_shaders/generate_streamlines_metal.metal @@ -0,0 +1,400 @@ +/* Metal port of cuslines/cuda_c/generate_streamlines_cuda.cu + * + * Main streamline generation kernels for probabilistic and PTT tracking. + * Bootstrap kernels are in boot.metal. + */ + +#include "globals.h" +#include "types.h" +#include "philox_rng.h" + +// Forward declarations from tracking_helpers.metal and utils.metal +inline int trilinear_interp(const int dimx, const int dimy, const int dimz, + const int dimt, int dimt_idx, + const device float* dataf, + const float3 point, + threadgroup float* vox_data, + uint tidx); + +inline int check_point(const float tc_threshold, + const float3 point, + const int dimx, const int dimy, const int dimz, + const device float* metric_map, + threadgroup float* interp_out, + uint tidx, uint tidy); + +inline int peak_directions(const threadgroup float* odf, + threadgroup float3* dirs, + const device packed_float3* sphere_vertices, + const device int2* sphere_edges, + const int num_edges, + int samplm_nr, + threadgroup int* shInd, + const float relative_peak_thres, + const float min_separation_angle, + uint tidx); + +inline float simd_max_reduce(int n, const threadgroup float* src, float minVal, uint tidx); + +inline void prefix_sum_sh(threadgroup float* num_sh, int len, uint tidx); + +// ── Parameter struct for Prob/PTT kernels ──────────────────────────── +// Guarded: may already be defined by ptt.metal (compiled first). + +#ifndef PROB_TRACKING_PARAMS_DEFINED +#define PROB_TRACKING_PARAMS_DEFINED +struct ProbTrackingParams { + float max_angle; + float tc_threshold; + float step_size; + float relative_peak_thresh; + float min_separation_angle; + int rng_seed_lo; + int rng_seed_hi; + int rng_offset; + int nseed; + int dimx; + int dimy; + int dimz; + int dimt; + int samplm_nr; + int num_edges; + int model_type; // PROB=2 or PTT=3 +}; +#endif + +// ── max threadgroup memory dimensions ──────────────────────────────── +// BLOCK_Y and MAX_N32DIMT are defined in globals.h + +// ── probabilistic direction getter ─────────────────────────────────── + +inline int get_direction_prob(thread PhiloxState& st, + const device float* pmf, + const float max_angle, + const float relative_peak_thres, + const float min_separation_angle, + float3 dir, + const int dimx, const int dimy, + const int dimz, const int dimt, + const float3 point, + const device packed_float3* sphere_vertices, + const device int2* sphere_edges, + const int num_edges, + threadgroup float3* out_dirs, + threadgroup float* sh_mem, + threadgroup int* sh_ind, + bool is_start, + uint tidx, uint tidy) { + + const int n32dimt = ((dimt + 31) / 32) * 32; + threadgroup float* pmf_data_sh = sh_mem + tidy * n32dimt; + + // pmf = trilinear interpolation at point + simdgroup_barrier(mem_flags::mem_threadgroup); + const int rv = trilinear_interp(dimx, dimy, dimz, dimt, -1, pmf, point, pmf_data_sh, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + if (rv != 0) { + return 0; + } + + // absolute pmf threshold + const float absolpmf_thresh = PMF_THRESHOLD_P * simd_max_reduce(dimt, pmf_data_sh, REAL_MIN, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + // zero out entries below threshold + for (int i = int(tidx); i < dimt; i += THR_X_SL) { + if (pmf_data_sh[i] < absolpmf_thresh) { + pmf_data_sh[i] = 0.0f; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (is_start) { + return peak_directions(pmf_data_sh, + out_dirs, + sphere_vertices, + sphere_edges, + num_edges, + dimt, + sh_ind, + relative_peak_thres, + min_separation_angle, + tidx); + } else { + // Filter by angle similarity + const float cos_similarity = COS(max_angle); + + for (int i = int(tidx); i < dimt; i += THR_X_SL) { + float3 sv = load_f3(sphere_vertices, uint(i)); + const float dot = dir.x * sv.x + dir.y * sv.y + dir.z * sv.z; + if (FABS(dot) < cos_similarity) { + pmf_data_sh[i] = 0.0f; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Prefix sum for CDF + prefix_sum_sh(pmf_data_sh, dimt, tidx); + + float last_cdf = pmf_data_sh[dimt - 1]; + if (last_cdf == 0.0f) { + return 0; + } + + // Sample from CDF + float tmp; + if (tidx == 0) { + tmp = philox_uniform(st) * last_cdf; + } + float selected_cdf = simd_broadcast_first(tmp); + + // Binary search + ballot for insertion point + int low = 0; + int high = dimt - 1; + while ((high - low) >= THR_X_SL) { + const int mid = (low + high) / 2; + if (pmf_data_sh[mid] < selected_cdf) { + low = mid; + } else { + high = mid; + } + } + const bool ballot_pred = (low + int(tidx) <= high) ? (selected_cdf < pmf_data_sh[low + tidx]) : false; + const uint msk = SIMD_BALLOT_MASK(ballot_pred); + const int indProb = (msk != 0) ? (low + int(ctz(msk))) : (dimt - 1); + + // Select direction, flip if needed + if (tidx == 0) { + float3 sv = load_f3(sphere_vertices, uint(indProb)); + if ((dir.x * sv.x + dir.y * sv.y + dir.z * sv.z) > 0) { + *out_dirs = sv; + } else { + *out_dirs = -sv; + } + } + + return 1; + } +} + +// ── tracker — step along streamline ────────────────────────────────── + +inline int tracker_prob(thread PhiloxState& st, + const float max_angle, + const float tc_threshold, + const float step_size, + const float relative_peak_thres, + const float min_separation_angle, + float3 seed, + float3 first_step, + const float3 voxel_size, + const int dimx, const int dimy, + const int dimz, const int dimt, + const device float* dataf, + const device float* metric_map, + const int samplm_nr, + const device packed_float3* sphere_vertices, + const device int2* sphere_edges, + const int num_edges, + threadgroup int* nsteps, + device packed_float3* streamline, + threadgroup float3* sh_new_dir, + threadgroup float* sh_mem, + threadgroup float* interp_out, + threadgroup int* sh_ind, + uint tidx, uint tidy) { + + int tissue_class = TRACKPOINT; + float3 point = seed; + float3 direction = first_step; + + if (tidx == 0) { + store_f3(streamline, 0, point); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + int i; + for (i = 1; i < MAX_SLINE_LEN; i++) { + int ndir = get_direction_prob(st, dataf, max_angle, + relative_peak_thres, min_separation_angle, + direction, dimx, dimy, dimz, dimt, + point, sphere_vertices, sphere_edges, + num_edges, sh_new_dir + tidy, + sh_mem, sh_ind, false, tidx, tidy); + simdgroup_barrier(mem_flags::mem_threadgroup); + direction = sh_new_dir[tidy]; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (ndir == 0) { + break; + } + + point.x += (direction.x / voxel_size.x) * step_size; + point.y += (direction.y / voxel_size.y) * step_size; + point.z += (direction.z / voxel_size.z) * step_size; + + if (tidx == 0) { + store_f3(streamline, uint(i), point); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + tissue_class = check_point(tc_threshold, point, dimx, dimy, dimz, + metric_map, interp_out, tidx, tidy); + + if (tissue_class == ENDPOINT || + tissue_class == INVALIDPOINT || + tissue_class == OUTSIDEIMAGE) { + break; + } + } + nsteps[0] = i; + return tissue_class; +} + +// ── getNumStreamlinesProb_k ────────────────────────────────────────── + +kernel void getNumStreamlinesProb_k( + constant ProbTrackingParams& params [[buffer(0)]], + const device packed_float3* seeds [[buffer(1)]], + const device float* dataf [[buffer(2)]], + const device packed_float3* sphere_vertices [[buffer(3)]], + const device int2* sphere_edges [[buffer(4)]], + device packed_float3* shDir0 [[buffer(5)]], + device int* slineOutOff [[buffer(6)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]]) +{ + const uint tidx = tid.x; + const uint tidy = tid.y; + const uint slid = gid.x * BLOCK_Y + tidy; + + if (int(slid) >= params.nseed) return; + + const uint global_id = gid.x * BLOCK_Y * THR_X_SL + THR_X_SL * tidy + tidx; + PhiloxState st = philox_init(uint(params.rng_seed_lo), uint(params.rng_seed_hi), global_id, 0); + + const int n32dimt = ((params.dimt + 31) / 32) * 32; + + // Threadgroup memory + threadgroup float sh_mem[BLOCK_Y * MAX_N32DIMT]; + threadgroup int sh_ind[BLOCK_Y * MAX_N32DIMT]; + threadgroup float3 dirs_sh[BLOCK_Y * MAX_SLINES_PER_SEED]; + + threadgroup float* my_sh = sh_mem + tidy * n32dimt; + threadgroup int* my_ind = sh_ind + tidy * n32dimt; + + float3 seed = load_f3(seeds, slid); + device packed_float3* my_shDir = shDir0 + slid * params.dimt; + + int ndir = get_direction_prob(st, dataf, params.max_angle, + params.relative_peak_thresh, + params.min_separation_angle, + float3(0, 0, 0), + params.dimx, params.dimy, params.dimz, params.dimt, + seed, sphere_vertices, sphere_edges, + params.num_edges, + dirs_sh + tidy * MAX_SLINES_PER_SEED, + my_sh, my_ind, true, tidx, tidy); + + // Copy found directions to global memory + if (tidx == 0) { + for (int d = 0; d < ndir; d++) { + store_f3(my_shDir, uint(d), dirs_sh[tidy * MAX_SLINES_PER_SEED + d]); + } + slineOutOff[slid] = ndir; + } +} + +// ── genStreamlinesMergeProb_k ──────────────────────────────────────── + +kernel void genStreamlinesMergeProb_k( + constant ProbTrackingParams& params [[buffer(0)]], + const device packed_float3* seeds [[buffer(1)]], + const device float* dataf [[buffer(2)]], + const device float* metric_map [[buffer(3)]], + const device packed_float3* sphere_vertices [[buffer(4)]], + const device int2* sphere_edges [[buffer(5)]], + const device int* slineOutOff [[buffer(6)]], + device packed_float3* shDir0 [[buffer(7)]], + device int* slineSeed [[buffer(8)]], + device int* slineLen [[buffer(9)]], + device packed_float3* sline [[buffer(10)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]]) +{ + const uint tidx = tid.x; + const uint tidy = tid.y; + const uint slid = gid.x * BLOCK_Y + tidy; + + if (int(slid) >= params.nseed) return; + + const uint global_id = gid.x * BLOCK_Y * THR_X_SL + THR_X_SL * tidy + tidx; + PhiloxState st = philox_init(uint(params.rng_seed_lo), uint(params.rng_seed_hi), global_id + 1, 0); + + const int n32dimt = ((params.dimt + 31) / 32) * 32; + + // Threadgroup memory + threadgroup float sh_mem[BLOCK_Y * MAX_N32DIMT]; + threadgroup int sh_ind[BLOCK_Y * MAX_N32DIMT]; + threadgroup float3 sh_new_dir[BLOCK_Y]; + threadgroup float interp_out[BLOCK_Y]; + threadgroup int stepsB_sh[BLOCK_Y]; + threadgroup int stepsF_sh[BLOCK_Y]; + + float3 seed = load_f3(seeds, slid); + + int ndir = slineOutOff[slid + 1] - slineOutOff[slid]; + simdgroup_barrier(mem_flags::mem_threadgroup); + + int slineOff = slineOutOff[slid]; + + for (int i = 0; i < ndir; i++) { + float3 first_step = load_f3(shDir0, uint(int(slid) * params.samplm_nr + i)); + + device packed_float3* currSline = sline + slineOff * MAX_SLINE_LEN * 2; + + if (tidx == 0) { + slineSeed[slineOff] = int(slid); + } + + // Backward tracking + tracker_prob(st, params.max_angle, params.tc_threshold, + params.step_size, params.relative_peak_thresh, + params.min_separation_angle, + seed, float3(-first_step.x, -first_step.y, -first_step.z), + float3(1, 1, 1), + params.dimx, params.dimy, params.dimz, params.dimt, + dataf, metric_map, params.samplm_nr, + sphere_vertices, sphere_edges, params.num_edges, + stepsB_sh + tidy, currSline, + sh_new_dir, sh_mem, interp_out, + sh_ind + tidy * n32dimt, tidx, tidy); + + int stepsB = stepsB_sh[tidy]; + + // Reverse backward streamline + for (int j = int(tidx); j < stepsB / 2; j += THR_X_SL) { + float3 p = load_f3(currSline, uint(j)); + store_f3(currSline, uint(j), load_f3(currSline, uint(stepsB - 1 - j))); + store_f3(currSline, uint(stepsB - 1 - j), p); + } + + // Forward tracking + tracker_prob(st, params.max_angle, params.tc_threshold, + params.step_size, params.relative_peak_thresh, + params.min_separation_angle, + seed, first_step, float3(1, 1, 1), + params.dimx, params.dimy, params.dimz, params.dimt, + dataf, metric_map, params.samplm_nr, + sphere_vertices, sphere_edges, params.num_edges, + stepsF_sh + tidy, currSline + (stepsB - 1), + sh_new_dir, sh_mem, interp_out, + sh_ind + tidy * n32dimt, tidx, tidy); + + if (tidx == 0) { + slineLen[slineOff] = stepsB - 1 + stepsF_sh[tidy]; + } + + slineOff += 1; + } +} diff --git a/cuslines/metal_shaders/globals.h b/cuslines/metal_shaders/globals.h new file mode 100644 index 0000000..c6eb014 --- /dev/null +++ b/cuslines/metal_shaders/globals.h @@ -0,0 +1,61 @@ +/* Metal-adapted globals — mirrors cuslines/cuda_c/globals.h. + * Metal only supports float (no double), so REAL_SIZE is always 4. + */ + +#ifndef __GLOBALS_H__ +#define __GLOBALS_H__ + +#include +using namespace metal; + +// ── precision ──────────────────────────────────────────────────────── +#define REAL_SIZE 4 + +#define REAL float +#define FLOOR floor +#define LOG fast::log +#define EXP fast::exp +#define COS fast::cos +#define SIN fast::sin +#define FABS abs +#define SQRT sqrt +#define RSQRT rsqrt +#define ACOS acos +#define REAL_MAX FLT_MAX +#define REAL_MIN (-FLT_MAX) + +// ── geometry constants ─────────────────────────────────────────────── +#define MAX_SLINE_LEN (501) +#define PMF_THRESHOLD_P ((REAL)0.05) + +#define THR_X_BL (64) +#define THR_X_SL (32) +#define BLOCK_Y (THR_X_BL / THR_X_SL) // = 2 +#define MAX_N32DIMT 512 + +#define MAX_SLINES_PER_SEED (10) + +#define MIN(x,y) (((x)<(y))?(x):(y)) +#define MAX(x,y) (((x)>(y))?(x):(y)) +#define POW2(n) (1 << (n)) + +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) + +// simd_ballot returns simd_vote; extract bits via ulong then truncate to uint +#define SIMD_BALLOT_MASK(pred) uint(ulong(simd_ballot(pred))) + +#define EXCESS_ALLOC_FACT 2 + +#define NORM_EPS ((REAL)1e-8) + +// ── model types ────────────────────────────────────────────────────── +enum ModelType { + OPDT = 0, + CSA = 1, + PROB = 2, + PTT = 3, +}; + +enum { OUTSIDEIMAGE, INVALIDPOINT, TRACKPOINT, ENDPOINT }; + +#endif diff --git a/cuslines/metal_shaders/philox_rng.h b/cuslines/metal_shaders/philox_rng.h new file mode 100644 index 0000000..8ac3ce7 --- /dev/null +++ b/cuslines/metal_shaders/philox_rng.h @@ -0,0 +1,152 @@ +/* Philox4x32-10 counter-based RNG for Metal Shading Language. + * + * This implements the same algorithm as curandStatePhilox4_32_10_t so that, + * given the same seed and sequence, the Metal and CUDA paths produce + * identical random streams. + * + * Reference: Salmon et al., "Parallel Random Numbers: As Easy as 1, 2, 3" + * (SC '11). DOI 10.1145/2063384.2063405 + */ + +#ifndef __PHILOX_RNG_H__ +#define __PHILOX_RNG_H__ + +#include +using namespace metal; + +// Philox constants +constant uint PHILOX_M4x32_0 = 0xD2511F53u; +constant uint PHILOX_M4x32_1 = 0xCD9E8D57u; +constant uint PHILOX_W32_0 = 0x9E3779B9u; +constant uint PHILOX_W32_1 = 0xBB67AE85u; + +struct PhiloxState { + uint4 counter; // 128-bit counter (ctr) + uint2 key; // 64-bit key + uint4 output; // cached output of last round + uint idx; // 0..3 index into output + float cached_normal; // Box-Muller second output cache + bool has_cached; // true if cached_normal is valid +}; + +// ── single Philox round ────────────────────────────────────────────── + +inline uint mulhi32(uint a, uint b) { + return uint((ulong(a) * ulong(b)) >> 32); +} + +inline uint4 philox4x32_single_round(uint4 ctr, uint2 key) { + uint lo0 = ctr.x * PHILOX_M4x32_0; + uint hi0 = mulhi32(ctr.x, PHILOX_M4x32_0); + uint lo1 = ctr.z * PHILOX_M4x32_1; + uint hi1 = mulhi32(ctr.z, PHILOX_M4x32_1); + + return uint4(hi1 ^ ctr.y ^ key.x, + lo1, + hi0 ^ ctr.w ^ key.y, + lo0); +} + +// ── 10-round Philox4x32 ───────────────────────────────────────────── + +inline uint4 philox4x32_10(uint4 ctr, uint2 key) { + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); key += uint2(PHILOX_W32_0, PHILOX_W32_1); + ctr = philox4x32_single_round(ctr, key); + return ctr; +} + +// ── curand-compatible initialisation ───────────────────────────────── +// Matches curand_init(seed, subsequence, offset, &state) + +inline PhiloxState philox_init(uint seed_lo, uint seed_hi, uint subsequence, uint offset) { + PhiloxState s; + // curand packs the 64-bit seed into the two key words + s.key = uint2(seed_lo, seed_hi); + // subsequence goes into counter.y/z, offset into counter.x + s.counter = uint4(0, 0, 0, 0); + + // Advance by subsequence (each subsequence = 2^67 values) + // In practice subsequence fits in 32 bits; mirror curand layout. + ulong subseq = ulong(subsequence); + s.counter.y += uint(subseq); + s.counter.z += uint(subseq >> 32); + + // Advance by offset (each offset = 4 outputs since Philox produces 4 uint per call) + uint advance = offset / 4; + uint remainder = offset % 4; + s.counter.x += advance; + + // Generate first batch + s.output = philox4x32_10(s.counter, s.key); + s.idx = remainder; + s.has_cached = false; + s.cached_normal = 0.0f; + return s; +} + +// ── advance counter ────────────────────────────────────────────────── + +inline void philox_next(thread PhiloxState& s) { + s.counter.x += 1; + if (s.counter.x == 0) { // overflow + s.counter.y += 1; + if (s.counter.y == 0) { + s.counter.z += 1; + if (s.counter.z == 0) { + s.counter.w += 1; + } + } + } + s.output = philox4x32_10(s.counter, s.key); + s.idx = 0; +} + +// ── generate uniform float in (0, 1] ──────────────────────────────── +// Matches curand_uniform(&state) + +inline float philox_uniform(thread PhiloxState& s) { + if (s.idx >= 4) { + philox_next(s); + } + uint bits; + switch (s.idx) { + case 0: bits = s.output.x; break; + case 1: bits = s.output.y; break; + case 2: bits = s.output.z; break; + default: bits = s.output.w; break; + } + s.idx++; + // curand maps uint to (0, 1] then we mirror to [0, 1) + // curand_uniform: result = uint * (1/2^32) but never 0 + // We use the same approach + return float(bits) * 2.3283064365386963e-10f + 2.3283064365386963e-10f; +} + +// ── generate standard normal via Box-Muller ────────────────────────── +// Matches curand_normal(&state) — caches second output for efficiency. + +inline float philox_normal(thread PhiloxState& s) { + if (s.has_cached) { + s.has_cached = false; + return s.cached_normal; + } + float u1 = philox_uniform(s); + float u2 = philox_uniform(s); + // Ensure u1 is not exactly 0 for the log + u1 = max(u1, 1.0e-38f); + float r = sqrt(-2.0f * log(u1)); + float theta = 2.0f * M_PI_F * u2; + s.cached_normal = r * sin(theta); + s.has_cached = true; + return r * cos(theta); +} + +#endif diff --git a/cuslines/metal_shaders/ptt.metal b/cuslines/metal_shaders/ptt.metal new file mode 100644 index 0000000..dff952e --- /dev/null +++ b/cuslines/metal_shaders/ptt.metal @@ -0,0 +1,1061 @@ +/* Metal port of cuslines/cuda_c/ptt.cu — Parallel Transport Tractography. + * + * Aydogan DB, Shi Y. Parallel Transport Tractography. IEEE Trans Med Imaging. + * 2021 Feb;40(2):635-647. doi: 10.1109/TMI.2020.3034038. + * + * Translation rules applied: + * __device__ -> inline functions + * threadIdx.x / threadIdx.y -> tidx / tidy parameters + * __syncwarp(WMASK) -> simdgroup_barrier(mem_flags::mem_threadgroup) + * __shfl_xor_sync(WMASK, v, d, BDX) -> simd_shuffle_xor(v, ushort(d)) + * __shfl_sync(WMASK, v, l, BDX) -> simd_shuffle(v, ushort(l)) + * curandStatePhilox4_32_10_t -> PhiloxState + * curand_init / uniform / normal -> philox_init / philox_uniform / philox_normal + * __shared__ -> threadgroup (at kernel scope only) + * REAL_T -> float + * REAL3_T -> float3 (registers) / packed_float3 (device) + * MAKE_REAL3(x,y,z) -> float3(x,y,z) + * Templates removed — concrete float types throughout. + */ + +#include "globals.h" +#include "types.h" +#include "philox_rng.h" + +// ── disc data ──────────────────────────────────────────────────────── +// Include the raw disc vertex/face macros, +// then declare Metal constant-address-space arrays for SAMPLING_QUALITY == 2. + +#include "disc.h" + +// ── PTT constants (from ptt.cuh) ───────────────────────────────────── +#define STEP_FRAC (20) +#define PROBE_FRAC (2) +#define PROBE_QUALITY (4) +#define SAMPLING_QUALITY (2) +#define ALLOW_WEAK_LINK (0) +#define TRIES_PER_REJECTION_SAMPLING (1024) +#define K_SMALL (0.0001f) + +#define DISC_VERT_CNT DISC_2_VERT_CNT +#define DISC_FACE_CNT DISC_2_FACE_CNT + +constant float DISC_VERT[DISC_VERT_CNT * 2] = DISC_2_VERT; +constant int DISC_FACE[DISC_FACE_CNT * 3] = DISC_2_FACE; + +// ── forward declarations of helpers defined in other .metal files ──── +// (These are compiled together into a single Metal library.) + +inline float simd_max_reduce_dev(int n, const device float* src, float minVal, + uint tidx); + +inline void prefix_sum_sh(threadgroup float* num_sh, int len, uint tidx); + +inline int trilinear_interp(const int dimx, const int dimy, const int dimz, + const int dimt, int dimt_idx, + const device float* dataf, + const float3 point, + threadgroup float* vox_data, + uint tidx); + +inline int check_point(const float tc_threshold, + const float3 point, + const int dimx, const int dimy, const int dimz, + const device float* metric_map, + threadgroup float* interp_out, + uint tidx, uint tidy); + +// ── norm3 ──────────────────────────────────────────────────────────── +// Normalise a 3-vector in place. On degenerate input set axis fail_ind to 1. + +inline void norm3(thread float* num, int fail_ind) { + const float scale = SQRT(num[0] * num[0] + num[1] * num[1] + num[2] * num[2]); + + if (scale > NORM_EPS) { + num[0] /= scale; + num[1] /= scale; + num[2] /= scale; + } else { + num[0] = num[1] = num[2] = 0; + num[fail_ind] = 1.0f; + } +} + +// threadgroup overload +inline void norm3(threadgroup float* num, int fail_ind) { + const float scale = SQRT(num[0] * num[0] + num[1] * num[1] + num[2] * num[2]); + + if (scale > NORM_EPS) { + num[0] /= scale; + num[1] /= scale; + num[2] /= scale; + } else { + num[0] = num[1] = num[2] = 0; + num[fail_ind] = 1.0f; + } +} + +// ── crossnorm3 ────────────────────────────────────────────────────── +// dest = normalise(src1 x src2) + +inline void crossnorm3(threadgroup float* dest, + const threadgroup float* src1, + const threadgroup float* src2, + int fail_ind) { + dest[0] = src1[1] * src2[2] - src1[2] * src2[1]; + dest[1] = src1[2] * src2[0] - src1[0] * src2[2]; + dest[2] = src1[0] * src2[1] - src1[1] * src2[0]; + + norm3(dest, fail_ind); +} + +// ── interp4 ───────────────────────────────────────────────────────── +// Find the ODF sphere vertex closest to `frame` direction, then +// trilinearly interpolate the PMF at that vertex index. + +inline float interp4(const float3 pos, + const threadgroup float* frame, + const device float* pmf, + const int dimx, const int dimy, + const int dimz, const int dimt, + const device packed_float3* odf_sphere_vertices, + threadgroup float* interp_scratch, + uint tidx) { + + int closest_odf_idx = 0; + float max_cos = 0.0f; + + for (int ii = int(tidx); ii < dimt; ii += THR_X_SL) { + float3 sv = load_f3(odf_sphere_vertices, uint(ii)); + float cos_sim = FABS(sv.x * frame[0] + + sv.y * frame[1] + + sv.z * frame[2]); + if (cos_sim > max_cos) { + max_cos = cos_sim; + closest_odf_idx = ii; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Reduce across the SIMD group + for (int i = THR_X_SL / 2; i > 0; i /= 2) { + const float tmp = simd_shuffle_xor(max_cos, ushort(i)); + const int tmp_idx = simd_shuffle_xor(closest_odf_idx, ushort(i)); + if (tmp > max_cos || + (tmp == max_cos && tmp_idx < closest_odf_idx)) { + max_cos = tmp; + closest_odf_idx = tmp_idx; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Trilinear interpolation at the closest ODF vertex + const int rv = trilinear_interp(dimx, dimy, dimz, dimt, + closest_odf_idx, pmf, pos, + interp_scratch, tidx); + + if (rv != 0) { + return 0.0f; // No support + } else { + return *interp_scratch; + } +} + +// ── prepare_propagator ────────────────────────────────────────────── +// Build 3x3 propagator matrix from curvatures k1, k2 and arclength. + +inline void prepare_propagator(float k1, float k2, float arclength, + threadgroup float* propagator) { + if ((FABS(k1) < K_SMALL) && (FABS(k2) < K_SMALL)) { + propagator[0] = arclength; + propagator[1] = 0; + propagator[2] = 0; + propagator[3] = 1; + propagator[4] = 0; + propagator[5] = 0; + propagator[6] = 0; + propagator[7] = 0; + propagator[8] = 1; + } else { + if (FABS(k1) < K_SMALL) { + k1 = K_SMALL; + } + if (FABS(k2) < K_SMALL) { + k2 = K_SMALL; + } + const float k = SQRT(k1 * k1 + k2 * k2); + const float sinkt = SIN(k * arclength); + const float coskt = COS(k * arclength); + const float kk = 1.0f / (k * k); + + propagator[0] = sinkt / k; + propagator[1] = k1 * (1.0f - coskt) * kk; + propagator[2] = k2 * (1.0f - coskt) * kk; + propagator[3] = coskt; + propagator[4] = k1 * sinkt / k; + propagator[5] = k2 * sinkt / k; + propagator[6] = -propagator[5]; + propagator[7] = k1 * k2 * (coskt - 1.0f) * kk; + propagator[8] = (k1 * k1 + k2 * k2 * coskt) * kk; + } +} + +// ── random_normal_ptt ─────────────────────────────────────────────── +// Generate a random normal vector perpendicular to probing_frame[0..2]. + +inline void random_normal_ptt(thread PhiloxState& st, + threadgroup float* probing_frame) { + probing_frame[3] = philox_normal(st); + probing_frame[4] = philox_normal(st); + probing_frame[5] = philox_normal(st); + + float dot = probing_frame[3] * probing_frame[0] + + probing_frame[4] * probing_frame[1] + + probing_frame[5] * probing_frame[2]; + + probing_frame[3] -= dot * probing_frame[0]; + probing_frame[4] -= dot * probing_frame[1]; + probing_frame[5] -= dot * probing_frame[2]; + + float n2 = probing_frame[3] * probing_frame[3] + + probing_frame[4] * probing_frame[4] + + probing_frame[5] * probing_frame[5]; + + if (n2 < NORM_EPS) { + float abs_x = FABS(probing_frame[0]); + float abs_y = FABS(probing_frame[1]); + float abs_z = FABS(probing_frame[2]); + + if (abs_x <= abs_y && abs_x <= abs_z) { + probing_frame[3] = 0.0f; + probing_frame[4] = probing_frame[2]; + probing_frame[5] = -probing_frame[1]; + } + else if (abs_y <= abs_z) { + probing_frame[3] = -probing_frame[2]; + probing_frame[4] = 0.0f; + probing_frame[5] = probing_frame[0]; + } + else { + probing_frame[3] = probing_frame[1]; + probing_frame[4] = -probing_frame[0]; + probing_frame[5] = 0.0f; + } + } +} + +// ── get_probing_frame ─────────────────────────────────────────────── +// IS_INIT variant: build a fresh probing frame from the tangent direction. +// Non-init variant: just copy the existing frame. + +inline void get_probing_frame_init(const threadgroup float* frame, + thread PhiloxState& st, + threadgroup float* probing_frame) { + for (int ii = 0; ii < 3; ii++) { + probing_frame[ii] = frame[ii]; + } + norm3(probing_frame, 0); + + random_normal_ptt(st, probing_frame); + norm3(probing_frame + 3, 1); + + // binorm = tangent x normal + crossnorm3(probing_frame + 2 * 3, probing_frame, probing_frame + 3, 2); +} + +inline void get_probing_frame_noinit(const threadgroup float* frame, + threadgroup float* probing_frame) { + for (int ii = 0; ii < 9; ii++) { + probing_frame[ii] = frame[ii]; + } +} + +// ── propagate_frame ───────────────────────────────────────────────── +// Apply propagator matrix to the frame, re-orthonormalise, and output direction. + +inline void propagate_frame(threadgroup float* propagator, + threadgroup float* frame, + threadgroup float* direc) { + float tmp[3]; + + for (int ii = 0; ii < 3; ii++) { + direc[ii] = propagator[0] * frame[ii] + propagator[1] * frame[3 + ii] + propagator[2] * frame[6 + ii]; + tmp[ii] = propagator[3] * frame[ii] + propagator[4] * frame[3 + ii] + propagator[5] * frame[6 + ii]; + frame[2*3 + ii] = propagator[6] * frame[ii] + propagator[7] * frame[3 + ii] + propagator[8] * frame[6 + ii]; + } + + norm3(tmp, 0); // normalise tangent + + // Write normalised tangent back to frame[0..2] so crossnorm3 can + // operate on threadgroup pointers (Metal requires address-space-qualified args). + for (int ii = 0; ii < 3; ii++) { + frame[ii] = tmp[ii]; + } + + crossnorm3(frame + 3, frame + 2 * 3, frame, 1); // normal = cross(binorm, tangent) + crossnorm3(frame + 2 * 3, frame, frame + 3, 2); // binorm = cross(tangent, normal) +} + +// ── calculate_data_support ────────────────────────────────────────── +// Probe forward along a candidate curve and accumulate FOD amplitudes. + +inline float calculate_data_support( + float support, + const float3 pos, + const device float* pmf, + const int dimx, const int dimy, const int dimz, const int dimt, + const float probe_step_size, + const float absolpmf_thresh, + const device packed_float3* odf_sphere_vertices, + threadgroup float* probing_prop_sh, + threadgroup float* direc_sh, + threadgroup float3* probing_pos_sh, + threadgroup float* k1_sh, + threadgroup float* k2_sh, + threadgroup float* probing_frame_sh, + threadgroup float* interp_scratch, + uint tidx) { + + if (tidx == 0) { + prepare_propagator(*k1_sh, *k2_sh, probe_step_size, probing_prop_sh); + *probing_pos_sh = pos; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (int ii = 0; ii < PROBE_QUALITY; ii++) { + if (tidx == 0) { + propagate_frame(probing_prop_sh, probing_frame_sh, direc_sh); + + float3 pp = *probing_pos_sh; + pp.x += direc_sh[0]; + pp.y += direc_sh[1]; + pp.z += direc_sh[2]; + *probing_pos_sh = pp; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const float fod_amp = interp4( + *probing_pos_sh, probing_frame_sh, pmf, + dimx, dimy, dimz, dimt, + odf_sphere_vertices, interp_scratch, tidx); + + if (!ALLOW_WEAK_LINK && (fod_amp < absolpmf_thresh)) { + return 0.0f; + } + support += fod_amp; + } + return support; +} + +// ── get_direction_ptt (IS_INIT == true) ───────────────────────────── +// Workspace threadgroup arrays are declared at kernel scope and passed +// as pre-offset (by tidy) pointers. + +inline int get_direction_ptt_init( + thread PhiloxState& st, + const device float* pmf, + const float max_angle, + const float step_size, + float3 dir, + threadgroup float* frame_sh, + const int dimx, const int dimy, const int dimz, const int dimt, + float3 pos, + const device packed_float3* odf_sphere_vertices, + threadgroup packed_float3* dirs, + // PTT workspace (pre-offset by tidy from kernel scope) + threadgroup float* my_face_cdf_sh, + threadgroup float* my_vert_pdf_sh, + threadgroup float* my_probing_frame_sh, + threadgroup float* my_k1_probe_sh, + threadgroup float* my_k2_probe_sh, + threadgroup float* my_probing_prop_sh, + threadgroup float* my_direc_sh, + threadgroup float3* my_probing_pos_sh, + threadgroup float* my_interp_scratch, + uint tidx) { + + const float probe_step_size = ((step_size / PROBE_FRAC) / (PROBE_QUALITY - 1)); + const float max_curvature = 2.0f * SIN(max_angle / 2.0f) / (step_size / PROBE_FRAC); + const float absolpmf_thresh = PMF_THRESHOLD_P * simd_max_reduce_dev(dimt, pmf, REAL_MIN, tidx); + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // IS_INIT: set frame tangent from dir + if (tidx == 0) { + frame_sh[0] = dir.x; + frame_sh[1] = dir.y; + frame_sh[2] = dir.z; + } + + const float first_val = interp4( + pos, frame_sh, pmf, + dimx, dimy, dimz, dimt, + odf_sphere_vertices, my_interp_scratch, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Calculate vert_pdf_sh + bool support_found = false; + for (int ii = 0; ii < DISC_VERT_CNT; ii++) { + if (tidx == 0) { + *my_k1_probe_sh = DISC_VERT[ii * 2] * max_curvature; + *my_k2_probe_sh = DISC_VERT[ii * 2 + 1] * max_curvature; + get_probing_frame_init(frame_sh, st, my_probing_frame_sh); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const float this_support = calculate_data_support( + first_val, + pos, pmf, dimx, dimy, dimz, dimt, + probe_step_size, + absolpmf_thresh, + odf_sphere_vertices, + my_probing_prop_sh, my_direc_sh, my_probing_pos_sh, + my_k1_probe_sh, my_k2_probe_sh, + my_probing_frame_sh, my_interp_scratch, tidx); + + if (this_support < PROBE_QUALITY * absolpmf_thresh) { + if (tidx == 0) { + my_vert_pdf_sh[ii] = 0; + } + } else { + if (tidx == 0) { + my_vert_pdf_sh[ii] = this_support; + } + support_found = true; + } + } + if (!support_found) { + return 0; + } + + // Initialise face_cdf_sh + for (int ii = int(tidx); ii < DISC_FACE_CNT; ii += THR_X_SL) { + my_face_cdf_sh[ii] = 0; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Move vert PDF to face PDF + for (int ii = int(tidx); ii < DISC_FACE_CNT; ii += THR_X_SL) { + bool all_verts_valid = true; + for (int jj = 0; jj < 3; jj++) { + float vert_val = my_vert_pdf_sh[DISC_FACE[ii * 3 + jj]]; + if (vert_val == 0) { + all_verts_valid = true; // IS_INIT: even go with faces that are not fully supported + } + my_face_cdf_sh[ii] += vert_val; + } + if (!all_verts_valid) { + my_face_cdf_sh[ii] = 0; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Prefix sum and check for zero total + prefix_sum_sh(my_face_cdf_sh, DISC_FACE_CNT, tidx); + float last_cdf = my_face_cdf_sh[DISC_FACE_CNT - 1]; + + if (last_cdf == 0) { + return 0; + } + + // Rejection sampling + for (int ii = 0; ii < TRIES_PER_REJECTION_SAMPLING; ii++) { + float tmp_sample; + if (tidx == 0) { + float r1 = philox_uniform(st); + float r2 = philox_uniform(st); + if (r1 + r2 > 1.0f) { + r1 = 1.0f - r1; + r2 = 1.0f - r2; + } + + tmp_sample = philox_uniform(st) * last_cdf; + int jj; + for (jj = 0; jj < DISC_FACE_CNT; jj++) { + if (my_face_cdf_sh[jj] >= tmp_sample) + break; + } + + const float vx0 = max_curvature * DISC_VERT[DISC_FACE[jj * 3] * 2]; + const float vx1 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 1] * 2]; + const float vx2 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 2] * 2]; + + const float vy0 = max_curvature * DISC_VERT[DISC_FACE[jj * 3] * 2 + 1]; + const float vy1 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 1] * 2 + 1]; + const float vy2 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 2] * 2 + 1]; + + *my_k1_probe_sh = vx0 + r1 * (vx1 - vx0) + r2 * (vx2 - vx0); + *my_k2_probe_sh = vy0 + r1 * (vy1 - vy0) + r2 * (vy2 - vy0); + get_probing_frame_init(frame_sh, st, my_probing_frame_sh); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const float this_support = calculate_data_support( + first_val, + pos, pmf, dimx, dimy, dimz, dimt, + probe_step_size, + absolpmf_thresh, + odf_sphere_vertices, + my_probing_prop_sh, my_direc_sh, my_probing_pos_sh, + my_k1_probe_sh, my_k2_probe_sh, + my_probing_frame_sh, my_interp_scratch, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (this_support < PROBE_QUALITY * absolpmf_thresh) { + continue; + } + + // IS_INIT: just store the original direction + if (tidx == 0) { + store_f3(dirs, 0, dir); + } + + if (tidx < 9) { + frame_sh[tidx] = my_probing_frame_sh[tidx]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + return 1; + } + return 0; +} + +// ── get_direction_ptt (IS_INIT == false) ──────────────────────────── +// Workspace threadgroup arrays are declared at kernel scope and passed +// as pre-offset (by tidy) pointers. + +inline int get_direction_ptt_noinit( + thread PhiloxState& st, + const device float* pmf, + const float max_angle, + const float step_size, + float3 dir, + threadgroup float* frame_sh, + const int dimx, const int dimy, const int dimz, const int dimt, + float3 pos, + const device packed_float3* odf_sphere_vertices, + threadgroup packed_float3* dirs, + // PTT workspace (pre-offset by tidy from kernel scope) + threadgroup float* my_face_cdf_sh, + threadgroup float* my_vert_pdf_sh, + threadgroup float* my_probing_frame_sh, + threadgroup float* my_k1_probe_sh, + threadgroup float* my_k2_probe_sh, + threadgroup float* my_probing_prop_sh, + threadgroup float* my_direc_sh, + threadgroup float3* my_probing_pos_sh, + threadgroup float* my_interp_scratch, + uint tidx) { + + const float probe_step_size = ((step_size / PROBE_FRAC) / (PROBE_QUALITY - 1)); + const float max_curvature = 2.0f * SIN(max_angle / 2.0f) / (step_size / PROBE_FRAC); + const float absolpmf_thresh = PMF_THRESHOLD_P * simd_max_reduce_dev(dimt, pmf, REAL_MIN, tidx); + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Non-init: frame_sh is already populated + + const float first_val = interp4( + pos, frame_sh, pmf, + dimx, dimy, dimz, dimt, + odf_sphere_vertices, my_interp_scratch, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Calculate vert_pdf_sh + bool support_found = false; + for (int ii = 0; ii < DISC_VERT_CNT; ii++) { + if (tidx == 0) { + *my_k1_probe_sh = DISC_VERT[ii * 2] * max_curvature; + *my_k2_probe_sh = DISC_VERT[ii * 2 + 1] * max_curvature; + get_probing_frame_noinit(frame_sh, my_probing_frame_sh); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const float this_support = calculate_data_support( + first_val, + pos, pmf, dimx, dimy, dimz, dimt, + probe_step_size, + absolpmf_thresh, + odf_sphere_vertices, + my_probing_prop_sh, my_direc_sh, my_probing_pos_sh, + my_k1_probe_sh, my_k2_probe_sh, + my_probing_frame_sh, my_interp_scratch, tidx); + + if (this_support < PROBE_QUALITY * absolpmf_thresh) { + if (tidx == 0) { + my_vert_pdf_sh[ii] = 0; + } + } else { + if (tidx == 0) { + my_vert_pdf_sh[ii] = this_support; + } + support_found = true; + } + } + if (!support_found) { + return 0; + } + + // Initialise face_cdf_sh + for (int ii = int(tidx); ii < DISC_FACE_CNT; ii += THR_X_SL) { + my_face_cdf_sh[ii] = 0; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Move vert PDF to face PDF + for (int ii = int(tidx); ii < DISC_FACE_CNT; ii += THR_X_SL) { + bool all_verts_valid = true; + for (int jj = 0; jj < 3; jj++) { + float vert_val = my_vert_pdf_sh[DISC_FACE[ii * 3 + jj]]; + if (vert_val == 0) { + all_verts_valid = false; // Non-init: reject faces with unsupported vertices + } + my_face_cdf_sh[ii] += vert_val; + } + if (!all_verts_valid) { + my_face_cdf_sh[ii] = 0; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Prefix sum and check for zero total + prefix_sum_sh(my_face_cdf_sh, DISC_FACE_CNT, tidx); + float last_cdf = my_face_cdf_sh[DISC_FACE_CNT - 1]; + + if (last_cdf == 0) { + return 0; + } + + // Rejection sampling + for (int ii = 0; ii < TRIES_PER_REJECTION_SAMPLING; ii++) { + float tmp_sample; + if (tidx == 0) { + float r1 = philox_uniform(st); + float r2 = philox_uniform(st); + if (r1 + r2 > 1.0f) { + r1 = 1.0f - r1; + r2 = 1.0f - r2; + } + + tmp_sample = philox_uniform(st) * last_cdf; + int jj; + for (jj = 0; jj < DISC_FACE_CNT; jj++) { + if (my_face_cdf_sh[jj] >= tmp_sample) + break; + } + + const float vx0 = max_curvature * DISC_VERT[DISC_FACE[jj * 3] * 2]; + const float vx1 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 1] * 2]; + const float vx2 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 2] * 2]; + + const float vy0 = max_curvature * DISC_VERT[DISC_FACE[jj * 3] * 2 + 1]; + const float vy1 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 1] * 2 + 1]; + const float vy2 = max_curvature * DISC_VERT[DISC_FACE[jj * 3 + 2] * 2 + 1]; + + *my_k1_probe_sh = vx0 + r1 * (vx1 - vx0) + r2 * (vx2 - vx0); + *my_k2_probe_sh = vy0 + r1 * (vy1 - vy0) + r2 * (vy2 - vy0); + get_probing_frame_noinit(frame_sh, my_probing_frame_sh); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const float this_support = calculate_data_support( + first_val, + pos, pmf, dimx, dimy, dimz, dimt, + probe_step_size, + absolpmf_thresh, + odf_sphere_vertices, + my_probing_prop_sh, my_direc_sh, my_probing_pos_sh, + my_k1_probe_sh, my_k2_probe_sh, + my_probing_frame_sh, my_interp_scratch, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (this_support < PROBE_QUALITY * absolpmf_thresh) { + continue; + } + + // Non-init: propagate 1/STEP_FRAC of a step and output direction + if (tidx == 0) { + prepare_propagator( + *my_k1_probe_sh, *my_k2_probe_sh, + step_size / STEP_FRAC, my_probing_prop_sh); + get_probing_frame_noinit(frame_sh, my_probing_frame_sh); + propagate_frame(my_probing_prop_sh, my_probing_frame_sh, my_direc_sh); + + // norm3 on threadgroup memory + norm3(my_direc_sh, 0); + + store_f3(dirs, 0, float3(my_direc_sh[0], my_direc_sh[1], my_direc_sh[2])); + } + + if (tidx < 9) { + frame_sh[tidx] = my_probing_frame_sh[tidx]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + return 1; + } + return 0; +} + +// ── init_frame_ptt ────────────────────────────────────────────────── +// Initialise the parallel transport frame for a new streamline. +// Tries the negative direction first, then the positive, and flips if needed. + +inline bool init_frame_ptt( + thread PhiloxState& st, + const device float* pmf, + const float max_angle, + const float step_size, + float3 first_step, + const int dimx, const int dimy, const int dimz, const int dimt, + float3 seed, + const device packed_float3* sphere_vertices, + threadgroup float* frame, + threadgroup packed_float3* tmp_dir, + // PTT workspace (pre-offset by tidy from kernel scope) + threadgroup float* my_face_cdf_sh, + threadgroup float* my_vert_pdf_sh, + threadgroup float* my_probing_frame_sh, + threadgroup float* my_k1_probe_sh, + threadgroup float* my_k2_probe_sh, + threadgroup float* my_probing_prop_sh, + threadgroup float* my_direc_sh, + threadgroup float3* my_probing_pos_sh, + threadgroup float* my_interp_scratch, + uint tidx) { + + bool init_norm_success; + + // Try with negated direction first + init_norm_success = (bool)get_direction_ptt_init( + st, + pmf, + max_angle, + step_size, + float3(-first_step.x, -first_step.y, -first_step.z), + frame, + dimx, dimy, dimz, dimt, + seed, + sphere_vertices, + tmp_dir, + my_face_cdf_sh, my_vert_pdf_sh, + my_probing_frame_sh, + my_k1_probe_sh, my_k2_probe_sh, + my_probing_prop_sh, my_direc_sh, + my_probing_pos_sh, my_interp_scratch, + tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (!init_norm_success) { + // Try the other direction + init_norm_success = (bool)get_direction_ptt_init( + st, + pmf, + max_angle, + step_size, + float3(first_step.x, first_step.y, first_step.z), + frame, + dimx, dimy, dimz, dimt, + seed, + sphere_vertices, + tmp_dir, + my_face_cdf_sh, my_vert_pdf_sh, + my_probing_frame_sh, + my_k1_probe_sh, my_k2_probe_sh, + my_probing_prop_sh, my_direc_sh, + my_probing_pos_sh, my_interp_scratch, + tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (!init_norm_success) { + return false; + } else { + if (tidx == 0) { + for (int ii = 0; ii < 9; ii++) { + frame[ii] = -frame[ii]; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // Save flipped frame for second run + if (tidx == 0) { + for (int ii = 0; ii < 9; ii++) { + frame[9 + ii] = -frame[ii]; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + return true; +} + +// ── ProbTrackingParams struct ──────────────────────────────────────── +// Shared with generate_streamlines_metal.metal. Guard against +// duplicate definitions since both files are compiled into one library. + +#ifndef PROB_TRACKING_PARAMS_DEFINED +#define PROB_TRACKING_PARAMS_DEFINED +struct ProbTrackingParams { + float max_angle; + float tc_threshold; + float step_size; + float relative_peak_thresh; + float min_separation_angle; + int rng_seed_lo; + int rng_seed_hi; + int rng_offset; + int nseed; + int dimx; + int dimy; + int dimz; + int dimt; + int samplm_nr; + int num_edges; + int model_type; // PROB=2 or PTT=3 +}; +#endif + +// ── tracker_ptt — step along streamline with parallel transport ───── +// Mirrors tracker_d from CUDA: takes fractional steps (STEP_FRAC +// sub-steps per full step), only stores every STEP_FRAC'th point. + +inline int tracker_ptt(thread PhiloxState& st, + const float max_angle, + const float tc_threshold, + const float step_size, + float3 seed, + float3 first_step, + const float3 voxel_size, + const int dimx, const int dimy, + const int dimz, const int dimt, + const device float* dataf, + const device float* metric_map, + const device packed_float3* sphere_vertices, + threadgroup int* nsteps, + device packed_float3* streamline, + threadgroup float* frame_sh, + threadgroup float* interp_out, + // PTT workspace (pre-offset by tidy) + threadgroup packed_float3* ptt_dirs, + threadgroup float* my_face_cdf_sh, + threadgroup float* my_vert_pdf_sh, + threadgroup float* my_probing_frame_sh, + threadgroup float* my_k1_probe_sh, + threadgroup float* my_k2_probe_sh, + threadgroup float* my_probing_prop_sh, + threadgroup float* my_direc_sh, + threadgroup float3* my_probing_pos_sh, + threadgroup float* my_interp_scratch, + uint tidx, uint tidy) { + + int tissue_class = TRACKPOINT; + float3 point = seed; + float3 direction = first_step; + + if (tidx == 0) { + store_f3(streamline, 0, point); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + int i; + for (i = 1; i < MAX_SLINE_LEN * STEP_FRAC; i++) { + int ndir = get_direction_ptt_noinit(st, dataf, max_angle, step_size, + direction, frame_sh, + dimx, dimy, dimz, dimt, + point, sphere_vertices, + ptt_dirs, + my_face_cdf_sh, my_vert_pdf_sh, + my_probing_frame_sh, + my_k1_probe_sh, my_k2_probe_sh, + my_probing_prop_sh, my_direc_sh, + my_probing_pos_sh, my_interp_scratch, + tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + direction = load_f3(ptt_dirs, 0); + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (ndir == 0) { + break; + } + + point.x += (direction.x / voxel_size.x) * (step_size / float(STEP_FRAC)); + point.y += (direction.y / voxel_size.y) * (step_size / float(STEP_FRAC)); + point.z += (direction.z / voxel_size.z) * (step_size / float(STEP_FRAC)); + + if ((tidx == 0) && ((i % STEP_FRAC) == 0)) { + store_f3(streamline, uint(i / STEP_FRAC), point); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + if ((i % STEP_FRAC) == 0) { + tissue_class = check_point(tc_threshold, point, dimx, dimy, dimz, + metric_map, interp_out, tidx, tidy); + + if (tissue_class == ENDPOINT || + tissue_class == INVALIDPOINT || + tissue_class == OUTSIDEIMAGE) { + break; + } + } + } + + nsteps[0] = i / STEP_FRAC; + // If stopped mid-fraction, store the final point + if (((i % STEP_FRAC) != 0) && (i < STEP_FRAC * (MAX_SLINE_LEN - 1))) { + nsteps[0] += 1; + if (tidx == 0) { + store_f3(streamline, uint(nsteps[0]), point); + } + } + return tissue_class; +} + +// ── genStreamlinesMergePtt_k ───────────────────────────────────────── +// PTT generation kernel. Uses the same buffer layout as the Prob kernel +// so the Python dispatch code is shared. PTT reuses Prob's getNum kernel +// for initial direction finding. + +kernel void genStreamlinesMergePtt_k( + constant ProbTrackingParams& params [[buffer(0)]], + const device packed_float3* seeds [[buffer(1)]], + const device float* dataf [[buffer(2)]], + const device float* metric_map [[buffer(3)]], + const device packed_float3* sphere_vertices [[buffer(4)]], + const device int2* sphere_edges [[buffer(5)]], + const device int* slineOutOff [[buffer(6)]], + device packed_float3* shDir0 [[buffer(7)]], + device int* slineSeed [[buffer(8)]], + device int* slineLen [[buffer(9)]], + device packed_float3* sline [[buffer(10)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]]) +{ + const uint tidx = tid.x; + const uint tidy = tid.y; + const uint slid = gid.x * BLOCK_Y + tidy; + + if (int(slid) >= params.nseed) return; + + const uint global_id = gid.x * BLOCK_Y * THR_X_SL + THR_X_SL * tidy + tidx; + PhiloxState st = philox_init(uint(params.rng_seed_lo), uint(params.rng_seed_hi), global_id + 1, 0); + + // ── PTT-specific threadgroup memory ───────────────────────────── + threadgroup float frame_sh[BLOCK_Y * 18]; // 9 backward + 9 forward + threadgroup packed_float3 tmp_dir_sh[BLOCK_Y]; // for init_frame_ptt + threadgroup packed_float3 ptt_dirs_sh[BLOCK_Y]; // direction output + threadgroup float interp_out[BLOCK_Y]; + threadgroup int stepsB_sh[BLOCK_Y]; + threadgroup int stepsF_sh[BLOCK_Y]; + + // PTT workspace arrays + threadgroup float face_cdf[BLOCK_Y * DISC_FACE_CNT]; + threadgroup float vert_pdf[BLOCK_Y * DISC_VERT_CNT]; + threadgroup float probing_frame[BLOCK_Y * 9]; + threadgroup float k1_probe[BLOCK_Y]; + threadgroup float k2_probe[BLOCK_Y]; + threadgroup float probing_prop[BLOCK_Y * 9]; + threadgroup float direc[BLOCK_Y * 3]; + threadgroup float3 probing_pos[BLOCK_Y]; + threadgroup float interp_scratch[BLOCK_Y]; + + // Pre-offset pointers for this tidy + threadgroup float* my_frame = frame_sh + tidy * 18; + threadgroup packed_float3* my_tmpdir = tmp_dir_sh + tidy; + threadgroup packed_float3* my_dirs = ptt_dirs_sh + tidy; + + threadgroup float* my_face_cdf = face_cdf + tidy * DISC_FACE_CNT; + threadgroup float* my_vert_pdf = vert_pdf + tidy * DISC_VERT_CNT; + threadgroup float* my_pfr = probing_frame + tidy * 9; + threadgroup float* my_k1 = k1_probe + tidy; + threadgroup float* my_k2 = k2_probe + tidy; + threadgroup float* my_pprop = probing_prop + tidy * 9; + threadgroup float* my_direc = direc + tidy * 3; + threadgroup float3* my_ppos = probing_pos + tidy; + threadgroup float* my_iscratch = interp_scratch + tidy; + + // ── per-seed loop ─────────────────────────────────────────────── + float3 seed = load_f3(seeds, slid); + + int ndir = slineOutOff[slid + 1] - slineOutOff[slid]; + simdgroup_barrier(mem_flags::mem_threadgroup); + + int slineOff = slineOutOff[slid]; + + for (int i = 0; i < ndir; i++) { + float3 first_step = load_f3(shDir0, uint(int(slid) * params.samplm_nr + i)); + + device packed_float3* currSline = sline + slineOff * MAX_SLINE_LEN * 2; + + if (tidx == 0) { + slineSeed[slineOff] = int(slid); + } + + // PTT frame initialization + if (!init_frame_ptt(st, dataf, params.max_angle, params.step_size, + first_step, + params.dimx, params.dimy, params.dimz, params.dimt, + seed, sphere_vertices, + my_frame, + my_tmpdir, + my_face_cdf, my_vert_pdf, + my_pfr, my_k1, my_k2, + my_pprop, my_direc, + my_ppos, my_iscratch, + tidx)) { + // Init failed — store single-point streamline + if (tidx == 0) { + slineLen[slineOff] = 1; + store_f3(currSline, 0, seed); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + slineOff += 1; + continue; + } + + // Backward tracking (using frame[0:9]) + tracker_ptt(st, params.max_angle, params.tc_threshold, + params.step_size, + seed, float3(-first_step.x, -first_step.y, -first_step.z), + float3(1, 1, 1), + params.dimx, params.dimy, params.dimz, params.dimt, + dataf, metric_map, sphere_vertices, + stepsB_sh + tidy, currSline, + my_frame, // backward frame = first 9 elements + interp_out, + my_dirs, + my_face_cdf, my_vert_pdf, + my_pfr, my_k1, my_k2, + my_pprop, my_direc, + my_ppos, my_iscratch, + tidx, tidy); + + int stepsB = stepsB_sh[tidy]; + + // Reverse backward streamline + for (int j = int(tidx); j < stepsB / 2; j += THR_X_SL) { + float3 p = load_f3(currSline, uint(j)); + store_f3(currSline, uint(j), load_f3(currSline, uint(stepsB - 1 - j))); + store_f3(currSline, uint(stepsB - 1 - j), p); + } + + // Forward tracking (using frame[9:18]) + tracker_ptt(st, params.max_angle, params.tc_threshold, + params.step_size, + seed, first_step, float3(1, 1, 1), + params.dimx, params.dimy, params.dimz, params.dimt, + dataf, metric_map, sphere_vertices, + stepsF_sh + tidy, currSline + (stepsB - 1), + my_frame + 9, // forward frame = last 9 elements + interp_out, + my_dirs, + my_face_cdf, my_vert_pdf, + my_pfr, my_k1, my_k2, + my_pprop, my_direc, + my_ppos, my_iscratch, + tidx, tidy); + + if (tidx == 0) { + slineLen[slineOff] = stepsB - 1 + stepsF_sh[tidy]; + } + + slineOff += 1; + } +} diff --git a/cuslines/metal_shaders/tracking_helpers.metal b/cuslines/metal_shaders/tracking_helpers.metal new file mode 100644 index 0000000..8ef2148 --- /dev/null +++ b/cuslines/metal_shaders/tracking_helpers.metal @@ -0,0 +1,221 @@ +/* Metal port of cuslines/cuda_c/tracking_helpers.cu + * + * Trilinear interpolation, tissue checking, and peak direction finding. + */ + +#include "globals.h" +#include "types.h" + +// ── trilinear interpolation helper (inner loop) ────────────────────── + +inline float interpolation_helper(const device float* dataf, + const float wgh[3][2], + const long coo[3][2], + int dimy, int dimz, int dimt, int t) { + float tmp = 0.0f; + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + for (int k = 0; k < 2; k++) { + tmp += wgh[0][i] * wgh[1][j] * wgh[2][k] * + dataf[coo[0][i] * dimy * dimz * dimt + + coo[1][j] * dimz * dimt + + coo[2][k] * dimt + + t]; + } + } + } + return tmp; +} + +// ── trilinear interpolation ────────────────────────────────────────── +// All threads in the SIMD group compute boundary checks together. +// Thread-parallel loop over the dimt dimension. + +inline int trilinear_interp(const int dimx, const int dimy, const int dimz, + const int dimt, int dimt_idx, + const device float* dataf, + const float3 point, + threadgroup float* vox_data, + uint tidx) { + const float HALF = 0.5f; + + if (point.x < -HALF || point.x + HALF >= float(dimx) || + point.y < -HALF || point.y + HALF >= float(dimy) || + point.z < -HALF || point.z + HALF >= float(dimz)) { + return -1; + } + + long coo[3][2]; // 64-bit to avoid overflow in index computation (CUDA uses long long) + float wgh[3][2]; + + const float3 fl = floor(point); + + wgh[0][1] = point.x - fl.x; + wgh[0][0] = 1.0f - wgh[0][1]; + coo[0][0] = MAX(0, int(fl.x)); + coo[0][1] = MIN(int(dimx - 1), coo[0][0] + 1); + + wgh[1][1] = point.y - fl.y; + wgh[1][0] = 1.0f - wgh[1][1]; + coo[1][0] = MAX(0, int(fl.y)); + coo[1][1] = MIN(int(dimy - 1), coo[1][0] + 1); + + wgh[2][1] = point.z - fl.z; + wgh[2][0] = 1.0f - wgh[2][1]; + coo[2][0] = MAX(0, int(fl.z)); + coo[2][1] = MIN(int(dimz - 1), coo[2][0] + 1); + + if (dimt_idx == -1) { + for (int t = int(tidx); t < dimt; t += THR_X_SL) { + vox_data[t] = interpolation_helper(dataf, wgh, coo, dimy, dimz, dimt, t); + } + } else { + *vox_data = interpolation_helper(dataf, wgh, coo, dimy, dimz, dimt, dimt_idx); + } + return 0; +} + +// ── tissue check at a point ────────────────────────────────────────── + +inline int check_point(const float tc_threshold, + const float3 point, + const int dimx, const int dimy, const int dimz, + const device float* metric_map, + threadgroup float* interp_out, // length BLOCK_Y + uint tidx, uint tidy) { + + const int rv = trilinear_interp(dimx, dimy, dimz, 1, 0, + metric_map, point, + interp_out + tidy, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (rv != 0) { + return OUTSIDEIMAGE; + } + return (interp_out[tidy] > tc_threshold) ? TRACKPOINT : ENDPOINT; +} + +// ── peak direction finding ─────────────────────────────────────────── +// Finds local maxima on the ODF sphere, filters by relative threshold +// and minimum separation angle. + +inline int peak_directions(const threadgroup float* odf, + threadgroup float3* dirs, + const device packed_float3* sphere_vertices, + const device int2* sphere_edges, + const int num_edges, + int samplm_nr, + threadgroup int* shInd, + const float relative_peak_thres, + const float min_separation_angle, + uint tidx) { + // Initialize index array + for (int j = int(tidx); j < samplm_nr; j += THR_X_SL) { + shInd[j] = 0; + } + + float odf_min = simd_min_reduce(samplm_nr, odf, REAL_MAX, tidx); + odf_min = MAX(0.0f, odf_min); + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Local maxima detection using sphere edges + // atomics on threadgroup memory for benign race conditions + for (int j = 0; j < num_edges; j += THR_X_SL) { + if (j + int(tidx) < num_edges) { + const int u_ind = sphere_edges[j + tidx].x; + const int v_ind = sphere_edges[j + tidx].y; + + const float u_val = odf[u_ind]; + const float v_val = odf[v_ind]; + + if (u_val < v_val) { + atomic_store_explicit( + (volatile threadgroup atomic_int*)(shInd + u_ind), -1, + memory_order_relaxed); + atomic_fetch_or_explicit( + (volatile threadgroup atomic_int*)(shInd + v_ind), 1, + memory_order_relaxed); + } else if (v_val < u_val) { + atomic_store_explicit( + (volatile threadgroup atomic_int*)(shInd + v_ind), -1, + memory_order_relaxed); + atomic_fetch_or_explicit( + (volatile threadgroup atomic_int*)(shInd + u_ind), 1, + memory_order_relaxed); + } + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + const float compThres = relative_peak_thres * + simd_max_mask_transl(samplm_nr, shInd, odf, -odf_min, REAL_MIN, tidx); + + // Compact indices of positive values (local maxima above threshold) + int n = 0; + const uint lmask = (1u << tidx) - 1u; // lanes below me + + for (int j = 0; j < samplm_nr; j += THR_X_SL) { + const int v = (j + int(tidx) < samplm_nr) ? shInd[j + tidx] : -1; + const bool keep = (v > 0) && ((odf[j + tidx] - odf_min) >= compThres); + + // simd_ballot returns a simd_vote on Metal; we can extract the uint mask + uint msk = SIMD_BALLOT_MASK(keep); + + if (keep) { + const int myoff = popcount(msk & lmask); + shInd[n + myoff] = j + int(tidx); + } + n += popcount(msk); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Sort local maxima by ODF value (descending) + if (n > 0 && n < THR_X_SL) { + float k = REAL_MIN; + int val = 0; + if (int(tidx) < n) { + val = shInd[tidx]; + k = odf[val]; + } + warp_sort_kv(k, val, tidx); + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (int(tidx) < n) { + shInd[tidx] = val; + } + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Remove similar vertices (single-threaded) + if (n != 0) { + if (tidx == 0) { + const float cos_similarity = COS(min_separation_angle); + + dirs[0] = load_f3(sphere_vertices, uint(shInd[0])); + + int k = 1; + for (int i = 1; i < n; i++) { + const float3 abc = load_f3(sphere_vertices, uint(shInd[i])); + + int j = 0; + for (; j < k; j++) { + const float cs = FABS(abc.x * dirs[j].x + + abc.y * dirs[j].y + + abc.z * dirs[j].z); + if (cs > cos_similarity) { + break; + } + } + if (j == k) { + dirs[k++] = abc; + } + } + n = k; + } + n = simd_broadcast_first(n); + simdgroup_barrier(mem_flags::mem_threadgroup); + } + + return n; +} diff --git a/cuslines/metal_shaders/types.h b/cuslines/metal_shaders/types.h new file mode 100644 index 0000000..e84f0a3 --- /dev/null +++ b/cuslines/metal_shaders/types.h @@ -0,0 +1,50 @@ +/* Metal type helpers — handles the packed_float3 / float3 alignment difference. + * + * In CUDA, float3 is 12 bytes in arrays (no padding). + * In Metal, float3 is 16 bytes. packed_float3 is 12 bytes. + * + * Strategy: + * - Device buffers use packed_float3 (12 bytes) → matches CUDA layout and + * Python numpy dtype, so all buffer size calculations remain unchanged. + * - Computation uses float3 (16 bytes) in registers/threadgroup memory. + * - load/store helpers convert between the two. + */ + +#ifndef __TYPES_H__ +#define __TYPES_H__ + +#include +using namespace metal; + +// ── buffer ↔ register conversions ──────────────────────────────────── + +inline float3 load_f3(const device packed_float3* p, uint idx) { + return float3(p[idx]); +} + +inline float3 load_f3(const device packed_float3& p) { + return float3(p); +} + +inline void store_f3(device packed_float3* p, uint idx, float3 v) { + p[idx] = packed_float3(v); +} + +inline void store_f3(device packed_float3& p, float3 v) { + p = packed_float3(v); +} + +// threadgroup load/store — threadgroup memory can use float3 directly +// but we sometimes index packed arrays in threadgroup memory too +inline float3 load_f3(const threadgroup packed_float3* p, uint idx) { + return float3(p[idx]); +} + +inline void store_f3(threadgroup packed_float3* p, uint idx, float3 v) { + p[idx] = packed_float3(v); +} + +// ── CUDA MAKE_REAL3 replacement ────────────────────────────────────── +#define MAKE_REAL3(x, y, z) float3((x), (y), (z)) + +#endif diff --git a/cuslines/metal_shaders/utils.metal b/cuslines/metal_shaders/utils.metal new file mode 100644 index 0000000..6f4aa48 --- /dev/null +++ b/cuslines/metal_shaders/utils.metal @@ -0,0 +1,107 @@ +/* Metal port of cuslines/cuda_c/utils.cu — reduction and prefix-sum primitives. + * + * CUDA warp operations → Metal SIMD group operations: + * __shfl_xor_sync(WMASK, v, delta, BDIM_X) → simd_shuffle_xor(v, delta) + * __shfl_up_sync(WMASK, v, delta, BDIM_X) → simd_shuffle_up(v, delta) + * __syncwarp(WMASK) → simdgroup_barrier(mem_flags::mem_threadgroup) + * + * Since BDIM_X == THR_X_SL == 32 == Apple GPU SIMD width, the custom + * WMASK always covers the full SIMD group so no masking is needed. + */ + +#include "globals.h" + +// ── max reduction across SIMD group ────────────────────────────────── + +inline float simd_max_reduce(int n, const threadgroup float* src, float minVal, + uint tidx) { + float m = minVal; + for (int i = tidx; i < n; i += THR_X_SL) { + m = MAX(m, src[i]); + } + for (int i = THR_X_SL / 2; i > 0; i /= 2) { + float tmp = simd_shuffle_xor(m, ushort(i)); + m = MAX(m, tmp); + } + return m; +} + +// ── min reduction across SIMD group ────────────────────────────────── + +inline float simd_min_reduce(int n, const threadgroup float* src, float maxVal, + uint tidx) { + float m = maxVal; + for (int i = tidx; i < n; i += THR_X_SL) { + m = MIN(m, src[i]); + } + for (int i = THR_X_SL / 2; i > 0; i /= 2) { + float tmp = simd_shuffle_xor(m, ushort(i)); + m = MIN(m, tmp); + } + return m; +} + +// ── max-with-mask reduction ────────────────────────────────────────── +// Only considers entries where srcMsk[i] > 0, applies offset to value. + +inline float simd_max_mask_transl(int n, + const threadgroup int* srcMsk, + const threadgroup float* srcVal, + float offset, float minVal, + uint tidx) { + float m = minVal; + for (int i = tidx; i < n; i += THR_X_SL) { + int sel = srcMsk[i]; + if (sel > 0) { + m = MAX(m, srcVal[i] + offset); + } + } + for (int i = THR_X_SL / 2; i > 0; i /= 2) { + float tmp = simd_shuffle_xor(m, ushort(i)); + m = MAX(m, tmp); + } + return m; +} + +// ── max from device buffer ─────────────────────────────────────────── + +inline float simd_max_reduce_dev(int n, const device float* src, float minVal, + uint tidx) { + float m = minVal; + for (int i = tidx; i < n; i += THR_X_SL) { + m = MAX(m, src[i]); + } + for (int i = THR_X_SL / 2; i > 0; i /= 2) { + float tmp = simd_shuffle_xor(m, ushort(i)); + m = MAX(m, tmp); + } + return m; +} + +// ── inclusive prefix sum in threadgroup memory ──────────────────────── +// Operates on threadgroup float array of length __len. +// All threads in the SIMD group participate. + +inline void prefix_sum_sh(threadgroup float* num_sh, int len, uint tidx) { + for (int j = 0; j < len; j += THR_X_SL) { + if ((tidx == 0) && (j != 0)) { + num_sh[j] += num_sh[j - 1]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + float t_pmf = 0.0f; + if (j + int(tidx) < len) { + t_pmf = num_sh[j + tidx]; + } + for (int i = 1; i < THR_X_SL; i *= 2) { + float tmp = simd_shuffle_up(t_pmf, ushort(i)); + if ((int(tidx) >= i) && (j + int(tidx) < len)) { + t_pmf += tmp; + } + } + if (j + int(tidx) < len) { + num_sh[j + tidx] = t_pmf; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + } +} diff --git a/cuslines/metal_shaders/warp_sort.metal b/cuslines/metal_shaders/warp_sort.metal new file mode 100644 index 0000000..9da9e56 --- /dev/null +++ b/cuslines/metal_shaders/warp_sort.metal @@ -0,0 +1,109 @@ +/* Metal port of cuslines/cuda_c/cuwsort.cuh — bitonic merge sort within a SIMD group. + * + * CUDA __shfl_sync → Metal simd_shuffle. + * Swap networks are embedded as constant arrays. + */ + +#include "globals.h" + +// ── sort direction ─────────────────────────────────────────────────── +#define WSORT_DIR_DEC 0 +#define WSORT_DIR_INC 1 + +// ── swap networks ──────────────────────────────────────────────────── +// Batcher's bitonic merge sort comparator networks. + +constant int swap32[15][32] = { + {16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}, + { 8, 9,10,11,12,13,14,15, 0, 1, 2, 3, 4, 5, 6, 7,24,25,26,27,28,29,30,31,16,17,18,19,20,21,22,23}, + { 4, 5, 6, 7, 0, 1, 2, 3,16,17,18,19,20,21,22,23, 8, 9,10,11,12,13,14,15,28,29,30,31,24,25,26,27}, + { 2, 3, 0, 1, 4, 5, 6, 7,12,13,14,15, 8, 9,10,11,20,21,22,23,16,17,18,19,24,25,26,27,30,31,28,29}, + { 1, 0, 2, 3,16,17,18,19, 8, 9,10,11,24,25,26,27, 4, 5, 6, 7,20,21,22,23,12,13,14,15,28,29,31,30}, + { 0, 1, 2, 3, 8, 9,10,11, 4, 5, 6, 7,16,17,18,19,12,13,14,15,24,25,26,27,20,21,22,23,28,29,30,31}, + { 0, 1, 2, 3, 6, 7, 4, 5,10,11, 8, 9,14,15,12,13,18,19,16,17,22,23,20,21,26,27,24,25,28,29,30,31}, + { 0, 1,16,17, 4, 5,20,21, 8, 9,24,25,12,13,28,29, 2, 3,18,19, 6, 7,22,23,10,11,26,27,14,15,30,31}, + { 0, 1, 8, 9, 4, 5,12,13, 2, 3,16,17, 6, 7,20,21,10,11,24,25,14,15,28,29,18,19,26,27,22,23,30,31}, + { 0, 1, 4, 5, 2, 3, 8, 9, 6, 7,12,13,10,11,16,17,14,15,20,21,18,19,24,25,22,23,28,29,26,27,30,31}, + { 0, 1, 3, 2, 5, 4, 7, 6, 9, 8,11,10,13,12,15,14,17,16,19,18,21,20,23,22,25,24,27,26,29,28,30,31}, + { 0,16, 2,18, 4,20, 6,22, 8,24,10,26,12,28,14,30, 1,17, 3,19, 5,21, 7,23, 9,25,11,27,13,29,15,31}, + { 0, 8, 2,10, 4,12, 6,14, 1,16, 3,18, 5,20, 7,22, 9,24,11,26,13,28,15,30,17,25,19,27,21,29,23,31}, + { 0, 4, 2, 6, 1, 8, 3,10, 5,12, 7,14, 9,16,11,18,13,20,15,22,17,24,19,26,21,28,23,30,25,29,27,31}, + { 0, 2, 1, 4, 3, 6, 5, 8, 7,10, 9,12,11,14,13,16,15,18,17,20,19,22,21,24,23,26,25,28,27,30,29,31} +}; + +constant int swap16[10][16] = { + { 8, 9,10,11,12,13,14,15, 0, 1, 2, 3, 4, 5, 6, 7}, + { 4, 5, 6, 7, 0, 1, 2, 3,12,13,14,15, 8, 9,10,11}, + { 2, 3, 0, 1, 8, 9,10,11, 4, 5, 6, 7,14,15,12,13}, + { 1, 0, 2, 3, 6, 7, 4, 5,10,11, 8, 9,12,13,15,14}, + { 0, 1, 8, 9, 4, 5,12,13, 2, 3,10,11, 6, 7,14,15}, + { 0, 1, 4, 5, 2, 3, 8, 9, 6, 7,12,13,10,11,14,15}, + { 0, 1, 3, 2, 5, 4, 7, 6, 9, 8,11,10,13,12,14,15}, + { 0, 8, 2,10, 4,12, 6,14, 1, 9, 3,11, 5,13, 7,15}, + { 0, 4, 2, 6, 1, 8, 3,10, 5,12, 7,14, 9,13,11,15}, + { 0, 2, 1, 4, 3, 6, 5, 8, 7,10, 9,12,11,14,13,15} +}; + +constant int swap8[6][8] = { + { 4, 5, 6, 7, 0, 1, 2, 3}, + { 2, 3, 0, 1, 6, 7, 4, 5}, + { 1, 0, 4, 5, 2, 3, 7, 6}, + { 0, 1, 3, 2, 5, 4, 6, 7}, + { 0, 4, 2, 6, 1, 5, 3, 7}, + { 0, 2, 1, 4, 3, 6, 5, 7} +}; + +constant int swap4[3][4] = { + { 2, 3, 0, 1}, + { 1, 0, 3, 2}, + { 0, 2, 1, 3} +}; + +constant int swap2[1][2] = { + { 1, 0} +}; + +// ── key-only sort ──────────────────────────────────────────────────── + +template +inline float warp_sort_key(float v, uint gid) { + const int NSWAP = (GSIZE == 2) ? 1 : (GSIZE == 4) ? 3 : (GSIZE == 8) ? 6 : (GSIZE == 16) ? 10 : 15; + + for (int i = 0; i < NSWAP; i++) { + int srclane; + if (GSIZE == 32) srclane = swap32[i][gid]; + else if (GSIZE == 16) srclane = swap16[i][gid]; + else if (GSIZE == 8) srclane = swap8[i][gid]; + else if (GSIZE == 4) srclane = swap4[i][gid]; + else srclane = swap2[i][gid]; + + float a = simd_shuffle(v, ushort(srclane)); + v = ((int(gid) < srclane) == DIRECTION) ? MIN(a, v) : MAX(a, v); + } + return v; +} + +// ── key-value sort ─────────────────────────────────────────────────── + +template +inline void warp_sort_kv(thread float& k, thread int& val, uint gid) { + const int NSWAP = (GSIZE == 2) ? 1 : (GSIZE == 4) ? 3 : (GSIZE == 8) ? 6 : (GSIZE == 16) ? 10 : 15; + + for (int i = 0; i < NSWAP; i++) { + int srclane; + if (GSIZE == 32) srclane = swap32[i][gid]; + else if (GSIZE == 16) srclane = swap16[i][gid]; + else if (GSIZE == 8) srclane = swap8[i][gid]; + else if (GSIZE == 4) srclane = swap4[i][gid]; + else srclane = swap2[i][gid]; + + float a = simd_shuffle(k, ushort(srclane)); + int b = simd_shuffle(val, ushort(srclane)); + + if ((int(gid) < srclane) == DIRECTION) { + if (a < k) { k = a; val = b; } + } else { + if (a > k) { k = a; val = b; } + } + } +} diff --git a/pyproject.toml b/pyproject.toml index 4276dd6..d9581e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,10 @@ cu12 = [ "cuda-cccl[cu12]" ] +metal = [ + "pyobjc-framework-Metal", + "pyobjc-framework-MetalPerformanceShaders", +] [tool.setuptools.packages.find] where = ["."] diff --git a/run_gpu_streamlines.py b/run_gpu_streamlines.py index 0d6c447..cc01f10 100644 --- a/run_gpu_streamlines.py +++ b/run_gpu_streamlines.py @@ -57,6 +57,7 @@ from trx.io import save as save_trx from cuslines import ( + BACKEND, BootDirectionGetter, GPUTracker, ProbDirectionGetter, @@ -86,7 +87,7 @@ def get_img(ep2_seq): parser.add_argument("bvecs", nargs='?', default='hardi', help="path to the bvecs") parser.add_argument("mask_nifti", nargs='?', default='hardi', help="path to the mask file") parser.add_argument("roi_nifti", nargs='?', default='hardi', help="path to the ROI file") -parser.add_argument("--device", type=str, default ='gpu', choices=['cpu', 'gpu'], help="Whether to use cpu or gpu") +parser.add_argument("--device", type=str, default ='gpu', choices=['cpu', 'gpu', 'metal'], help="Whether to use cpu, gpu (auto-detect), or metal") parser.add_argument("--output-prefix", type=str, default ='', help="path to the output file") parser.add_argument("--chunk-size", type=int, default=100000, help="how many seeds to process per sweep, per GPU") parser.add_argument("--nseeds", type=int, default=100000, help="how many seeds to process in total") @@ -105,6 +106,17 @@ def get_img(ep2_seq): args = parser.parse_args() +if args.device == "metal": + if BACKEND != "metal": + raise RuntimeError("Metal backend requested but not available. " + "Install: pip install 'cuslines[metal]'") + if args.ngpus > 1: + print("WARNING: Metal backend supports only 1 GPU, ignoring --ngpus %d" % args.ngpus) + args.ngpus = 1 + args.device = "gpu" # use the GPU code path +elif args.device == "gpu": + print("Using %s backend" % BACKEND) + if args.device == "cpu" and args.write_method != "trk": print("WARNING: only trk write method is implemented for cpu testing.") write_method = "trk" diff --git a/setup.py b/setup.py index 46c718c..7aff1f7 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def run(self): setup( cmdclass={"build_py": build_py_with_cuda}, package_data={ - "cuslines": ["cuda_c/*"], + "cuslines": ["cuda_c/*", "metal_shaders/*"], }, project_urls={ "Homepage": "https://github.com/dipy/GPUStreamlines", From 50593fc7fc74df7f053e7585c9592fa625856a51 Mon Sep 17 00:00:00 2001 From: neurolabusc Date: Mon, 2 Mar 2026 12:12:21 -0500 Subject: [PATCH 2/2] Fix Docker CI: install git and add setuptools-scm fallback_version setuptools-scm requires git to determine the package version from tags. The Dockerfile was missing git, causing pip install to fail in CI. Two fixes: - Install git alongside curl in the Dockerfile so setuptools-scm can read the version from the copied .git history - Add fallback_version = "0.0.0" to [tool.setuptools_scm] in pyproject.toml as a safety net for git-free environments (shallow clones, tarballs, GitHub zip archives) Co-Authored-By: Claude Sonnet 4.6 --- Dockerfile | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index f27a2d0..9490519 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ SHELL ["/bin/bash", "-c"] ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install --assume-yes curl +RUN apt-get update && apt-get install --assume-yes curl git RUN curl -L "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ -o "/tmp/Miniconda3.sh" diff --git a/pyproject.toml b/pyproject.toml index d9581e8..a704678 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ requires = ["setuptools>=64.0", "setuptools_scm>=8"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] +fallback_version = "0.0.0" [project] name = "cuslines"