From 897b415610f9360a3bab501e0a513ccd38988628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 30 Apr 2026 15:40:05 +0200 Subject: [PATCH 01/14] Compare PyTorch vs CUDA --- perf/CondaPkg.toml | 8 ++ perf/Project.toml | 13 +++ perf/cuda_vs_pytorch.jl | 171 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 192 insertions(+) create mode 100644 perf/CondaPkg.toml create mode 100644 perf/Project.toml create mode 100644 perf/cuda_vs_pytorch.jl diff --git a/perf/CondaPkg.toml b/perf/CondaPkg.toml new file mode 100644 index 0000000..b9eb7fb --- /dev/null +++ b/perf/CondaPkg.toml @@ -0,0 +1,8 @@ +channels = ["conda-forge"] + +[deps] +python = ">=3.10,<3.13" + +[pip.deps] +torch = ">=2.2" +numpy = ">=1.24" diff --git a/perf/Project.toml b/perf/Project.toml new file mode 100644 index 0000000..ff01a45 --- /dev/null +++ b/perf/Project.toml @@ -0,0 +1,13 @@ +name = "ArrayDiffPerf" +uuid = "00000000-0000-0000-0000-000000000001" +authors = ["Benoît Legat "] +version = "0.0.0" + +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl new file mode 100644 index 0000000..c5cac9b --- /dev/null +++ b/perf/cuda_vs_pytorch.jl @@ -0,0 +1,171 @@ +# Compare hardcoded CUDA.jl forward+reverse for a 2-layer MLP gradient +# against PyTorch (via PythonCall) doing the equivalent loss + .backward(). +# +# Goals +# 1. Numerical equality of ∂L/∂W1 between the two paths. +# 2. Per-kernel CUDA trace from each, side by side, so we can see whether +# they issue the same GPU operations under the hood. +# 3. Wall-clock benchmark per hidden size h. +# +# Run +# cd ~/.julia/dev/ArrayDiff +# julia --project=perf -e 'using Pkg; Pkg.instantiate()' +# julia --project=perf perf/cuda_vs_pytorch.jl +# +# Formulas are transcribed verbatim from autodiff.jl:398-426. +# Loss: L = sum((W2 * tanh(W1 * X) - y) .^ 2) / size(y, 2) +# Returned grad: ∂L/∂W1 (∈ R^{h×d}). + +using Random, LinearAlgebra, Printf +using CUDA +using BenchmarkTools +using PythonCall + +# ------------------------------------------------------------------------- +# Hardcoded CUDA.jl path +# ------------------------------------------------------------------------- +function forward_pass(W1, W2, X, y) + y_1 = tanh.(W1 * X) + J_1 = 1 .- y_1 .^ 2 + J_2 = 2 .* (W2 * y_1 .- y) ./ size(y, 2) + return y_1, J_1, J_2 +end + +function reverse_diff(W1, W2, X, y) + _, J_1, J_2 = forward_pass(W1, W2, X, y) + return (J_1 .* (W2' * J_2)) * X' +end + +# ------------------------------------------------------------------------- +# PyTorch path +# ------------------------------------------------------------------------- +const torch = pyimport("torch") +const np = pyimport("numpy") +const profiler = pyimport("torch.profiler") + +# Build torch tensors once and reuse them across benchmark iterations, +# mirroring how the Julia path passes already-on-GPU CuArrays. +function build_torch_tensors(W1::Matrix, W2::Matrix, X::Matrix, y::Matrix) + npW1 = np.ascontiguousarray(np.asarray(PyArray(W1))) + npW2 = np.ascontiguousarray(np.asarray(PyArray(W2))) + npX = np.ascontiguousarray(np.asarray(PyArray(X))) + npY = np.ascontiguousarray(np.asarray(PyArray(y))) + W1t = torch.from_numpy(npW1).to("cuda").requires_grad_(true) + W2t = torch.from_numpy(npW2).to("cuda") + Xt = torch.from_numpy(npX).to("cuda") + yt = torch.from_numpy(npY).to("cuda") + return W1t, W2t, Xt, yt +end + +function pytorch_grad(W1t, W2t, Xt, yt) + y1 = torch.tanh(torch.matmul(W1t, Xt)) + diff = torch.matmul(W2t, y1) - yt + n = pyconvert(Int, yt.shape[1]) + loss = (diff * diff).sum() / n + grad = torch.autograd.grad(loss, W1t)[0] + return grad +end + +torch_to_julia(t) = pyconvert(Array, t.detach().cpu().numpy()) + +# ------------------------------------------------------------------------- +# Trace helpers +# ------------------------------------------------------------------------- +function julia_trace(f) + # Warmup so JIT + cuBLAS handle init don't show up. + f(); CUDA.synchronize() + return CUDA.@profile trace = true begin + f() + CUDA.synchronize() + end +end + +function pytorch_trace(f) + f(); torch.cuda.synchronize() # warmup + ProfilerActivity = profiler.ProfilerActivity + prof = profiler.profile(activities = pylist([ProfilerActivity.CUDA])) + prof.__enter__() + try + f() + torch.cuda.synchronize() + finally + prof.__exit__(pybuiltins.None, pybuiltins.None, pybuiltins.None) + end + return prof.key_averages().table(sort_by = "cuda_time_total") +end + +# ------------------------------------------------------------------------- +# Benchmark + verify for one (h, d, n) +# ------------------------------------------------------------------------- +function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) + println("\n" * "="^72) + @printf "h = %d, d = %d, n = %d\n" h d n + println("="^72) + + Random.seed!(0) + W1 = randn(Float32, h, d) + W2 = randn(Float32, 1, h) + X = randn(Float32, d, n) + y = randn(Float32, 1, n) + + # ----- Julia / CUDA.jl ----- + W1g = CuArray(W1); W2g = CuArray(W2); Xg = CuArray(X); yg = CuArray(y) + grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg)) + CUDA.synchronize() + + # ----- PyTorch ----- + W1t, W2t, Xt, yt = build_torch_tensors(W1, W2, X, y) + grad_pytorch = torch_to_julia(pytorch_grad(W1t, W2t, Xt, yt)) + torch.cuda.synchronize() + + # ----- Numerical equivalence ----- + maxdiff = maximum(abs.(grad_julia .- grad_pytorch)) + relmag = maxdiff / max(maximum(abs.(grad_julia)), eps(Float32)) + @printf "max|Δ| = %.3e\n" maxdiff + @printf "max|Δ| / scale = %.3e\n" relmag + ok = isapprox(grad_julia, grad_pytorch; rtol = rtol, atol = 1f-4) + println("gradients match: ", ok) + + # ----- Benchmarks ----- + println("\n--- benchmark (median over many samples, includes CUDA.synchronize) ---") + bj = @benchmark begin + reverse_diff($W1g, $W2g, $Xg, $yg) + CUDA.synchronize() + end + bp = @benchmark begin + pytorch_grad($W1t, $W2t, $Xt, $yt) + $torch.cuda.synchronize() + end + @printf "Julia (CUDA.jl) : median %8.3f µs\n" 1e-3 * median(bj).time + @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(bp).time + + # ----- CUDA traces ----- + println("\n--- CUDA trace: Julia / CUDA.jl ---") + show(stdout, "text/plain", julia_trace(() -> reverse_diff(W1g, W2g, Xg, yg))) + println() + + println("\n--- CUDA trace: PyTorch ---") + println(pytorch_trace(() -> pytorch_grad(W1t, W2t, Xt, yt))) + + return nothing +end + +# ------------------------------------------------------------------------- +# Main +# ------------------------------------------------------------------------- +function main() + if !CUDA.functional() + error("CUDA is not functional in this environment.") + end + if !pyconvert(Bool, torch.cuda.is_available()) + error("PyTorch reports CUDA is not available.") + end + println("CUDA.jl device : ", CUDA.name(CUDA.device())) + println("PyTorch device : ", pyconvert(String, torch.cuda.get_device_name(0))) + + for h in (16, 256, 4096) + run_one(; h = h) + end +end + +main() From ae2426ea111b9a51b21c14b47ec240aec82bb1f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 30 Apr 2026 15:45:47 +0200 Subject: [PATCH 02/14] Fix --- perf/cuda_vs_pytorch.jl | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index c5cac9b..1181f86 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -80,6 +80,36 @@ function julia_trace(f) end end +# CUDATools.Profile.ProfileResults' default `show` calls `format_bytes` on a +# column that can contain `Inf` (e.g. for non-memcpy kernels), which throws +# `InexactError(Int64, Inf)`. Walk `trace.device` directly to sidestep it. +function summarize_julia_trace(io::IO, trace) + dev = trace.device + names_ = dev.name + starts = dev.start + stops = dev.stop + counts = Dict{String,Int}() + totals = Dict{String,Float64}() # in seconds + order = String[] + for i in eachindex(names_) + nm = String(names_[i]) + if !haskey(counts, nm) + push!(order, nm) + counts[nm] = 0 + totals[nm] = 0.0 + end + counts[nm] += 1 + totals[nm] += stops[i] - starts[i] + end + sorted = sort(order; by = nm -> -totals[nm]) + @printf io " %-66s %6s %10s\n" "kernel" "count" "total µs" + println(io, " ", "-"^66, " ", "-"^6, " ", "-"^10) + for nm in sorted + label = length(nm) <= 66 ? nm : first(nm, 63) * "..." + @printf io " %-66s %6d %10.2f\n" label counts[nm] 1e6 * totals[nm] + end +end + function pytorch_trace(f) f(); torch.cuda.synchronize() # warmup ProfilerActivity = profiler.ProfilerActivity @@ -141,8 +171,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # ----- CUDA traces ----- println("\n--- CUDA trace: Julia / CUDA.jl ---") - show(stdout, "text/plain", julia_trace(() -> reverse_diff(W1g, W2g, Xg, yg))) - println() + summarize_julia_trace(stdout, julia_trace(() -> reverse_diff(W1g, W2g, Xg, yg))) println("\n--- CUDA trace: PyTorch ---") println(pytorch_trace(() -> pytorch_grad(W1t, W2t, Xt, yt))) From baf3422ba9eebf56f40a68f480ca9b5cf39f22ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 30 Apr 2026 15:56:15 +0200 Subject: [PATCH 03/14] Fix --- perf/cuda_vs_pytorch.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index 1181f86..e71843a 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -124,6 +124,8 @@ function pytorch_trace(f) return prof.key_averages().table(sort_by = "cuda_time_total") end +const _pygc = pyimport("gc") + # ------------------------------------------------------------------------- # Benchmark + verify for one (h, d, n) # ------------------------------------------------------------------------- @@ -157,15 +159,17 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) println("gradients match: ", ok) # ----- Benchmarks ----- - println("\n--- benchmark (median over many samples, includes CUDA.synchronize) ---") + # samples=30 evals=1 caps total iterations so the PyTorch caching allocator + # doesn't blow up at h=4096; setup= clears it between samples. + println("\n--- benchmark (median of 30 samples, post-sync) ---") bj = @benchmark begin reverse_diff($W1g, $W2g, $Xg, $yg) CUDA.synchronize() - end + end samples=30 evals=1 seconds=10 bp = @benchmark begin pytorch_grad($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() - end + end samples=30 evals=1 seconds=10 setup=($(_pygc).collect(); $torch.cuda.empty_cache()) @printf "Julia (CUDA.jl) : median %8.3f µs\n" 1e-3 * median(bj).time @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(bp).time @@ -194,6 +198,11 @@ function main() for h in (16, 256, 4096) run_one(; h = h) + # Release per-h tensors from both caching allocators before the next sweep. + GC.gc(true) + CUDA.reclaim() + _pygc.collect() + torch.cuda.empty_cache() end end From 77676649fe682918433538590ca8c2f36882d3d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sat, 2 May 2026 22:45:11 +0200 Subject: [PATCH 04/14] Without eager as well --- perf/cuda_vs_pytorch.jl | 73 ++++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index e71843a..fc8af1e 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -57,15 +57,28 @@ function build_torch_tensors(W1::Matrix, W2::Matrix, X::Matrix, y::Matrix) return W1t, W2t, Xt, yt end -function pytorch_grad(W1t, W2t, Xt, yt) - y1 = torch.tanh(torch.matmul(W1t, Xt)) - diff = torch.matmul(W2t, y1) - yt - n = pyconvert(Int, yt.shape[1]) - loss = (diff * diff).sum() / n - grad = torch.autograd.grad(loss, W1t)[0] - return grad +# Define eager + torch.compile'd versions in one Python namespace so the only +# difference between them is the `torch.compile` call. Each call goes through +# exactly one PythonCall round-trip, so wall-clock differences reflect what's +# actually happening on the GPU rather than per-op FFI cost. +const _grad_fn_eager, _grad_fn_compiled = let + nt = @pyexec """ +import torch +def _eager(W1, W2, X, y): + y1 = torch.tanh(torch.matmul(W1, X)) + diff = torch.matmul(W2, y1) - y + loss = (diff * diff).sum() / y.shape[1] + return torch.autograd.grad(loss, W1)[0] +# mode="default" — change to "reduce-overhead" for CUDA Graphs, or +# "max-autotune" for an aggressive autotune pass. +_compiled = torch.compile(_eager) +""" => (_eager::Py, _compiled::Py) + (nt._eager, nt._compiled) end +pytorch_grad_eager(W1t, W2t, Xt, yt) = _grad_fn_eager(W1t, W2t, Xt, yt) +pytorch_grad_compiled(W1t, W2t, Xt, yt) = _grad_fn_compiled(W1t, W2t, Xt, yt) + torch_to_julia(t) = pyconvert(Array, t.detach().cpu().numpy()) # ------------------------------------------------------------------------- @@ -147,16 +160,28 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # ----- PyTorch ----- W1t, W2t, Xt, yt = build_torch_tensors(W1, W2, X, y) - grad_pytorch = torch_to_julia(pytorch_grad(W1t, W2t, Xt, yt)) + grad_pytorch_eager = torch_to_julia(pytorch_grad_eager(W1t, W2t, Xt, yt)) + torch.cuda.synchronize() + + # First call to the compiled fn for this shape triggers Inductor codegen + # (can take seconds). Time it so the user knows. + print("torch.compile codegen for h=$h ... "); flush(stdout) + t_compile = @elapsed begin + pytorch_grad_compiled(W1t, W2t, Xt, yt) + torch.cuda.synchronize() + end + @printf "%.2f s\n" t_compile + grad_pytorch_compiled = torch_to_julia(pytorch_grad_compiled(W1t, W2t, Xt, yt)) torch.cuda.synchronize() # ----- Numerical equivalence ----- - maxdiff = maximum(abs.(grad_julia .- grad_pytorch)) - relmag = maxdiff / max(maximum(abs.(grad_julia)), eps(Float32)) - @printf "max|Δ| = %.3e\n" maxdiff - @printf "max|Δ| / scale = %.3e\n" relmag - ok = isapprox(grad_julia, grad_pytorch; rtol = rtol, atol = 1f-4) - println("gradients match: ", ok) + for (name, g) in [("eager ", grad_pytorch_eager), + ("compiled", grad_pytorch_compiled)] + maxdiff = maximum(abs.(grad_julia .- g)) + relmag = maxdiff / max(maximum(abs.(grad_julia)), eps(Float32)) + ok = isapprox(grad_julia, g; rtol = rtol, atol = 1f-4) + @printf "PyTorch %s vs Julia: max|Δ| = %.3e (rel %.2e) match=%s\n" name maxdiff relmag ok + end # ----- Benchmarks ----- # samples=30 evals=1 caps total iterations so the PyTorch caching allocator @@ -166,19 +191,27 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) reverse_diff($W1g, $W2g, $Xg, $yg) CUDA.synchronize() end samples=30 evals=1 seconds=10 - bp = @benchmark begin - pytorch_grad($W1t, $W2t, $Xt, $yt) + be = @benchmark begin + pytorch_grad_eager($W1t, $W2t, $Xt, $yt) + $torch.cuda.synchronize() + end samples=30 evals=1 seconds=10 setup=($(_pygc).collect(); $torch.cuda.empty_cache()) + bc = @benchmark begin + pytorch_grad_compiled($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() end samples=30 evals=1 seconds=10 setup=($(_pygc).collect(); $torch.cuda.empty_cache()) - @printf "Julia (CUDA.jl) : median %8.3f µs\n" 1e-3 * median(bj).time - @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(bp).time + @printf "Julia (CUDA.jl) : median %8.3f µs\n" 1e-3 * median(bj).time + @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time + @printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time # ----- CUDA traces ----- println("\n--- CUDA trace: Julia / CUDA.jl ---") summarize_julia_trace(stdout, julia_trace(() -> reverse_diff(W1g, W2g, Xg, yg))) - println("\n--- CUDA trace: PyTorch ---") - println(pytorch_trace(() -> pytorch_grad(W1t, W2t, Xt, yt))) + println("\n--- CUDA trace: PyTorch eager ---") + println(pytorch_trace(() -> pytorch_grad_eager(W1t, W2t, Xt, yt))) + + println("\n--- CUDA trace: PyTorch compiled ---") + println(pytorch_trace(() -> pytorch_grad_compiled(W1t, W2t, Xt, yt))) return nothing end From 6ccc613901ff84e0af510cbb4001027569314ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sat, 2 May 2026 22:48:56 +0200 Subject: [PATCH 05/14] Fix --- perf/cuda_vs_pytorch.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index fc8af1e..bd9a1e5 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -62,9 +62,14 @@ end # exactly one PythonCall round-trip, so wall-clock differences reflect what's # actually happening on the GPU rather than per-op FFI cost. const _grad_fn_eager, _grad_fn_compiled = let + # `import torch` inside _eager so it lands in the function's __globals__ + # at call time — @pyexec runs with separate globals/locals dicts, and a + # top-level `import torch` would only populate locals, leaving _eager + # unable to resolve `torch` when invoked later. nt = @pyexec """ import torch def _eager(W1, W2, X, y): + import torch y1 = torch.tanh(torch.matmul(W1, X)) diff = torch.matmul(W2, y1) - y loss = (diff * diff).sum() / y.shape[1] From fa4ed39da856bf2527285132c8c54a8a21385739 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sat, 2 May 2026 23:09:42 +0200 Subject: [PATCH 06/14] Hardcoded kernel --- perf/cuda_vs_pytorch.jl | 118 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 111 insertions(+), 7 deletions(-) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index bd9a1e5..2666d8c 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -18,6 +18,7 @@ using Random, LinearAlgebra, Printf using CUDA +using CUDA: AS using BenchmarkTools using PythonCall @@ -36,6 +37,95 @@ function reverse_diff(W1, W2, X, y) return (J_1 .* (W2' * J_2)) * X' end +# ------------------------------------------------------------------------- +# Hardcoded CUDA.jl path — vectorized custom-kernel version +# +# Replaces the two big elementwise broadcasts: +# 1. tanh.(W1 * X) and 1 .- y_1.^2 → one fused @cuda kernel +# 2. J_1 .* (W2' * J_2) → one @cuda kernel +# with kernels that issue `ld.global.v4.f32` / `st.global.v4.f32` PTX, +# matching PyTorch's `vectorized_elementwise_kernel<4, ...>` shape. +# +# The third broadcast (J_2 from a 1×n vector) is left as a regular .= since +# n is tiny (e.g. 178) and vectorization wouldn't measurably help. +# +# Pattern is taken from CUDA.jl's own ldg.jl tests, which assert that +# NTuple{4, Base.VecElement{Float32}} loads via Core.LLVMPtr lower to +# `ld.global.v4` PTX. +# ------------------------------------------------------------------------- +const Float4 = NTuple{4, Base.VecElement{Float32}} + +@inline _f4ptr(arr::CuDeviceArray{Float32}) = + reinterpret(Core.LLVMPtr{Float4, AS.Global}, pointer(arr)) + +@inline _vec4(t1, t2, t3, t4) = + (Base.VecElement(t1), Base.VecElement(t2), Base.VecElement(t3), Base.VecElement(t4)) + +function _tanh_and_jac_kernel!(y, J, x, n::Int) + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + base = 4 * (i - 1) + if base + 4 <= n + v = unsafe_load(_f4ptr(x), i, Val(16)) + t1 = tanh(v[1].value); t2 = tanh(v[2].value) + t3 = tanh(v[3].value); t4 = tanh(v[4].value) + unsafe_store!(_f4ptr(y), _vec4(t1, t2, t3, t4), i, Val(16)) + unsafe_store!(_f4ptr(J), + _vec4(1f0 - t1*t1, 1f0 - t2*t2, 1f0 - t3*t3, 1f0 - t4*t4), i, Val(16)) + elseif base < n + for k in 1:(n - base) + @inbounds t = tanh(x[base + k]) + @inbounds y[base + k] = t + @inbounds J[base + k] = 1f0 - t*t + end + end + return nothing +end + +function tanh_and_jac!(y::CuArray{Float32}, J::CuArray{Float32}, x::CuArray{Float32}) + n = length(x) + threads = 256 + blocks = cld(cld(n, 4), threads) + @cuda threads=threads blocks=blocks _tanh_and_jac_kernel!(y, J, x, n) + return (y, J) +end + +function _vmul_kernel!(out, a, b, n::Int) + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + base = 4 * (i - 1) + if base + 4 <= n + va = unsafe_load(_f4ptr(a), i, Val(16)) + vb = unsafe_load(_f4ptr(b), i, Val(16)) + unsafe_store!(_f4ptr(out), + _vec4(va[1].value*vb[1].value, va[2].value*vb[2].value, + va[3].value*vb[3].value, va[4].value*vb[4].value), i, Val(16)) + elseif base < n + for k in 1:(n - base) + @inbounds out[base + k] = a[base + k] * b[base + k] + end + end + return nothing +end + +function vmul!(out::CuArray{Float32}, a::CuArray{Float32}, b::CuArray{Float32}) + n = length(a) + threads = 256 + blocks = cld(cld(n, 4), threads) + @cuda threads=threads blocks=blocks _vmul_kernel!(out, a, b, n) + return out +end + +function reverse_diff_v4(W1, W2, X, y) + Z1 = W1 * X # GEMM (h, n) + y_1 = similar(Z1) + J_1 = similar(Z1) + tanh_and_jac!(y_1, J_1, Z1) # fused tanh + (1 - y²), vec=4 + J_2 = 2 .* (W2 * y_1 .- y) ./ size(y, 2) # 1×n broadcast, kept as-is + tmp = W2' * J_2 # GEMM (h, n) + out = similar(tmp) + vmul!(out, J_1, tmp) # J_1 .* tmp, vec=4 + return out * X' # GEMM (h, d) +end + # ------------------------------------------------------------------------- # PyTorch path # ------------------------------------------------------------------------- @@ -160,7 +250,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # ----- Julia / CUDA.jl ----- W1g = CuArray(W1); W2g = CuArray(W2); Xg = CuArray(X); yg = CuArray(y) - grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg)) + grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg)) + grad_julia_v4 = Array(reverse_diff_v4(W1g, W2g, Xg, yg)) CUDA.synchronize() # ----- PyTorch ----- @@ -180,12 +271,13 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) torch.cuda.synchronize() # ----- Numerical equivalence ----- - for (name, g) in [("eager ", grad_pytorch_eager), - ("compiled", grad_pytorch_compiled)] + for (name, g) in [("Julia v4 (vec=4) ", grad_julia_v4), + ("PyTorch eager ", grad_pytorch_eager), + ("PyTorch compiled ", grad_pytorch_compiled)] maxdiff = maximum(abs.(grad_julia .- g)) relmag = maxdiff / max(maximum(abs.(grad_julia)), eps(Float32)) ok = isapprox(grad_julia, g; rtol = rtol, atol = 1f-4) - @printf "PyTorch %s vs Julia: max|Δ| = %.3e (rel %.2e) match=%s\n" name maxdiff relmag ok + @printf "%s vs Julia broadcast: max|Δ| = %.3e (rel %.2e) match=%s\n" name maxdiff relmag ok end # ----- Benchmarks ----- @@ -196,6 +288,10 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) reverse_diff($W1g, $W2g, $Xg, $yg) CUDA.synchronize() end samples=30 evals=1 seconds=10 + bj4 = @benchmark begin + reverse_diff_v4($W1g, $W2g, $Xg, $yg) + CUDA.synchronize() + end samples=30 evals=1 seconds=10 be = @benchmark begin pytorch_grad_eager($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() @@ -204,14 +300,18 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) pytorch_grad_compiled($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() end samples=30 evals=1 seconds=10 setup=($(_pygc).collect(); $torch.cuda.empty_cache()) - @printf "Julia (CUDA.jl) : median %8.3f µs\n" 1e-3 * median(bj).time + @printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time + @printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time @printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time # ----- CUDA traces ----- - println("\n--- CUDA trace: Julia / CUDA.jl ---") + println("\n--- CUDA trace: Julia broadcast ---") summarize_julia_trace(stdout, julia_trace(() -> reverse_diff(W1g, W2g, Xg, yg))) + println("\n--- CUDA trace: Julia vec=4 ---") + summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v4(W1g, W2g, Xg, yg))) + println("\n--- CUDA trace: PyTorch eager ---") println(pytorch_trace(() -> pytorch_grad_eager(W1t, W2t, Xt, yt))) @@ -231,7 +331,11 @@ function main() if !pyconvert(Bool, torch.cuda.is_available()) error("PyTorch reports CUDA is not available.") end - println("CUDA.jl device : ", CUDA.name(CUDA.device())) + # Match PyTorch's default of fast tanh / fast intrinsics. Affects BOTH + # Julia versions equally, so the broadcast-vs-vec=4 comparison still + # isolates the kernel-design effect. + CUDA.math_mode!(CUDA.FAST_MATH) + println("CUDA.jl device : ", CUDA.name(CUDA.device()), " (math_mode=FAST_MATH)") println("PyTorch device : ", pyconvert(String, torch.cuda.get_device_name(0))) for h in (16, 256, 4096) From 9322f27e0a0435f0e113abc2c3af45ef37bc861c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sat, 2 May 2026 23:23:01 +0200 Subject: [PATCH 07/14] Custom GEMM --- perf/cuda_vs_pytorch.jl | 74 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index 2666d8c..b16c71f 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -126,6 +126,62 @@ function reverse_diff_v4(W1, W2, X, y) return out * X' # GEMM (h, d) end +# ------------------------------------------------------------------------- +# v5: vec=4 elementwise + SIMT cuBLAS GEMM (no TF32 tensor cores) +# +# Under FAST_MATH, gemmExComputeType picks CUBLAS_COMPUTE_32F_FAST_TF32, which +# routes Float32 matmul through TF32 tensor cores. For our awkward k dims +# (k=13 for W1*X, k=178 for out*X', k=1 for W2'*J_2), tensor cores can't be +# filled efficiently and the resulting cutlass_80_tensorop kernel runs much +# slower than a SIMT FP32 GEMM. We bypass gemmExComputeType by calling +# cublasGemmEx directly with CUBLAS_COMPUTE_32F + CUBLAS_GEMM_DEFAULT. +# ------------------------------------------------------------------------- +function _gemm_simt!(C::CuArray{Float32,2}, transA::Char, A::CuArray{Float32,2}, + transB::Char, B::CuArray{Float32,2}; + alpha::Float32 = 1f0, beta::Float32 = 0f0) + m = size(A, transA == 'N' ? 1 : 2) + k = size(A, transA == 'N' ? 2 : 1) + n = size(B, transB == 'N' ? 2 : 1) + lda = max(1, stride(A, 2)) + ldb = max(1, stride(B, 2)) + ldc = max(1, stride(C, 2)) + α = Ref{Float32}(alpha); β = Ref{Float32}(beta) + CUDA.CUBLAS.cublasGemmEx( + CUDA.CUBLAS.handle(), + transA, transB, m, n, k, + α, A, Float32, lda, + B, Float32, ldb, + β, C, Float32, ldc, + CUDA.CUBLAS.CUBLAS_COMPUTE_32F, + CUDA.CUBLAS.CUBLAS_GEMM_DEFAULT, + ) + return C +end + +function reverse_diff_v5(W1, W2, X, y) + h, d = size(W1) + nn = size(X, 2) + + Z1 = CuArray{Float32}(undef, h, nn) + _gemm_simt!(Z1, 'N', W1, 'N', X) # SIMT: (h,d) * (d,n) + + y_1 = similar(Z1) + J_1 = similar(Z1) + tanh_and_jac!(y_1, J_1, Z1) + + J_2 = 2 .* (W2 * y_1 .- y) ./ size(y, 2) # tiny, leave as broadcast + + tmp = CuArray{Float32}(undef, h, nn) + _gemm_simt!(tmp, 'T', W2, 'N', J_2) # SIMT: W2' * J_2 (k=1) + + out = similar(tmp) + vmul!(out, J_1, tmp) + + result = CuArray{Float32}(undef, h, d) + _gemm_simt!(result, 'N', out, 'T', X) # SIMT: out * X' (k=n) + return result +end + # ------------------------------------------------------------------------- # PyTorch path # ------------------------------------------------------------------------- @@ -252,6 +308,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) W1g = CuArray(W1); W2g = CuArray(W2); Xg = CuArray(X); yg = CuArray(y) grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg)) grad_julia_v4 = Array(reverse_diff_v4(W1g, W2g, Xg, yg)) + grad_julia_v5 = Array(reverse_diff_v5(W1g, W2g, Xg, yg)) CUDA.synchronize() # ----- PyTorch ----- @@ -272,6 +329,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # ----- Numerical equivalence ----- for (name, g) in [("Julia v4 (vec=4) ", grad_julia_v4), + ("Julia v5 (vec=4+SIMT)", grad_julia_v5), ("PyTorch eager ", grad_pytorch_eager), ("PyTorch compiled ", grad_pytorch_compiled)] maxdiff = maximum(abs.(grad_julia .- g)) @@ -292,6 +350,10 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) reverse_diff_v4($W1g, $W2g, $Xg, $yg) CUDA.synchronize() end samples=30 evals=1 seconds=10 + bj5 = @benchmark begin + reverse_diff_v5($W1g, $W2g, $Xg, $yg) + CUDA.synchronize() + end samples=30 evals=1 seconds=10 be = @benchmark begin pytorch_grad_eager($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() @@ -300,10 +362,11 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) pytorch_grad_compiled($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() end samples=30 evals=1 seconds=10 setup=($(_pygc).collect(); $torch.cuda.empty_cache()) - @printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time - @printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time - @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time - @printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time + @printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time + @printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time + @printf "Julia vec=4 + SIMT : median %8.3f µs\n" 1e-3 * median(bj5).time + @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time + @printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time # ----- CUDA traces ----- println("\n--- CUDA trace: Julia broadcast ---") @@ -312,6 +375,9 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) println("\n--- CUDA trace: Julia vec=4 ---") summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v4(W1g, W2g, Xg, yg))) + println("\n--- CUDA trace: Julia vec=4 + SIMT ---") + summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v5(W1g, W2g, Xg, yg))) + println("\n--- CUDA trace: PyTorch eager ---") println(pytorch_trace(() -> pytorch_grad_eager(W1t, W2t, Xt, yt))) From f7d54341bb752d1f001a798b3fa9f1bbd7b45a8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sat, 2 May 2026 23:27:56 +0200 Subject: [PATCH 08/14] Fix --- perf/cuda_vs_pytorch.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index b16c71f..3792d5e 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -145,7 +145,10 @@ function _gemm_simt!(C::CuArray{Float32,2}, transA::Char, A::CuArray{Float32,2}, lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) ldc = max(1, stride(C, 2)) - α = Ref{Float32}(alpha); β = Ref{Float32}(beta) + # CUDA.jl puts the cuBLAS handle in CUBLAS_POINTER_MODE_DEVICE, so alpha/beta + # MUST be device pointers. Passing host Ref{Float32} causes UVA fault handling + # per kernel launch (~100× slowdown but eventually-correct values). + α = CUDA.CuRef{Float32}(alpha); β = CUDA.CuRef{Float32}(beta) CUDA.CUBLAS.cublasGemmEx( CUDA.CUBLAS.handle(), transA, transB, m, n, k, From 38a9ea1e8382228d8d0360f1524009a920086aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sun, 3 May 2026 11:48:40 +0200 Subject: [PATCH 09/14] Lux --- perf/Project.toml | 2 + perf/cuda_vs_pytorch.jl | 82 +++++++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/perf/Project.toml b/perf/Project.toml index ff01a45..b0c8c3e 100644 --- a/perf/Project.toml +++ b/perf/Project.toml @@ -8,6 +8,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index 3792d5e..24228c2 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -21,6 +21,7 @@ using CUDA using CUDA: AS using BenchmarkTools using PythonCall +using Lux, Zygote # ------------------------------------------------------------------------- # Hardcoded CUDA.jl path @@ -146,18 +147,25 @@ function _gemm_simt!(C::CuArray{Float32,2}, transA::Char, A::CuArray{Float32,2}, ldb = max(1, stride(B, 2)) ldc = max(1, stride(C, 2)) # CUDA.jl puts the cuBLAS handle in CUBLAS_POINTER_MODE_DEVICE, so alpha/beta - # MUST be device pointers. Passing host Ref{Float32} causes UVA fault handling - # per kernel launch (~100× slowdown but eventually-correct values). + # MUST be device pointers (host Ref triggers UVA fault handling — 100× slowdown). α = CUDA.CuRef{Float32}(alpha); β = CUDA.CuRef{Float32}(beta) - CUDA.CUBLAS.cublasGemmEx( - CUDA.CUBLAS.handle(), - transA, transB, m, n, k, - α, A, Float32, lda, - B, Float32, ldb, - β, C, Float32, ldc, - CUDA.CUBLAS.CUBLAS_COMPUTE_32F, - CUDA.CUBLAS.CUBLAS_GEMM_DEFAULT, - ) + h = CUDA.CUBLAS.handle() + # Under FAST_MATH the handle's math mode is CUBLAS_TF32_TENSOR_OP_MATH, which + # forces TF32 tensor cores even when we ask for CUBLAS_COMPUTE_32F. Flip it to + # DEFAULT_MATH for this call so cuBLAS picks a SIMT FP32 kernel. + CUDA.CUBLAS.math_mode!(h, CUDA.DEFAULT_MATH) + try + CUDA.CUBLAS.cublasGemmEx( + h, transA, transB, m, n, k, + α, A, Float32, lda, + B, Float32, ldb, + β, C, Float32, ldc, + CUDA.CUBLAS.CUBLAS_COMPUTE_32F, + CUDA.CUBLAS.CUBLAS_GEMM_DEFAULT, + ) + finally + CUDA.CUBLAS.math_mode!(h, CUDA.math_mode()) # restore (FAST_MATH → TF32 tensor op) + end return C end @@ -185,6 +193,38 @@ function reverse_diff_v5(W1, W2, X, y) return result end +# ------------------------------------------------------------------------- +# Lux + Zygote path +# +# Builds an equivalent 2-layer MLP `Y = W2 * tanh(W1 * X)` (no bias) using +# Lux, plugs in the *same* CuArray weights so the gradient is comparable, +# and lets Zygote source-to-source the backward. This goes through the same +# CUDA.jl + cuBLAS stack as `reverse_diff`, so we expect similar kernels — +# the interesting thing is the AD/dispatch overhead Lux+Zygote add on top. +# ------------------------------------------------------------------------- +function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2}) + h, d = size(W1g) + model = Lux.Chain( + Lux.Dense(d => h, tanh; use_bias = false), + Lux.Dense(h => 1, identity; use_bias = false), + ) + ps = ( + layer_1 = (weight = W1g,), + layer_2 = (weight = W2g,), + ) + st = Lux.initialstates(Random.default_rng(), model) + return model, ps, st +end + +function lux_grad(model, ps, st, Xg::CuArray, yg::CuArray) + function loss_fn(p) + y_hat, _ = model(Xg, p, st) + return sum((y_hat .- yg) .^ 2) / size(yg, 2) + end + ∂ps = first(Zygote.gradient(loss_fn, ps)) + return ∂ps.layer_1.weight +end + # ------------------------------------------------------------------------- # PyTorch path # ------------------------------------------------------------------------- @@ -314,6 +354,17 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) grad_julia_v5 = Array(reverse_diff_v5(W1g, W2g, Xg, yg)) CUDA.synchronize() + # Lux + Zygote warmup (first call compiles Zygote's pullback for this shape) + print("Lux+Zygote compile warmup for h=$h ... "); flush(stdout) + lux_model, lux_ps, lux_st = build_lux(W1g, W2g) + t_lux_compile = @elapsed begin + lux_grad(lux_model, lux_ps, lux_st, Xg, yg) + CUDA.synchronize() + end + @printf "%.2f s\n" t_lux_compile + grad_lux = Array(lux_grad(lux_model, lux_ps, lux_st, Xg, yg)) + CUDA.synchronize() + # ----- PyTorch ----- W1t, W2t, Xt, yt = build_torch_tensors(W1, W2, X, y) grad_pytorch_eager = torch_to_julia(pytorch_grad_eager(W1t, W2t, Xt, yt)) @@ -333,6 +384,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # ----- Numerical equivalence ----- for (name, g) in [("Julia v4 (vec=4) ", grad_julia_v4), ("Julia v5 (vec=4+SIMT)", grad_julia_v5), + ("Lux + Zygote ", grad_lux), ("PyTorch eager ", grad_pytorch_eager), ("PyTorch compiled ", grad_pytorch_compiled)] maxdiff = maximum(abs.(grad_julia .- g)) @@ -357,6 +409,10 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) reverse_diff_v5($W1g, $W2g, $Xg, $yg) CUDA.synchronize() end samples=30 evals=1 seconds=10 + bjlux = @benchmark begin + lux_grad($lux_model, $lux_ps, $lux_st, $Xg, $yg) + CUDA.synchronize() + end samples=30 evals=1 seconds=10 be = @benchmark begin pytorch_grad_eager($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() @@ -368,6 +424,7 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) @printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time @printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time @printf "Julia vec=4 + SIMT : median %8.3f µs\n" 1e-3 * median(bj5).time + @printf "Lux + Zygote : median %8.3f µs\n" 1e-3 * median(bjlux).time @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time @printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time @@ -381,6 +438,9 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) println("\n--- CUDA trace: Julia vec=4 + SIMT ---") summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v5(W1g, W2g, Xg, yg))) + println("\n--- CUDA trace: Lux + Zygote ---") + summarize_julia_trace(stdout, julia_trace(() -> lux_grad(lux_model, lux_ps, lux_st, Xg, yg))) + println("\n--- CUDA trace: PyTorch eager ---") println(pytorch_trace(() -> pytorch_grad_eager(W1t, W2t, Xt, yt))) From 34d3bb344797bc65265f5607b6665bd700720291 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 4 May 2026 16:04:38 +0200 Subject: [PATCH 10/14] Add more --- perf/Project.toml | 2 +- perf/cuda_vs_pytorch.jl | 245 +++++++++++++++++++++++++++++++++++----- 2 files changed, 220 insertions(+), 27 deletions(-) diff --git a/perf/Project.toml b/perf/Project.toml index b0c8c3e..f7d180d 100644 --- a/perf/Project.toml +++ b/perf/Project.toml @@ -9,7 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index 24228c2..cd0d3f6 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -21,7 +21,8 @@ using CUDA using CUDA: AS using BenchmarkTools using PythonCall -using Lux, Zygote +using Lux +import Mooncake # ------------------------------------------------------------------------- # Hardcoded CUDA.jl path @@ -194,15 +195,182 @@ function reverse_diff_v5(W1, W2, X, y) end # ------------------------------------------------------------------------- -# Lux + Zygote path +# v6: vec=4 elementwise + cuBLASLt with per-shape heuristic-picked algo +# +# cuBLAS's standard heuristic, even with CUBLAS_COMPUTE_32F + DEFAULT_MATH, +# picks `cutlass_80_simt_sgemm_*` for our awkward shapes. PyTorch's process +# happens to land on `magma_sgemmEx_kernel` for the same compute type — same +# library, different choice. cuBLASLt exposes a fuller heuristic API with a +# workspace budget that often unlocks better algos. We build a matmul +# descriptor + layouts per (transA, transB, m, n, k) shape, ask cuBLASLt for +# its best algo, and reuse the cached plan on every call. +# ------------------------------------------------------------------------- +const _LT_WS_BYTES = Csize_t(32 * 1024 * 1024) # 32 MiB workspace + +# Lazy: created on first use, kept alive for the process. +const _LT_STATE = Ref{Any}(nothing) + +function _lt_state() + s = _LT_STATE[] + if s === nothing + h_ref = Ref{CUDA.CUBLAS.cublasLtHandle_t}(C_NULL) + CUDA.CUBLAS.cublasLtCreate(h_ref) + ws = CUDA.CuArray{UInt8}(undef, Int(_LT_WS_BYTES)) + s = (handle = h_ref[], ws = ws) + _LT_STATE[] = s + end + return s::NamedTuple{(:handle, :ws)} +end + +mutable struct LtPlan + desc::CUDA.CUBLAS.cublasLtMatmulDesc_t + Adesc::CUDA.CUBLAS.cublasLtMatrixLayout_t + Bdesc::CUDA.CUBLAS.cublasLtMatrixLayout_t + Cdesc::CUDA.CUBLAS.cublasLtMatrixLayout_t + algo::CUDA.CUBLAS.cublasLtMatmulAlgo_t +end + +function _build_lt_plan(transA::Char, transB::Char, + m::Int, n::Int, k::Int, + lda::Int, ldb::Int, ldc::Int) + state = _lt_state() + handle = state.handle + R32 = CUDA.CUDACore.R_32F # cudaDataType for Float32 + + desc_ref = Ref{CUDA.CUBLAS.cublasLtMatmulDesc_t}(C_NULL) + CUDA.CUBLAS.cublasLtMatmulDescCreate(desc_ref, CUDA.CUBLAS.CUBLAS_COMPUTE_32F, R32) + desc = desc_ref[] + + # Set transpose attributes. + tA = (transA == 'N') ? CUDA.CUBLAS.CUBLAS_OP_N : CUDA.CUBLAS.CUBLAS_OP_T + tB = (transB == 'N') ? CUDA.CUBLAS.CUBLAS_OP_N : CUDA.CUBLAS.CUBLAS_OP_T + let r = Ref(tA) + CUDA.CUBLAS.cublasLtMatmulDescSetAttribute( + desc, CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSA, r, sizeof(tA)) + end + let r = Ref(tB) + CUDA.CUBLAS.cublasLtMatmulDescSetAttribute( + desc, CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSB, r, sizeof(tB)) + end + + # Layout shape is the *storage* shape (pre-transpose). + Arows = transA == 'N' ? m : k + Acols = transA == 'N' ? k : m + Brows = transB == 'N' ? k : n + Bcols = transB == 'N' ? n : k + + Aref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL) + Bref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL) + Cref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL) + CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Aref, R32, UInt64(Arows), UInt64(Acols), Int64(lda)) + CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Bref, R32, UInt64(Brows), UInt64(Bcols), Int64(ldb)) + CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Cref, R32, UInt64(m), UInt64(n), Int64(ldc)) + + # Preference: tell the heuristic how much workspace it can use. + pref_ref = Ref{CUDA.CUBLAS.cublasLtMatmulPreference_t}(C_NULL) + CUDA.CUBLAS.cublasLtMatmulPreferenceCreate(pref_ref) + pref = pref_ref[] + let r = Ref(_LT_WS_BYTES) + CUDA.CUBLAS.cublasLtMatmulPreferenceSetAttribute( + pref, CUDA.CUBLAS.CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + r, sizeof(_LT_WS_BYTES)) + end + + # Heuristic: top-1 algorithm. + heur = Vector{CUDA.CUBLAS.cublasLtMatmulHeuristicResult_t}(undef, 1) + returned = Ref{Cint}(0) + CUDA.CUBLAS.cublasLtMatmulAlgoGetHeuristic( + handle, desc, Aref[], Bref[], Cref[], Cref[], + pref, Cint(1), heur, returned) + returned[] < 1 && error("cuBLASLt has no algo for shape (m=$m,n=$n,k=$k,trans=$transA$transB)") + + return LtPlan(desc, Aref[], Bref[], Cref[], heur[1].algo) +end + +function _gemm_lt!(plan::LtPlan, + C::CuArray{Float32,2}, A::CuArray{Float32,2}, B::CuArray{Float32,2}; + alpha::Float32 = 1f0, beta::Float32 = 0f0) + state = _lt_state() + # cuBLASLt's matmul descriptor defaults to CUBLASLT_POINTER_MODE_HOST + # (independent of the cuBLAS handle's pointer mode), so alpha/beta are + # plain host Refs here — using CuRef would trigger UVA faults. + α = Ref{Float32}(alpha) + β = Ref{Float32}(beta) + algo_ref = Ref(plan.algo) + CUDA.CUBLAS.cublasLtMatmul( + state.handle, plan.desc, + α, A, plan.Adesc, + B, plan.Bdesc, + β, C, plan.Cdesc, + C, plan.Cdesc, # D = C in place + algo_ref, + state.ws, sizeof(state.ws), + CUDA.stream(), + ) + return C +end + +# Three plans for our specific 2-layer MLP shape. +struct LtPlans + p1::LtPlan # W1 * X : (h,d) * (d,n) → (h,n) + p2::LtPlan # W2' * J_2 : store (1,h),'T' * (1,n) → (h,n) + p3::LtPlan # out * X' : (h,n) * store (d,n),'T' → (h,d) +end + +function build_lt_plans(W1::CuArray{Float32,2}, W2::CuArray{Float32,2}, + X::CuArray{Float32,2}) + h, d = size(W1) + nn = size(X, 2) + p1 = _build_lt_plan('N', 'N', h, nn, d, h, d, h) + p2 = _build_lt_plan('T', 'N', h, nn, 1, 1, 1, h) + p3 = _build_lt_plan('N', 'T', h, d, nn, h, d, h) + return LtPlans(p1, p2, p3) +end + +function reverse_diff_v6(plans::LtPlans, W1, W2, X, y) + h, d = size(W1) + nn = size(X, 2) + + Z1 = CuArray{Float32}(undef, h, nn) + _gemm_lt!(plans.p1, Z1, W1, X) + + y_1 = similar(Z1) + J_1 = similar(Z1) + tanh_and_jac!(y_1, J_1, Z1) + + J_2 = 2 .* (W2 * y_1 .- y) ./ size(y, 2) + + tmp = CuArray{Float32}(undef, h, nn) + _gemm_lt!(plans.p2, tmp, W2, J_2) + + out = similar(tmp) + vmul!(out, J_1, tmp) + + result = CuArray{Float32}(undef, h, d) + _gemm_lt!(plans.p3, result, out, X) + return result +end + +# ------------------------------------------------------------------------- +# Lux + Mooncake path # # Builds an equivalent 2-layer MLP `Y = W2 * tanh(W1 * X)` (no bias) using # Lux, plugs in the *same* CuArray weights so the gradient is comparable, -# and lets Zygote source-to-source the backward. This goes through the same -# CUDA.jl + cuBLAS stack as `reverse_diff`, so we expect similar kernels — -# the interesting thing is the AD/dispatch overhead Lux+Zygote add on top. +# and uses Mooncake (the modern Julia 1.12-friendly reverse-mode AD) for the +# backward. Goes through the same CUDA.jl + cuBLAS stack as `reverse_diff`, +# so kernels should look similar — what we're measuring is the AD/dispatch +# overhead Lux+Mooncake add on top. # ------------------------------------------------------------------------- -function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2}) +struct LuxMooncake{M,P,S,L,R} + model::M + ps::P + st::S + loss_fn::L + rule::R +end + +function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2}, + Xg::CuArray, yg::CuArray) h, d = size(W1g) model = Lux.Chain( Lux.Dense(d => h, tanh; use_bias = false), @@ -213,15 +381,24 @@ function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2}) layer_2 = (weight = W2g,), ) st = Lux.initialstates(Random.default_rng(), model) - return model, ps, st -end -function lux_grad(model, ps, st, Xg::CuArray, yg::CuArray) - function loss_fn(p) - y_hat, _ = model(Xg, p, st) - return sum((y_hat .- yg) .^ 2) / size(yg, 2) + # Closure captures Xg, yg, model, st — only `p` is the differentiated arg. + loss_fn = let model = model, st = st, Xg = Xg, yg = yg + p -> begin + y_hat, _ = model(Xg, p, st) + return sum((y_hat .- yg) .^ 2) / size(yg, 2) + end end - ∂ps = first(Zygote.gradient(loss_fn, ps)) + + # build_rrule is the expensive step (compiles the reverse pass for these + # types) — do it once at setup so the per-call cost in the benchmark is + # just the actual fwd+bwd execution. + rule = Mooncake.build_rrule(loss_fn, ps) + return LuxMooncake(model, ps, st, loss_fn, rule) +end + +function lux_grad(lm::LuxMooncake) + _, (_, ∂ps) = Mooncake.value_and_gradient!!(lm.rule, lm.loss_fn, lm.ps) return ∂ps.layer_1.weight end @@ -352,17 +529,24 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg)) grad_julia_v4 = Array(reverse_diff_v4(W1g, W2g, Xg, yg)) grad_julia_v5 = Array(reverse_diff_v5(W1g, W2g, Xg, yg)) + print("cuBLASLt build_lt_plans for h=$h ... "); flush(stdout) + t_lt_build = @elapsed lt_plans = build_lt_plans(W1g, W2g, Xg) + @printf "%.3f s\n" t_lt_build + grad_julia_v6 = Array(reverse_diff_v6(lt_plans, W1g, W2g, Xg, yg)) CUDA.synchronize() - # Lux + Zygote warmup (first call compiles Zygote's pullback for this shape) - print("Lux+Zygote compile warmup for h=$h ... "); flush(stdout) - lux_model, lux_ps, lux_st = build_lux(W1g, W2g) - t_lux_compile = @elapsed begin - lux_grad(lux_model, lux_ps, lux_st, Xg, yg) - CUDA.synchronize() + # Lux + Mooncake setup. build_rrule compiles the reverse pass for these + # types (one-time cost per shape); first call afterwards still does some + # JIT, so we time both separately. + print("Lux+Mooncake build_rrule for h=$h ... "); flush(stdout) + t_lux_build = @elapsed lm = build_lux(W1g, W2g, Xg, yg) + @printf "%.2f s, " t_lux_build + print("first call ... "); flush(stdout) + t_lux_first = @elapsed begin + lux_grad(lm); CUDA.synchronize() end - @printf "%.2f s\n" t_lux_compile - grad_lux = Array(lux_grad(lux_model, lux_ps, lux_st, Xg, yg)) + @printf "%.2f s\n" t_lux_first + grad_lux = Array(lux_grad(lm)) CUDA.synchronize() # ----- PyTorch ----- @@ -384,7 +568,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # ----- Numerical equivalence ----- for (name, g) in [("Julia v4 (vec=4) ", grad_julia_v4), ("Julia v5 (vec=4+SIMT)", grad_julia_v5), - ("Lux + Zygote ", grad_lux), + ("Julia v6 (vec=4+Lt)", grad_julia_v6), + ("Lux + Mooncake ", grad_lux), ("PyTorch eager ", grad_pytorch_eager), ("PyTorch compiled ", grad_pytorch_compiled)] maxdiff = maximum(abs.(grad_julia .- g)) @@ -409,8 +594,12 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) reverse_diff_v5($W1g, $W2g, $Xg, $yg) CUDA.synchronize() end samples=30 evals=1 seconds=10 + bj6 = @benchmark begin + reverse_diff_v6($lt_plans, $W1g, $W2g, $Xg, $yg) + CUDA.synchronize() + end samples=30 evals=1 seconds=10 bjlux = @benchmark begin - lux_grad($lux_model, $lux_ps, $lux_st, $Xg, $yg) + lux_grad($lm) CUDA.synchronize() end samples=30 evals=1 seconds=10 be = @benchmark begin @@ -424,7 +613,8 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) @printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time @printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time @printf "Julia vec=4 + SIMT : median %8.3f µs\n" 1e-3 * median(bj5).time - @printf "Lux + Zygote : median %8.3f µs\n" 1e-3 * median(bjlux).time + @printf "Julia vec=4 + cuBLASLt: median %8.3f µs\n" 1e-3 * median(bj6).time + @printf "Lux + Mooncake : median %8.3f µs\n" 1e-3 * median(bjlux).time @printf "PyTorch eager : median %8.3f µs\n" 1e-3 * median(be).time @printf "PyTorch compiled : median %8.3f µs\n" 1e-3 * median(bc).time @@ -438,8 +628,11 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) println("\n--- CUDA trace: Julia vec=4 + SIMT ---") summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v5(W1g, W2g, Xg, yg))) - println("\n--- CUDA trace: Lux + Zygote ---") - summarize_julia_trace(stdout, julia_trace(() -> lux_grad(lux_model, lux_ps, lux_st, Xg, yg))) + println("\n--- CUDA trace: Julia vec=4 + cuBLASLt ---") + summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v6(lt_plans, W1g, W2g, Xg, yg))) + + println("\n--- CUDA trace: Lux + Mooncake ---") + summarize_julia_trace(stdout, julia_trace(() -> lux_grad(lm))) println("\n--- CUDA trace: PyTorch eager ---") println(pytorch_trace(() -> pytorch_grad_eager(W1t, W2t, Xt, yt))) From 0dca3381285d478c7f581f3c2c955a8533c670ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 5 May 2026 10:56:28 +0200 Subject: [PATCH 11/14] Add GPU support --- src/ArrayDiff.jl | 21 ++++-- src/mathoptinterface_api.jl | 9 ++- src/reverse_mode.jl | 140 ++++++++++++++++-------------------- src/sizes.jl | 55 +++++++++++--- src/types.jl | 37 +++++----- test/Optimisers_GPU.jl | 64 +++++++++++++++++ test/Project.toml | 2 + test/runtests.jl | 1 + 8 files changed, 217 insertions(+), 112 deletions(-) create mode 100644 test/Optimisers_GPU.jl diff --git a/src/ArrayDiff.jl b/src/ArrayDiff.jl index c5d4741..f4d5570 100644 --- a/src/ArrayDiff.jl +++ b/src/ArrayDiff.jl @@ -7,17 +7,25 @@ module ArrayDiff import ForwardDiff +import LinearAlgebra import MathOptInterface as MOI const Nonlinear = MOI.Nonlinear import SparseArrays import OrderedCollections """ - Mode() <: MOI.Nonlinear.AbstractAutomaticDifferentiation + Mode{S}() <: MOI.Nonlinear.AbstractAutomaticDifferentiation Fork of `MOI.Nonlinear.SparseReverseMode` to add array support. + +The type parameter `S` is the storage type used for the AD tape (forward, +partials, and reverse storage of each subexpression). It must satisfy +`S<:AbstractVector{Float64}`. Defaults to `Vector{Float64}`. Pass a different +`S` (for example `CuVector{Float64}`) to keep the tape on a GPU. """ -struct Mode <: MOI.Nonlinear.AbstractAutomaticDifferentiation end +struct Mode{S<:AbstractVector{Float64}} <: MOI.Nonlinear.AbstractAutomaticDifferentiation end + +Mode() = Mode{Vector{Float64}}() # Override basic math functions to return NaN instead of throwing errors. # This is what NLP solvers expect, and sometimes the results aren't needed @@ -56,7 +64,8 @@ include("evaluator.jl") include("array_nonlinear_function.jl") include("parse_moi.jl") -model(::Mode) = Model() +model(::Mode{S}) where {S} = Model() +storage_type(::Mode{S}) where {S} = S # Extend MOI.Nonlinear.set_objective so that solvers calling # MOI.Nonlinear.set_objective(arraydiff_model, snf) dispatch here. @@ -73,10 +82,10 @@ end # Create an ArrayDiff Evaluator from an ArrayDiff Model. function Evaluator( model::ArrayDiff.Model, - ::Mode, + ::Mode{S}, ordered_variables::Vector{MOI.VariableIndex}, -) - return Evaluator(model, NLPEvaluator(model, ordered_variables)) +) where {S<:AbstractVector{Float64}} + return Evaluator(model, NLPEvaluator{S}(model, ordered_variables)) end # Called by solvers via MOI.Nonlinear.Evaluator(nlp_model, ad_backend, vars). diff --git a/src/mathoptinterface_api.jl b/src/mathoptinterface_api.jl index 3043176..384a0b8 100644 --- a/src/mathoptinterface_api.jl +++ b/src/mathoptinterface_api.jl @@ -19,7 +19,7 @@ function MOI.features_available(d::NLPEvaluator) return [:Grad, :Jac, :JacVec, :Hess, :HessVec] end -function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) +function MOI.initialize(d::NLPEvaluator{S}, requested_features::Vector{Symbol}) where {S<:AbstractVector{Float64}} # Check that we support the features requested by the user. available_features = MOI.features_available(d) for feature in requested_features @@ -39,7 +39,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) d.residual = nothing d.user_output_buffer = zeros(largest_user_input_dimension) d.jac_storage = zeros(max(N, largest_user_input_dimension)) - d.constraints = _FunctionStorage[] + d.constraints = _FunctionStorage{S}[] d.last_x = fill(NaN, N) d.want_hess = :Hess in requested_features want_hess_storage = (:HessVec in requested_features) || d.want_hess @@ -63,7 +63,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) subexpression_variables = Vector{Vector{Int}}(undef, num_subexpressions) subexpression_edgelist = Vector{Set{Tuple{Int,Int}}}(undef, num_subexpressions) - d.subexpressions = Vector{_SubexpressionStorage}(undef, num_subexpressions) + d.subexpressions = Vector{_SubexpressionStorage{S}}(undef, num_subexpressions) d.subexpression_forward_values = zeros(num_subexpressions) d.subexpression_reverse_values = zeros(num_subexpressions) for k in d.subexpression_order @@ -75,6 +75,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) moi_index_to_consecutive_index, Float64[], d, + S, ) d.subexpressions[k] = subex d.subexpression_linearity[k] = subex.linearity @@ -115,6 +116,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) moi_index_to_consecutive_index, shared_partials_storage_ϵ, d, + S, ) objective = _FunctionStorage( subexpr, @@ -163,6 +165,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) moi_index_to_consecutive_index, shared_partials_storage_ϵ, d, + S, ) push!( d.constraints, diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 3a1dfab..6fe9060 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -165,20 +165,13 @@ function _forward_eval( idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - v1 = zeros(_size(f.sizes, ix1)...) - v2 = zeros(_size(f.sizes, ix2)...) - for j in _eachindex(f.sizes, ix1) - v1[j] = @j f.forward_storage[ix1] - @j f.partials_storage[ix2] = v1[j] - end - for j in _eachindex(f.sizes, ix2) - v2[j] = @j f.forward_storage[ix2] - @j f.partials_storage[ix1] = v2[j] - end - v_prod = v1 * v2 - for j in _eachindex(f.sizes, k) - @j f.forward_storage[k] = v_prod[j] - end + v1 = _view_array(f.forward_storage, f.sizes, ix1) + v2 = _view_array(f.forward_storage, f.sizes, ix2) + out = _view_array(f.forward_storage, f.sizes, k) + LinearAlgebra.mul!(out, v1, v2) + # We deliberately don't write v1/v2 into partials_storage + # here: the matmul reverse branch reads forward_storage + # directly, so those writes were dead. # Node `k` is scalar else tmp_prod = one(T) @@ -350,12 +343,9 @@ function _forward_eval( elseif node.index == 15 # sum @assert N == 1 ix = children_arr[first(children_indices)] - tmp_sum = zero(T) - for j in _eachindex(f.sizes, ix) - @j f.partials_storage[ix] = one(T) - tmp_sum += @j f.forward_storage[ix] - end - @s f.forward_storage[k] = tmp_sum + inp = _view_array(f.forward_storage, f.sizes, ix) + fill!(_view_array(f.partials_storage, f.sizes, ix), one(T)) + @s f.forward_storage[k] = sum(inp) elseif node.index == 16 # row for j in _eachindex(f.sizes, k) ix = children_arr[children_indices[j]] @@ -403,12 +393,12 @@ function _forward_eval( child1 = first(children_indices) @inbounds ix1 = children_arr[child1] @inbounds ix2 = children_arr[child1+1] - for j in _eachindex(f.sizes, k) - @j f.partials_storage[ix1] = one(T) - @j f.partials_storage[ix2] = -one(T) - @j f.forward_storage[k] = - @j(f.forward_storage[ix1]) - @j(f.forward_storage[ix2]) - end + out = _view_array(f.forward_storage, f.sizes, k) + v1 = _view_array(f.forward_storage, f.sizes, ix1) + v2 = _view_array(f.forward_storage, f.sizes, ix2) + out .= v1 .- v2 + fill!(_view_array(f.partials_storage, f.sizes, ix1), one(T)) + fill!(_view_array(f.partials_storage, f.sizes, ix2), -one(T)) elseif node.index == 3 # :* (broadcasted) # Node `k` is not scalar, so we do matrix multiplication if f.sizes.ndims[k] != 0 @@ -472,21 +462,19 @@ function _forward_eval( @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] @assert f.sizes.ndims[ix2] == 0 "Broadcasted ^ requires scalar exponent" - @inbounds exponent = - f.forward_storage[f.sizes.storage_offset[ix2]+1] - for j in _eachindex(f.sizes, k) - base = @j f.forward_storage[ix1] - if exponent == 2 - @j f.forward_storage[k] = base * base - @j f.partials_storage[ix1] = 2 * base - elseif exponent == 1 - @j f.forward_storage[k] = base - @j f.partials_storage[ix1] = one(T) - else - @j f.forward_storage[k] = pow(base, exponent) - @j f.partials_storage[ix1] = - exponent * pow(base, exponent - 1) - end + exponent = _scalar_load(f.forward_storage, f.sizes.storage_offset[ix2]+1) + out = _view_array(f.forward_storage, f.sizes, k) + inp = _view_array(f.forward_storage, f.sizes, ix1) + partials = _view_array(f.partials_storage, f.sizes, ix1) + if exponent == 2 + out .= inp .* inp + partials .= 2 .* inp + elseif exponent == 1 + out .= inp + fill!(partials, one(T)) + else + out .= pow.(inp, exponent) + partials .= exponent .* pow.(inp, exponent - 1) end end elseif node.type == NODE_CALL_UNIVARIATE @@ -526,6 +514,12 @@ function _forward_eval( val = @j f.forward_storage[child_idx] @j f.forward_storage[k] = -val end + elseif operators.univariate_operators[node.index] === :tanh + out = _view_array(f.forward_storage, f.sizes, k) + inp = _view_array(f.forward_storage, f.sizes, child_idx) + partials = _view_array(f.partials_storage, f.sizes, child_idx) + out .= tanh.(inp) + partials .= one(T) .- out .* out else for j in _eachindex(f.sizes, k) ret_f, ret_f′ = eval_univariate_function_and_gradient( @@ -611,31 +605,23 @@ function _reverse_eval( op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] if op == :* if f.sizes.ndims[k] != 0 - # Node `k` is not scalar, so we do matrix multiplication or broadcasted multiplication + # Matrix multiplication: rev_v1 = rev_parent * v2', + # rev_v2 = v1' * rev_parent. Both v1 and v2 are read + # straight from forward_storage (the matmul forward + # branch deliberately doesn't snapshot them into + # partials_storage), and the reverse views are written + # in place. idx1 = first(children_indices) idx2 = last(children_indices) ix1 = children_arr[idx1] ix2 = children_arr[idx2] - v1 = zeros(_size(f.sizes, ix1)...) - v2 = zeros(_size(f.sizes, ix2)...) - for j in _eachindex(f.sizes, ix1) - v1[j] = @j f.forward_storage[ix1] - end - for j in _eachindex(f.sizes, ix2) - v2[j] = @j f.forward_storage[ix2] - end - rev_parent = zeros(_size(f.sizes, k)...) - for j in _eachindex(f.sizes, k) - rev_parent[j] = @j f.reverse_storage[k] - end - rev_v1 = rev_parent * v2' - rev_v2 = v1' * rev_parent - for j in _eachindex(f.sizes, ix1) - @j f.reverse_storage[ix1] = rev_v1[j] - end - for j in _eachindex(f.sizes, ix2) - @j f.reverse_storage[ix2] = rev_v2[j] - end + v1 = _view_array(f.forward_storage, f.sizes, ix1) + v2 = _view_array(f.forward_storage, f.sizes, ix2) + rev_parent = _view_array(f.reverse_storage, f.sizes, k) + rev_v1 = _view_array(f.reverse_storage, f.sizes, ix1) + rev_v2 = _view_array(f.reverse_storage, f.sizes, ix2) + LinearAlgebra.mul!(rev_v1, rev_parent, v2') + LinearAlgebra.mul!(rev_v2, v1', rev_parent) continue end elseif op == :vect @@ -888,21 +874,21 @@ function _reverse_eval( continue end # Node `k` has same size as its children. - # The Jacobian (between the vectorized versions) is diagonal and the diagonal entries - # are stored in `f.partials_storage` - for j in _eachindex(f.sizes, k) - rev_parent = @j f.reverse_storage[k] - for child_idx in children_indices - ix = children_arr[child_idx] - @assert _size(f.sizes, k) == _size(f.sizes, ix) - partial = @j f.partials_storage[ix] - val = ifelse( - rev_parent == 0.0 && !isfinite(partial), - rev_parent, - rev_parent * partial, - ) - @j f.reverse_storage[ix] = val - end + # The Jacobian (between the vectorized versions) is diagonal and the + # diagonal entries are stored in `f.partials_storage`. We broadcast + # `rev_child .= rev_parent .* partial` over the whole array (with the + # 0 * Inf guard preserved). + rev_parent = _view_array(f.reverse_storage, f.sizes, k) + for child_idx in children_indices + ix = children_arr[child_idx] + @assert _size(f.sizes, k) == _size(f.sizes, ix) + rev_child = _view_array(f.reverse_storage, f.sizes, ix) + partial = _view_array(f.partials_storage, f.sizes, ix) + rev_child .= ifelse.( + (rev_parent .== 0) .& .!isfinite.(partial), + rev_parent, + rev_parent .* partial, + ) end end return diff --git a/src/sizes.jl b/src/sizes.jl index 3e8e3cc..04c45d5 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -57,6 +57,40 @@ function _setindex!(x, value, sizes::Sizes, k::Int, j) return x[sizes.storage_offset[k]+j] = value end +""" + _scalar_load(storage, idx) -> Float64 + +Read a single Float64 from `storage` at linear index `idx`. The default +implementation just calls `getindex`; this is a hook for storage backends +(such as `CuVector`) that disallow scalar indexing and need to dispatch to a +1-element transfer instead. +""" +_scalar_load(storage::AbstractVector, idx::Int) = @inbounds storage[idx] + +""" + _view_array(storage, sizes, k) -> AbstractArray + +Return a view of the slice of `storage` that holds node `k`'s array value, +reshaped to that node's natural shape. The view aliases the underlying +`storage` (no copy), so mutating the returned array writes back into the tape. +For a scalar (`ndims[k] == 0`) node this returns a length-1 vector view. +""" +function _view_array(storage::AbstractVector, sizes::Sizes, k::Int) + nd = sizes.ndims[k] + offset = sizes.storage_offset[k] + if nd == 0 + return view(storage, (offset+1):(offset+1)) + elseif nd == 1 + n = sizes.size[sizes.size_offset[k]+1] + return view(storage, (offset+1):(offset+n)) + else + N = _length(sizes, k) + v = view(storage, (offset+1):(offset+N)) + szs = ntuple(d -> sizes.size[sizes.size_offset[k]+d], nd) + return reshape(v, szs) + end +end + """ @s(storage[node]) -> _getscalar(storage, f.sizes, node) @s(storage[node] = value) -> _setscalar!(storage, value, f.sizes, node) @@ -366,14 +400,14 @@ function _infer_sizes( return sizes end -struct _SubexpressionStorage +struct _SubexpressionStorage{S<:AbstractVector{Float64}} nodes::Vector{Node} adj::SparseArrays.SparseMatrixCSC{Bool,Int} sizes::Sizes const_values::Vector{Float64} - forward_storage::Vector{Float64} - partials_storage::Vector{Float64} - reverse_storage::Vector{Float64} + forward_storage::S + partials_storage::S + reverse_storage::S partials_storage_ϵ::Vector{Float64} linearity::Linearity @@ -383,17 +417,18 @@ struct _SubexpressionStorage const_values::Vector{Float64}, partials_storage_ϵ::Vector{Float64}, linearity::Linearity, - ) + ::Type{S} = Vector{Float64}, + ) where {S<:AbstractVector{Float64}} sizes = _infer_sizes(nodes, adj) N = _length(sizes) - return new( + return new{S}( nodes, adj, - _infer_sizes(nodes, adj), + sizes, const_values, - zeros(N), # forward_storage, - zeros(N), # partials_storage, - zeros(N), # reverse_storage, + fill!(S(undef, N), 0.0), # forward_storage, + fill!(S(undef, N), 0.0), # partials_storage, + fill!(S(undef, N), 0.0), # reverse_storage, partials_storage_ϵ, linearity, ) diff --git a/src/types.jl b/src/types.jl index 58358b2..8ac67d5 100644 --- a/src/types.jl +++ b/src/types.jl @@ -90,7 +90,8 @@ function _subexpression_and_linearity( moi_index_to_consecutive_index, partials_storage_ϵ::Vector{Float64}, d, -) + ::Type{S} = Vector{Float64}, +) where {S<:AbstractVector{Float64}} nodes = _replace_moi_variables(expr.nodes, moi_index_to_consecutive_index) adj = adjacency_matrix(nodes) linearity = if d.want_hess @@ -104,12 +105,13 @@ function _subexpression_and_linearity( convert(Vector{Float64}, expr.values), partials_storage_ϵ, linearity[1], + S, ), linearity end -struct _FunctionStorage - expr::_SubexpressionStorage +struct _FunctionStorage{S<:AbstractVector{Float64}} + expr::_SubexpressionStorage{S} grad_sparsity::Vector{Int} # Nonzero pattern of Hessian matrix hess_I::Vector{Int} @@ -120,16 +122,16 @@ struct _FunctionStorage dependent_subexpressions::Vector{Int} function _FunctionStorage( - expr::_SubexpressionStorage, + expr::_SubexpressionStorage{S}, num_variables, coloring_storage::Coloring.IndexedSet, want_hess::Bool, - subexpressions::Vector{_SubexpressionStorage}, + subexpressions::Vector{_SubexpressionStorage{S}}, dependent_subexpressions, subexpression_edgelist, subexpression_variables, linearity::Vector{Linearity}, - ) + ) where {S<:AbstractVector{Float64}} empty!(coloring_storage) _compute_gradient_sparsity!(coloring_storage, expr.nodes) for k in dependent_subexpressions @@ -154,7 +156,7 @@ struct _FunctionStorage coloring_storage, ) seed_matrix = Coloring.seed_matrix(rinfo) - return new( + return new{S}( expr, grad_sparsity, hess_I, @@ -164,7 +166,7 @@ struct _FunctionStorage dependent_subexpressions, ) else - return new( + return new{S}( expr, grad_sparsity, Int[], @@ -293,14 +295,14 @@ interface. !!! warning Before using, you must initialize the evaluator using `MOI.initialize`. """ -mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator +mutable struct NLPEvaluator{S<:AbstractVector{Float64}} <: MOI.AbstractNLPEvaluator data::Model ordered_variables::Vector{MOI.VariableIndex} - objective::Union{Nothing,_FunctionStorage} - residual::Union{Nothing,_FunctionStorage} - constraints::Vector{_FunctionStorage} - subexpressions::Vector{_SubexpressionStorage} + objective::Union{Nothing,_FunctionStorage{S}} + residual::Union{Nothing,_FunctionStorage{S}} + constraints::Vector{_FunctionStorage{S}} + subexpressions::Vector{_SubexpressionStorage{S}} subexpression_order::Vector{Int} # Storage for the subexpressions in reverse-mode automatic differentiation. subexpression_forward_values::Vector{Float64} @@ -330,10 +332,13 @@ mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator hessian_sparsity::Vector{Tuple{Int64,Int64}} max_chunk::Int # chunk size for which we've allocated storage - function NLPEvaluator( + function NLPEvaluator{S}( data::Model, ordered_variables::Vector{MOI.VariableIndex}, - ) - return new(data, ordered_variables) + ) where {S<:AbstractVector{Float64}} + return new{S}(data, ordered_variables) end end + +NLPEvaluator(data::Model, ordered_variables::Vector{MOI.VariableIndex}) = + NLPEvaluator{Vector{Float64}}(data, ordered_variables) diff --git a/test/Optimisers_GPU.jl b/test/Optimisers_GPU.jl new file mode 100644 index 0000000..762b8f9 --- /dev/null +++ b/test/Optimisers_GPU.jl @@ -0,0 +1,64 @@ +module TestWithOptimisersGPU + +using Test + +using JuMP +using ArrayDiff +import CUDA +import LinearAlgebra +import MathOptInterface as MOI +import NLPModelsJuMP + +include(joinpath(@__DIR__, "OptimisersSolver.jl")) + +function runtests() + if !CUDA.functional() + @info "CUDA is not functional in this environment; skipping GPU tests." + return + end + for name in names(@__MODULE__; all = true) + if startswith("$(name)", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + +function test_neural_optimisers_gpu() + n = 2 + X = [1.0 0.5; 0.3 0.8] + target = [0.5 0.2; 0.1 0.7] + model = Model(NLPModelsJuMP.Optimizer) + set_attribute(model, "solver", OptimisersSolver) + set_attribute( + model, + MOI.AutomaticDifferentiationBackend(), + ArrayDiff.Mode{CUDA.CuVector{Float64}}(), + ) + @variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + @variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + start_W1 = [0.3 -0.2; 0.1 0.4] + start_W2 = [-0.1 0.5; 0.2 -0.3] + for i in 1:n, j in 1:n + set_start_value(W1[i, j], start_W1[i, j]) + set_start_value(W2[i, j], start_W2[i, j]) + end + Y = W2 * tanh.(W1 * X) + loss = sum((Y .- target) .^ 2) + @objective(model, Min, loss) + set_attribute(model, "max_iter", 20_000) + set_attribute(model, "tol", 1e-6) + # The variable-load and gradient-extract paths still do scalar reads/writes + # against the GPU-resident tape (forward_storage, reverse_storage). Those + # are what `@allowscalar` permits. They will be batched in a follow-up; + # for now this test is a correctness check, not a performance benchmark. + CUDA.@allowscalar optimize!(model) + @test objective_value(model) < 1e-3 + return +end + +end + +TestWithOptimisersGPU.runtests() diff --git a/test/Project.toml b/test/Project.toml index 68b2a1c..c46b933 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] ArrayDiff = "c45fa1ca-6901-44ac-ae5b-5513a4852d50" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" GenOpt = "f2c049d8-7489-4223-990c-4f1c121a4cde" JSOSolvers = "10dff2fc-5484-5881-a0e0-c90441020f8a" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" diff --git a/test/runtests.jl b/test/runtests.jl index fb70b82..e5aa663 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,5 @@ if VERSION >= v"1.11" # Needs https://github.com/JuliaSmoothOptimizers/NLPModelsJuMP.jl/pull/229 include("NLPModelsJuMP.jl") include("Optimisers.jl") + include("Optimisers_GPU.jl") end From 0cf58bebf977dbb658b7fa745aa5ae5b8d2196e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 5 May 2026 17:46:37 +0200 Subject: [PATCH 12/14] Fix --- src/ArrayDiff.jl | 3 ++- src/mathoptinterface_api.jl | 8 ++++++-- src/reverse_mode.jl | 5 ++++- src/types.jl | 8 +++++--- test/Project.toml | 7 +++++++ 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/ArrayDiff.jl b/src/ArrayDiff.jl index f4d5570..2856161 100644 --- a/src/ArrayDiff.jl +++ b/src/ArrayDiff.jl @@ -23,7 +23,8 @@ partials, and reverse storage of each subexpression). It must satisfy `S<:AbstractVector{Float64}`. Defaults to `Vector{Float64}`. Pass a different `S` (for example `CuVector{Float64}`) to keep the tape on a GPU. """ -struct Mode{S<:AbstractVector{Float64}} <: MOI.Nonlinear.AbstractAutomaticDifferentiation end +struct Mode{S<:AbstractVector{Float64}} <: + MOI.Nonlinear.AbstractAutomaticDifferentiation end Mode() = Mode{Vector{Float64}}() diff --git a/src/mathoptinterface_api.jl b/src/mathoptinterface_api.jl index 384a0b8..bab095a 100644 --- a/src/mathoptinterface_api.jl +++ b/src/mathoptinterface_api.jl @@ -19,7 +19,10 @@ function MOI.features_available(d::NLPEvaluator) return [:Grad, :Jac, :JacVec, :Hess, :HessVec] end -function MOI.initialize(d::NLPEvaluator{S}, requested_features::Vector{Symbol}) where {S<:AbstractVector{Float64}} +function MOI.initialize( + d::NLPEvaluator{S}, + requested_features::Vector{Symbol}, +) where {S<:AbstractVector{Float64}} # Check that we support the features requested by the user. available_features = MOI.features_available(d) for feature in requested_features @@ -63,7 +66,8 @@ function MOI.initialize(d::NLPEvaluator{S}, requested_features::Vector{Symbol}) subexpression_variables = Vector{Vector{Int}}(undef, num_subexpressions) subexpression_edgelist = Vector{Set{Tuple{Int,Int}}}(undef, num_subexpressions) - d.subexpressions = Vector{_SubexpressionStorage{S}}(undef, num_subexpressions) + d.subexpressions = + Vector{_SubexpressionStorage{S}}(undef, num_subexpressions) d.subexpression_forward_values = zeros(num_subexpressions) d.subexpression_reverse_values = zeros(num_subexpressions) for k in d.subexpression_order diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 6fe9060..70489f7 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -462,7 +462,10 @@ function _forward_eval( @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] @assert f.sizes.ndims[ix2] == 0 "Broadcasted ^ requires scalar exponent" - exponent = _scalar_load(f.forward_storage, f.sizes.storage_offset[ix2]+1) + exponent = _scalar_load( + f.forward_storage, + f.sizes.storage_offset[ix2]+1, + ) out = _view_array(f.forward_storage, f.sizes, k) inp = _view_array(f.forward_storage, f.sizes, ix1) partials = _view_array(f.partials_storage, f.sizes, ix1) diff --git a/src/types.jl b/src/types.jl index 8ac67d5..d043b52 100644 --- a/src/types.jl +++ b/src/types.jl @@ -295,7 +295,8 @@ interface. !!! warning Before using, you must initialize the evaluator using `MOI.initialize`. """ -mutable struct NLPEvaluator{S<:AbstractVector{Float64}} <: MOI.AbstractNLPEvaluator +mutable struct NLPEvaluator{S<:AbstractVector{Float64}} <: + MOI.AbstractNLPEvaluator data::Model ordered_variables::Vector{MOI.VariableIndex} @@ -340,5 +341,6 @@ mutable struct NLPEvaluator{S<:AbstractVector{Float64}} <: MOI.AbstractNLPEvalua end end -NLPEvaluator(data::Model, ordered_variables::Vector{MOI.VariableIndex}) = - NLPEvaluator{Vector{Float64}}(data, ordered_variables) +function NLPEvaluator(data::Model, ordered_variables::Vector{MOI.VariableIndex}) + return NLPEvaluator{Vector{Float64}}(data, ordered_variables) +end diff --git a/test/Project.toml b/test/Project.toml index c46b933..7788561 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,6 +19,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[compat] +# NLPModelsModifiers (a transitive dep of JSOSolvers) calls +# `@default_counters Model inner (excluded,)`, which only became a 3-arg +# macro in NLPModels 0.21.12. Older 0.21.x versions only define the 2-arg +# form and break precompilation. +NLPModels = "0.21.12" + [sources] ArrayDiff = {path = ".."} NLopt = {url = "https://github.com/jump-dev/NLopt.jl/", rev = "bl/diff_backend"} From f0fe786961b79774b2920245741d68562ece0542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 5 May 2026 20:48:29 +0200 Subject: [PATCH 13/14] Fix --- perf/cuda_vs_pytorch.jl | 380 ++++++++++++++++++++++++++++------------ 1 file changed, 264 insertions(+), 116 deletions(-) diff --git a/perf/cuda_vs_pytorch.jl b/perf/cuda_vs_pytorch.jl index cd0d3f6..e216088 100644 --- a/perf/cuda_vs_pytorch.jl +++ b/perf/cuda_vs_pytorch.jl @@ -55,69 +55,91 @@ end # NTuple{4, Base.VecElement{Float32}} loads via Core.LLVMPtr lower to # `ld.global.v4` PTX. # ------------------------------------------------------------------------- -const Float4 = NTuple{4, Base.VecElement{Float32}} +const Float4 = NTuple{4,Base.VecElement{Float32}} @inline _f4ptr(arr::CuDeviceArray{Float32}) = - reinterpret(Core.LLVMPtr{Float4, AS.Global}, pointer(arr)) + reinterpret(Core.LLVMPtr{Float4,AS.Global}, pointer(arr)) -@inline _vec4(t1, t2, t3, t4) = - (Base.VecElement(t1), Base.VecElement(t2), Base.VecElement(t3), Base.VecElement(t4)) +@inline _vec4(t1, t2, t3, t4) = ( + Base.VecElement(t1), + Base.VecElement(t2), + Base.VecElement(t3), + Base.VecElement(t4), +) function _tanh_and_jac_kernel!(y, J, x, n::Int) - i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x base = 4 * (i - 1) if base + 4 <= n - v = unsafe_load(_f4ptr(x), i, Val(16)) - t1 = tanh(v[1].value); t2 = tanh(v[2].value) - t3 = tanh(v[3].value); t4 = tanh(v[4].value) + v = unsafe_load(_f4ptr(x), i, Val(16)) + t1 = tanh(v[1].value); + t2 = tanh(v[2].value) + t3 = tanh(v[3].value); + t4 = tanh(v[4].value) unsafe_store!(_f4ptr(y), _vec4(t1, t2, t3, t4), i, Val(16)) - unsafe_store!(_f4ptr(J), - _vec4(1f0 - t1*t1, 1f0 - t2*t2, 1f0 - t3*t3, 1f0 - t4*t4), i, Val(16)) + unsafe_store!( + _f4ptr(J), + _vec4(1.0f0 - t1*t1, 1.0f0 - t2*t2, 1.0f0 - t3*t3, 1.0f0 - t4*t4), + i, + Val(16), + ) elseif base < n - for k in 1:(n - base) - @inbounds t = tanh(x[base + k]) - @inbounds y[base + k] = t - @inbounds J[base + k] = 1f0 - t*t + for k in 1:(n-base) + @inbounds t = tanh(x[base+k]) + @inbounds y[base+k] = t + @inbounds J[base+k] = 1.0f0 - t*t end end return nothing end -function tanh_and_jac!(y::CuArray{Float32}, J::CuArray{Float32}, x::CuArray{Float32}) - n = length(x) +function tanh_and_jac!( + y::CuArray{Float32}, + J::CuArray{Float32}, + x::CuArray{Float32}, +) + n = length(x) threads = 256 - blocks = cld(cld(n, 4), threads) + blocks = cld(cld(n, 4), threads) @cuda threads=threads blocks=blocks _tanh_and_jac_kernel!(y, J, x, n) return (y, J) end function _vmul_kernel!(out, a, b, n::Int) - i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x base = 4 * (i - 1) if base + 4 <= n va = unsafe_load(_f4ptr(a), i, Val(16)) vb = unsafe_load(_f4ptr(b), i, Val(16)) - unsafe_store!(_f4ptr(out), - _vec4(va[1].value*vb[1].value, va[2].value*vb[2].value, - va[3].value*vb[3].value, va[4].value*vb[4].value), i, Val(16)) + unsafe_store!( + _f4ptr(out), + _vec4( + va[1].value*vb[1].value, + va[2].value*vb[2].value, + va[3].value*vb[3].value, + va[4].value*vb[4].value, + ), + i, + Val(16), + ) elseif base < n - for k in 1:(n - base) - @inbounds out[base + k] = a[base + k] * b[base + k] + for k in 1:(n-base) + @inbounds out[base+k] = a[base+k] * b[base+k] end end return nothing end function vmul!(out::CuArray{Float32}, a::CuArray{Float32}, b::CuArray{Float32}) - n = length(a) + n = length(a) threads = 256 - blocks = cld(cld(n, 4), threads) + blocks = cld(cld(n, 4), threads) @cuda threads=threads blocks=blocks _vmul_kernel!(out, a, b, n) return out end function reverse_diff_v4(W1, W2, X, y) - Z1 = W1 * X # GEMM (h, n) + Z1 = W1 * X # GEMM (h, n) y_1 = similar(Z1) J_1 = similar(Z1) tanh_and_jac!(y_1, J_1, Z1) # fused tanh + (1 - y²), vec=4 @@ -138,18 +160,25 @@ end # slower than a SIMT FP32 GEMM. We bypass gemmExComputeType by calling # cublasGemmEx directly with CUBLAS_COMPUTE_32F + CUBLAS_GEMM_DEFAULT. # ------------------------------------------------------------------------- -function _gemm_simt!(C::CuArray{Float32,2}, transA::Char, A::CuArray{Float32,2}, - transB::Char, B::CuArray{Float32,2}; - alpha::Float32 = 1f0, beta::Float32 = 0f0) - m = size(A, transA == 'N' ? 1 : 2) - k = size(A, transA == 'N' ? 2 : 1) - n = size(B, transB == 'N' ? 2 : 1) +function _gemm_simt!( + C::CuArray{Float32,2}, + transA::Char, + A::CuArray{Float32,2}, + transB::Char, + B::CuArray{Float32,2}; + alpha::Float32 = 1.0f0, + beta::Float32 = 0.0f0, +) + m = size(A, transA == 'N' ? 1 : 2) + k = size(A, transA == 'N' ? 2 : 1) + n = size(B, transB == 'N' ? 2 : 1) lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) ldc = max(1, stride(C, 2)) # CUDA.jl puts the cuBLAS handle in CUBLAS_POINTER_MODE_DEVICE, so alpha/beta # MUST be device pointers (host Ref triggers UVA fault handling — 100× slowdown). - α = CUDA.CuRef{Float32}(alpha); β = CUDA.CuRef{Float32}(beta) + α = CUDA.CuRef{Float32}(alpha); + β = CUDA.CuRef{Float32}(beta) h = CUDA.CUBLAS.handle() # Under FAST_MATH the handle's math mode is CUBLAS_TF32_TENSOR_OP_MATH, which # forces TF32 tensor cores even when we ask for CUBLAS_COMPUTE_32F. Flip it to @@ -157,10 +186,23 @@ function _gemm_simt!(C::CuArray{Float32,2}, transA::Char, A::CuArray{Float32,2}, CUDA.CUBLAS.math_mode!(h, CUDA.DEFAULT_MATH) try CUDA.CUBLAS.cublasGemmEx( - h, transA, transB, m, n, k, - α, A, Float32, lda, - B, Float32, ldb, - β, C, Float32, ldc, + h, + transA, + transB, + m, + n, + k, + α, + A, + Float32, + lda, + B, + Float32, + ldb, + β, + C, + Float32, + ldc, CUDA.CUBLAS.CUBLAS_COMPUTE_32F, CUDA.CUBLAS.CUBLAS_GEMM_DEFAULT, ) @@ -172,7 +214,7 @@ end function reverse_diff_v5(W1, W2, X, y) h, d = size(W1) - nn = size(X, 2) + nn = size(X, 2) Z1 = CuArray{Float32}(undef, h, nn) _gemm_simt!(Z1, 'N', W1, 'N', X) # SIMT: (h,d) * (d,n) @@ -230,15 +272,26 @@ mutable struct LtPlan algo::CUDA.CUBLAS.cublasLtMatmulAlgo_t end -function _build_lt_plan(transA::Char, transB::Char, - m::Int, n::Int, k::Int, - lda::Int, ldb::Int, ldc::Int) - state = _lt_state() +function _build_lt_plan( + transA::Char, + transB::Char, + m::Int, + n::Int, + k::Int, + lda::Int, + ldb::Int, + ldc::Int, +) + state = _lt_state() handle = state.handle - R32 = CUDA.CUDACore.R_32F # cudaDataType for Float32 + R32 = CUDA.CUDACore.R_32F # cudaDataType for Float32 desc_ref = Ref{CUDA.CUBLAS.cublasLtMatmulDesc_t}(C_NULL) - CUDA.CUBLAS.cublasLtMatmulDescCreate(desc_ref, CUDA.CUBLAS.CUBLAS_COMPUTE_32F, R32) + CUDA.CUBLAS.cublasLtMatmulDescCreate( + desc_ref, + CUDA.CUBLAS.CUBLAS_COMPUTE_32F, + R32, + ) desc = desc_ref[] # Set transpose attributes. @@ -246,11 +299,19 @@ function _build_lt_plan(transA::Char, transB::Char, tB = (transB == 'N') ? CUDA.CUBLAS.CUBLAS_OP_N : CUDA.CUBLAS.CUBLAS_OP_T let r = Ref(tA) CUDA.CUBLAS.cublasLtMatmulDescSetAttribute( - desc, CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSA, r, sizeof(tA)) + desc, + CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSA, + r, + sizeof(tA), + ) end let r = Ref(tB) CUDA.CUBLAS.cublasLtMatmulDescSetAttribute( - desc, CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSB, r, sizeof(tB)) + desc, + CUDA.CUBLAS.CUBLASLT_MATMUL_DESC_TRANSB, + r, + sizeof(tB), + ) end # Layout shape is the *storage* shape (pre-transpose). @@ -262,9 +323,27 @@ function _build_lt_plan(transA::Char, transB::Char, Aref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL) Bref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL) Cref = Ref{CUDA.CUBLAS.cublasLtMatrixLayout_t}(C_NULL) - CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Aref, R32, UInt64(Arows), UInt64(Acols), Int64(lda)) - CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Bref, R32, UInt64(Brows), UInt64(Bcols), Int64(ldb)) - CUDA.CUBLAS.cublasLtMatrixLayoutCreate(Cref, R32, UInt64(m), UInt64(n), Int64(ldc)) + CUDA.CUBLAS.cublasLtMatrixLayoutCreate( + Aref, + R32, + UInt64(Arows), + UInt64(Acols), + Int64(lda), + ) + CUDA.CUBLAS.cublasLtMatrixLayoutCreate( + Bref, + R32, + UInt64(Brows), + UInt64(Bcols), + Int64(ldb), + ) + CUDA.CUBLAS.cublasLtMatrixLayoutCreate( + Cref, + R32, + UInt64(m), + UInt64(n), + Int64(ldc), + ) # Preference: tell the heuristic how much workspace it can use. pref_ref = Ref{CUDA.CUBLAS.cublasLtMatmulPreference_t}(C_NULL) @@ -272,39 +351,66 @@ function _build_lt_plan(transA::Char, transB::Char, pref = pref_ref[] let r = Ref(_LT_WS_BYTES) CUDA.CUBLAS.cublasLtMatmulPreferenceSetAttribute( - pref, CUDA.CUBLAS.CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - r, sizeof(_LT_WS_BYTES)) + pref, + CUDA.CUBLAS.CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + r, + sizeof(_LT_WS_BYTES), + ) end # Heuristic: top-1 algorithm. - heur = Vector{CUDA.CUBLAS.cublasLtMatmulHeuristicResult_t}(undef, 1) + heur = Vector{CUDA.CUBLAS.cublasLtMatmulHeuristicResult_t}(undef, 1) returned = Ref{Cint}(0) CUDA.CUBLAS.cublasLtMatmulAlgoGetHeuristic( - handle, desc, Aref[], Bref[], Cref[], Cref[], - pref, Cint(1), heur, returned) - returned[] < 1 && error("cuBLASLt has no algo for shape (m=$m,n=$n,k=$k,trans=$transA$transB)") + handle, + desc, + Aref[], + Bref[], + Cref[], + Cref[], + pref, + Cint(1), + heur, + returned, + ) + returned[] < 1 && error( + "cuBLASLt has no algo for shape (m=$m,n=$n,k=$k,trans=$transA$transB)", + ) return LtPlan(desc, Aref[], Bref[], Cref[], heur[1].algo) end -function _gemm_lt!(plan::LtPlan, - C::CuArray{Float32,2}, A::CuArray{Float32,2}, B::CuArray{Float32,2}; - alpha::Float32 = 1f0, beta::Float32 = 0f0) - state = _lt_state() +function _gemm_lt!( + plan::LtPlan, + C::CuArray{Float32,2}, + A::CuArray{Float32,2}, + B::CuArray{Float32,2}; + alpha::Float32 = 1.0f0, + beta::Float32 = 0.0f0, +) + state = _lt_state() # cuBLASLt's matmul descriptor defaults to CUBLASLT_POINTER_MODE_HOST # (independent of the cuBLAS handle's pointer mode), so alpha/beta are # plain host Refs here — using CuRef would trigger UVA faults. - α = Ref{Float32}(alpha) - β = Ref{Float32}(beta) + α = Ref{Float32}(alpha) + β = Ref{Float32}(beta) algo_ref = Ref(plan.algo) CUDA.CUBLAS.cublasLtMatmul( - state.handle, plan.desc, - α, A, plan.Adesc, - B, plan.Bdesc, - β, C, plan.Cdesc, - C, plan.Cdesc, # D = C in place + state.handle, + plan.desc, + α, + A, + plan.Adesc, + B, + plan.Bdesc, + β, + C, + plan.Cdesc, + C, + plan.Cdesc, # D = C in place algo_ref, - state.ws, sizeof(state.ws), + state.ws, + sizeof(state.ws), CUDA.stream(), ) return C @@ -317,10 +423,13 @@ struct LtPlans p3::LtPlan # out * X' : (h,n) * store (d,n),'T' → (h,d) end -function build_lt_plans(W1::CuArray{Float32,2}, W2::CuArray{Float32,2}, - X::CuArray{Float32,2}) +function build_lt_plans( + W1::CuArray{Float32,2}, + W2::CuArray{Float32,2}, + X::CuArray{Float32,2}, +) h, d = size(W1) - nn = size(X, 2) + nn = size(X, 2) p1 = _build_lt_plan('N', 'N', h, nn, d, h, d, h) p2 = _build_lt_plan('T', 'N', h, nn, 1, 1, 1, h) p3 = _build_lt_plan('N', 'T', h, d, nn, h, d, h) @@ -329,7 +438,7 @@ end function reverse_diff_v6(plans::LtPlans, W1, W2, X, y) h, d = size(W1) - nn = size(X, 2) + nn = size(X, 2) Z1 = CuArray{Float32}(undef, h, nn) _gemm_lt!(plans.p1, Z1, W1, X) @@ -369,17 +478,18 @@ struct LuxMooncake{M,P,S,L,R} rule::R end -function build_lux(W1g::CuArray{Float32,2}, W2g::CuArray{Float32,2}, - Xg::CuArray, yg::CuArray) - h, d = size(W1g) +function build_lux( + W1g::CuArray{Float32,2}, + W2g::CuArray{Float32,2}, + Xg::CuArray, + yg::CuArray, +) + h, d = size(W1g) model = Lux.Chain( Lux.Dense(d => h, tanh; use_bias = false), Lux.Dense(h => 1, identity; use_bias = false), ) - ps = ( - layer_1 = (weight = W1g,), - layer_2 = (weight = W2g,), - ) + ps = (layer_1 = (weight = W1g,), layer_2 = (weight = W2g,)) st = Lux.initialstates(Random.default_rng(), model) # Closure captures Xg, yg, model, st — only `p` is the differentiated arg. @@ -405,8 +515,8 @@ end # ------------------------------------------------------------------------- # PyTorch path # ------------------------------------------------------------------------- -const torch = pyimport("torch") -const np = pyimport("numpy") +const torch = pyimport("torch") +const np = pyimport("numpy") const profiler = pyimport("torch.profiler") # Build torch tensors once and reuse them across benchmark iterations, @@ -414,12 +524,12 @@ const profiler = pyimport("torch.profiler") function build_torch_tensors(W1::Matrix, W2::Matrix, X::Matrix, y::Matrix) npW1 = np.ascontiguousarray(np.asarray(PyArray(W1))) npW2 = np.ascontiguousarray(np.asarray(PyArray(W2))) - npX = np.ascontiguousarray(np.asarray(PyArray(X))) - npY = np.ascontiguousarray(np.asarray(PyArray(y))) + npX = np.ascontiguousarray(np.asarray(PyArray(X))) + npY = np.ascontiguousarray(np.asarray(PyArray(y))) W1t = torch.from_numpy(npW1).to("cuda").requires_grad_(true) W2t = torch.from_numpy(npW2).to("cuda") - Xt = torch.from_numpy(npX).to("cuda") - yt = torch.from_numpy(npY).to("cuda") + Xt = torch.from_numpy(npX).to("cuda") + yt = torch.from_numpy(npY).to("cuda") return W1t, W2t, Xt, yt end @@ -447,7 +557,7 @@ _compiled = torch.compile(_eager) (nt._eager, nt._compiled) end -pytorch_grad_eager(W1t, W2t, Xt, yt) = _grad_fn_eager(W1t, W2t, Xt, yt) +pytorch_grad_eager(W1t, W2t, Xt, yt) = _grad_fn_eager(W1t, W2t, Xt, yt) pytorch_grad_compiled(W1t, W2t, Xt, yt) = _grad_fn_compiled(W1t, W2t, Xt, yt) torch_to_julia(t) = pyconvert(Array, t.detach().cpu().numpy()) @@ -457,7 +567,8 @@ torch_to_julia(t) = pyconvert(Array, t.detach().cpu().numpy()) # ------------------------------------------------------------------------- function julia_trace(f) # Warmup so JIT + cuBLAS handle init don't show up. - f(); CUDA.synchronize() + f(); + CUDA.synchronize() return CUDA.@profile trace = true begin f() CUDA.synchronize() @@ -471,10 +582,10 @@ function summarize_julia_trace(io::IO, trace) dev = trace.device names_ = dev.name starts = dev.start - stops = dev.stop + stops = dev.stop counts = Dict{String,Int}() totals = Dict{String,Float64}() # in seconds - order = String[] + order = String[] for i in eachindex(names_) nm = String(names_[i]) if !haskey(counts, nm) @@ -495,7 +606,8 @@ function summarize_julia_trace(io::IO, trace) end function pytorch_trace(f) - f(); torch.cuda.synchronize() # warmup + f(); + torch.cuda.synchronize() # warmup ProfilerActivity = profiler.ProfilerActivity prof = profiler.profile(activities = pylist([ProfilerActivity.CUDA])) prof.__enter__() @@ -513,7 +625,7 @@ const _pygc = pyimport("gc") # ------------------------------------------------------------------------- # Benchmark + verify for one (h, d, n) # ------------------------------------------------------------------------- -function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) +function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1.0f-3) println("\n" * "="^72) @printf "h = %d, d = %d, n = %d\n" h d n println("="^72) @@ -521,15 +633,19 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) Random.seed!(0) W1 = randn(Float32, h, d) W2 = randn(Float32, 1, h) - X = randn(Float32, d, n) - y = randn(Float32, 1, n) + X = randn(Float32, d, n) + y = randn(Float32, 1, n) # ----- Julia / CUDA.jl ----- - W1g = CuArray(W1); W2g = CuArray(W2); Xg = CuArray(X); yg = CuArray(y) - grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg)) + W1g = CuArray(W1); + W2g = CuArray(W2); + Xg = CuArray(X); + yg = CuArray(y) + grad_julia = Array(reverse_diff(W1g, W2g, Xg, yg)) grad_julia_v4 = Array(reverse_diff_v4(W1g, W2g, Xg, yg)) grad_julia_v5 = Array(reverse_diff_v5(W1g, W2g, Xg, yg)) - print("cuBLASLt build_lt_plans for h=$h ... "); flush(stdout) + print("cuBLASLt build_lt_plans for h=$h ... "); + flush(stdout) t_lt_build = @elapsed lt_plans = build_lt_plans(W1g, W2g, Xg) @printf "%.3f s\n" t_lt_build grad_julia_v6 = Array(reverse_diff_v6(lt_plans, W1g, W2g, Xg, yg)) @@ -538,12 +654,15 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # Lux + Mooncake setup. build_rrule compiles the reverse pass for these # types (one-time cost per shape); first call afterwards still does some # JIT, so we time both separately. - print("Lux+Mooncake build_rrule for h=$h ... "); flush(stdout) + print("Lux+Mooncake build_rrule for h=$h ... "); + flush(stdout) t_lux_build = @elapsed lm = build_lux(W1g, W2g, Xg, yg) @printf "%.2f s, " t_lux_build - print("first call ... "); flush(stdout) + print("first call ... "); + flush(stdout) t_lux_first = @elapsed begin - lux_grad(lm); CUDA.synchronize() + lux_grad(lm); + CUDA.synchronize() end @printf "%.2f s\n" t_lux_first grad_lux = Array(lux_grad(lm)) @@ -556,25 +675,29 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # First call to the compiled fn for this shape triggers Inductor codegen # (can take seconds). Time it so the user knows. - print("torch.compile codegen for h=$h ... "); flush(stdout) + print("torch.compile codegen for h=$h ... "); + flush(stdout) t_compile = @elapsed begin pytorch_grad_compiled(W1t, W2t, Xt, yt) torch.cuda.synchronize() end @printf "%.2f s\n" t_compile - grad_pytorch_compiled = torch_to_julia(pytorch_grad_compiled(W1t, W2t, Xt, yt)) + grad_pytorch_compiled = + torch_to_julia(pytorch_grad_compiled(W1t, W2t, Xt, yt)) torch.cuda.synchronize() # ----- Numerical equivalence ----- - for (name, g) in [("Julia v4 (vec=4) ", grad_julia_v4), - ("Julia v5 (vec=4+SIMT)", grad_julia_v5), - ("Julia v6 (vec=4+Lt)", grad_julia_v6), - ("Lux + Mooncake ", grad_lux), - ("PyTorch eager ", grad_pytorch_eager), - ("PyTorch compiled ", grad_pytorch_compiled)] + for (name, g) in [ + ("Julia v4 (vec=4) ", grad_julia_v4), + ("Julia v5 (vec=4+SIMT)", grad_julia_v5), + ("Julia v6 (vec=4+Lt)", grad_julia_v6), + ("Lux + Mooncake ", grad_lux), + ("PyTorch eager ", grad_pytorch_eager), + ("PyTorch compiled ", grad_pytorch_compiled), + ] maxdiff = maximum(abs.(grad_julia .- g)) - relmag = maxdiff / max(maximum(abs.(grad_julia)), eps(Float32)) - ok = isapprox(grad_julia, g; rtol = rtol, atol = 1f-4) + relmag = maxdiff / max(maximum(abs.(grad_julia)), eps(Float32)) + ok = isapprox(grad_julia, g; rtol = rtol, atol = 1.0f-4) @printf "%s vs Julia broadcast: max|Δ| = %.3e (rel %.2e) match=%s\n" name maxdiff relmag ok end @@ -605,11 +728,17 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) be = @benchmark begin pytorch_grad_eager($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() - end samples=30 evals=1 seconds=10 setup=($(_pygc).collect(); $torch.cuda.empty_cache()) + end samples=30 evals=1 seconds=10 setup=( + $(_pygc).collect(); + $torch.cuda.empty_cache() + ) bc = @benchmark begin pytorch_grad_compiled($W1t, $W2t, $Xt, $yt) $torch.cuda.synchronize() - end samples=30 evals=1 seconds=10 setup=($(_pygc).collect(); $torch.cuda.empty_cache()) + end samples=30 evals=1 seconds=10 setup=( + $(_pygc).collect(); + $torch.cuda.empty_cache() + ) @printf "Julia broadcast : median %8.3f µs\n" 1e-3 * median(bj).time @printf "Julia vec=4 : median %8.3f µs\n" 1e-3 * median(bj4).time @printf "Julia vec=4 + SIMT : median %8.3f µs\n" 1e-3 * median(bj5).time @@ -620,16 +749,28 @@ function run_one(; h::Int, d::Int = 13, n::Int = 178, rtol::Float32 = 1f-3) # ----- CUDA traces ----- println("\n--- CUDA trace: Julia broadcast ---") - summarize_julia_trace(stdout, julia_trace(() -> reverse_diff(W1g, W2g, Xg, yg))) + summarize_julia_trace( + stdout, + julia_trace(() -> reverse_diff(W1g, W2g, Xg, yg)), + ) println("\n--- CUDA trace: Julia vec=4 ---") - summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v4(W1g, W2g, Xg, yg))) + summarize_julia_trace( + stdout, + julia_trace(() -> reverse_diff_v4(W1g, W2g, Xg, yg)), + ) println("\n--- CUDA trace: Julia vec=4 + SIMT ---") - summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v5(W1g, W2g, Xg, yg))) + summarize_julia_trace( + stdout, + julia_trace(() -> reverse_diff_v5(W1g, W2g, Xg, yg)), + ) println("\n--- CUDA trace: Julia vec=4 + cuBLASLt ---") - summarize_julia_trace(stdout, julia_trace(() -> reverse_diff_v6(lt_plans, W1g, W2g, Xg, yg))) + summarize_julia_trace( + stdout, + julia_trace(() -> reverse_diff_v6(lt_plans, W1g, W2g, Xg, yg)), + ) println("\n--- CUDA trace: Lux + Mooncake ---") summarize_julia_trace(stdout, julia_trace(() -> lux_grad(lm))) @@ -657,8 +798,15 @@ function main() # Julia versions equally, so the broadcast-vs-vec=4 comparison still # isolates the kernel-design effect. CUDA.math_mode!(CUDA.FAST_MATH) - println("CUDA.jl device : ", CUDA.name(CUDA.device()), " (math_mode=FAST_MATH)") - println("PyTorch device : ", pyconvert(String, torch.cuda.get_device_name(0))) + println( + "CUDA.jl device : ", + CUDA.name(CUDA.device()), + " (math_mode=FAST_MATH)", + ) + println( + "PyTorch device : ", + pyconvert(String, torch.cuda.get_device_name(0)), + ) for h in (16, 256, 4096) run_one(; h = h) From 95103c72dd29ef3c88a7c08430acba699119cb4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 5 May 2026 21:02:13 +0200 Subject: [PATCH 14/14] Remove unused --- src/ArrayDiff.jl | 1 - src/types.jl | 4 ---- 2 files changed, 5 deletions(-) diff --git a/src/ArrayDiff.jl b/src/ArrayDiff.jl index 2856161..9595d82 100644 --- a/src/ArrayDiff.jl +++ b/src/ArrayDiff.jl @@ -66,7 +66,6 @@ include("array_nonlinear_function.jl") include("parse_moi.jl") model(::Mode{S}) where {S} = Model() -storage_type(::Mode{S}) where {S} = S # Extend MOI.Nonlinear.set_objective so that solvers calling # MOI.Nonlinear.set_objective(arraydiff_model, snf) dispatch here. diff --git a/src/types.jl b/src/types.jl index d043b52..1127f62 100644 --- a/src/types.jl +++ b/src/types.jl @@ -340,7 +340,3 @@ mutable struct NLPEvaluator{S<:AbstractVector{Float64}} <: return new{S}(data, ordered_variables) end end - -function NLPEvaluator(data::Model, ordered_variables::Vector{MOI.VariableIndex}) - return NLPEvaluator{Vector{Float64}}(data, ordered_variables) -end