Skip to content

[STF] Migrate __stf/allocators/ from cuda_safe_call to cuda_try#9147

Merged
andralex merged 10 commits into
NVIDIA:mainfrom
andralex:andralex/stf-cuda-try-allocators
Jun 10, 2026
Merged

[STF] Migrate __stf/allocators/ from cuda_safe_call to cuda_try#9147
andralex merged 10 commits into
NVIDIA:mainfrom
andralex:andralex/stf-cuda-try-allocators

Conversation

@andralex

Copy link
Copy Markdown
Contributor

Summary

First PR in a series migrating production STF headers off the abort-on-failure cuda_safe_call and onto the throw-on-failure cuda_try, so callers (Python wrappers, any exception-aware control flow) can recover from CUDA errors instead of having the process aborted underneath them.

Intentionally small -- two call sites in cudax/include/cuda/experimental/__stf/allocators/ -- so the conversion patterns can be reviewed before scaling up.

Changes

pooled_allocator.cuh

block_data_pool's constructor used cudaGetDeviceProperties via cuda_safe_call. Now uses the templated cuda_try<F>(args...) form -- which deduces the first-output substitution and returns the populated struct -- so the variable can be const-initialized:

const cudaDeviceProp prop = cuda_try<cudaGetDeviceProperties>(dev);
const size_t max_mem = prop.totalGlobalMem;

Leak audit on throw:

  • The throw point is upstream of the only GPU allocation in the constructor (root_allocator.allocate(...)).
  • All member subobjects unwound at the throw point have noexcept-clean destructors (data_place and block_allocator_untyped are shared_ptr pimpls).
  • The single call site wraps construction in map.emplace(...), which is exception-safe.

adapters.cuh

stream_adapter::clear() was a for-each over the to_free vector with a lazy sync before the first blocking deallocation. Under cuda_safe_call a sync failure just aborted; converting to cuda_try exposed the lack of exception safety -- on a thrown sync, the unprocessed buffers would have been silently abandoned, and the destructor's _CCCL_ASSERT(cleared_or_moved, ...) would either lie or fire spuriously.

Rewritten to be transactional:

bool stream_synchronized = false;
while (!adapter_state->to_free.empty())
{
  const auto b = mv(adapter_state->to_free.back());
  adapter_state->to_free.pop_back();
  SCOPE(exit)
  {
    b.memory_node.deallocate(b.ptr, b.sz, stream);
  };

  if (!stream_synchronized && !b.memory_node.allocation_is_stream_ordered())
  {
    cuda_try(cudaStreamSynchronize(stream));
    stream_synchronized = true;
  }
}
cleared_or_moved = true;

On a throw from cudaStreamSynchronize:

  • The just-popped buffer is still deallocated by its SCOPE(exit).
  • to_free accurately holds the remaining un-deallocated entries.
  • cleared_or_moved stays false.
  • The caller can catch and retry clear(), or let the destructor's assertion fire with truthful state.

Inter-buffer deallocation order doesn't matter (each raw_buffer is independent), so popping from the back is the natural O(1) choice. mv(...back()) + pop_back() skips one shared_ptr refcount bump on data_place per iteration; the move is noexcept, so no half-moved-not-popped risk.

Explicit #include <cuda/experimental/__stf/utility/scope_guard.cuh> added rather than relying on transitive inclusion.

Migration pattern notes (for follow-ups)

  • SAFE: pure queries with no in-flight CUDA state → direct cuda_try substitution. Prefer the templated cuda_try<F>(args...) form when the function has an out-parameter, so the result can be const-initialized.
  • GUARDED: operations that need an undo step on failure → cuda_try + SCOPE(fail) for the rollback. Inside the SCOPE(fail) body, use cuda_safe_call, not cuda_try -- guard destructors are noexcept, so a thrown exception during unwinding would std::terminate.
  • TRANSACTIONAL (this PR's clear() pattern): when looping over state to release, pop incrementally and use per-iteration SCOPE(exit) so the in-flight item is always freed and the remaining queue stays accurate on throw.
  • KEEP: destructors and CUDA host callbacks remain cuda_safe_call. Same rationale -- those contexts are noexcept.
  • Audit: "if this throws, will any subobject destructor on the unwind path need to make CUDA calls?" If yes, add a SCOPE(fail) rollback in the constructor body.

Test plan

  • CI green (/ok to test triggered in comment below)
  • No behavioral change on the success path

First in a series migrating production STF headers off the
abort-on-failure ``cuda_safe_call`` onto the throw-on-failure
``cuda_try``, so callers (Python wrappers, exception-aware control flow)
can recover from CUDA errors instead of having the process aborted.

pooled_allocator.cuh:
  - ``cudaGetDeviceProperties`` query in ``block_data_pool``'s
    constructor: convert to the templated
    ``cuda_try<cudaGetDeviceProperties>(dev)`` form, which deduces the
    first-output substitution and returns the populated struct, so the
    variable can be const-initialized. Mark adjacent ``max_mem`` const
    as well.
  - Leak audit: if the ``cuda_try`` throws, no GPU resources have been
    allocated yet (the only ``root_allocator.allocate`` is downstream).
    Member subobjects unwound at the throw point have noexcept-clean
    destructors (``data_place`` and ``block_allocator_untyped`` are
    shared_ptr pimpls). The single call site
    (``block_data_pool_set::get_pool``) wraps construction in
    ``map.emplace``, which is exception-safe.

adapters.cuh:
  - ``stream_adapter::clear()`` rewritten to be transactional. The
    original for-each over ``to_free`` + lazy sync would, on a thrown
    sync, silently abandon the remaining buffers and leave the
    ``cleared_or_moved`` flag in a contradictory state (either lying
    that cleanup succeeded, or firing the destructor's sanity assertion
    spuriously). New form pops one buffer at a time, installs a
    ``SCOPE(exit)`` that frees just that buffer, then syncs lazily. On
    throw: the in-flight buffer is freed, ``to_free`` still holds the
    remaining pending entries, ``cleared_or_moved`` stays false, and
    the caller can catch + retry (or let the destructor's assertion
    fire with accurate state). Inter-buffer order is irrelevant (each
    ``raw_buffer`` is independent), so popping from the back is the
    O(1) choice.
  - Move-from-back + ``pop_back`` skips one shared_ptr refcount bump
    per iteration on ``data_place``. The move is noexcept, so no
    half-moved-not-popped risk.
  - Marked ``stream`` and the per-iteration ``b`` const.
  - Pulled ``scope_guard.cuh`` in explicitly rather than relying on
    transitive inclusion.

Pilot PR for the broader migration -- intentionally small (2 sites) so
the conversion patterns (transactional SCOPE-based cleanup, templated
``cuda_try<F>`` form, const-correctness sweep) can be reviewed before
scaling up.
@andralex

Copy link
Copy Markdown
Contributor Author

/ok to test b81c5df

@andralex andralex requested a review from a team as a code owner May 27, 2026 22:50
@andralex andralex requested a review from caugonnet May 27, 2026 22:50
@github-project-automation github-project-automation Bot moved this to Todo in CCCL May 27, 2026
@copy-pr-bot

copy-pr-bot Bot commented May 27, 2026

Copy link
Copy Markdown
Contributor

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL May 27, 2026
@coderabbitai

coderabbitai Bot commented May 27, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 382cf721-e36a-4a8f-8799-6b6f62d7ee5f

📥 Commits

Reviewing files that changed from the base of the PR and between a9d287a and 5200237.

📒 Files selected for processing (1)
  • cudax/include/cuda/experimental/__stf/allocators/adapters.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • cudax/include/cuda/experimental/__stf/allocators/adapters.cuh

📝 Walkthrough

Summary by CodeRabbit

  • Bug Fixes

    • More robust resource cleanup: deallocations proceed even when stream synchronization reports errors; synchronization failures are reported after cleanup.
    • Improved exception-safety to avoid descriptor/resource leaks during cleanup failures.
  • Performance Improvements

    • Stream synchronization is now performed lazily only when needed, reducing unnecessary stalls.
    • Reduced buffer handling overhead during deallocation for lower allocation latency.
  • Stability

    • Device memory capacity discovery updated for more reliable device queries while preserving allocation caps.

important: ## Walkthrough

Stream adapter and pooled allocator update error-handling: stream_adapter::clear() drains pending buffers with lazy stream synchronization and deferred error propagation; pooled allocator constructor queries device properties via cuda_try initialization.

Changes

STF allocator error handling updates

Layer / File(s) Summary
Stream adapter clear() draining and lazy sync
cudax/include/cuda/experimental/__stf/allocators/adapters.cuh
Drains adapter_state->to_free with a pop_back() loop, captures the CUDA stream once, performs a single lazy cudaStreamSynchronize(stream) when encountering the first non-stream-ordered buffer (records the error), deallocates each popped raw_buffer immediately, and surfaces the recorded sync error via cuda_try.
Pooled allocator device property query modernization
cudax/include/cuda/experimental/__stf/allocators/pooled_allocator.cuh
Replaces two-step cuda_safe_call(cudaGetDeviceProperties(&prop, dev)) with inline cuda_try<cudaGetDeviceProperties>(dev) initialization while preserving prop.totalGlobalMem-based capacity computation.

Possibly related PRs

Suggested labels

stf

Suggested reviewers

  • caugonnet
  • alliepiper

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
cudax/include/cuda/experimental/__stf/allocators/pooled_allocator.cuh (1)

76-76: ⚡ Quick win

suggestion: qualify the new cuda_try<cudaGetDeviceProperties> call from the global namespace instead of relying on unqualified lookup in this header. As per coding guidelines, "All calls to free functions must be fully qualified starting from the global namespace, e.g., ::cuda::ceil_div, including calls to functions in the same namespace."


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d250ca80-365d-4dfb-bde7-d7c11380d77c

📥 Commits

Reviewing files that changed from the base of the PR and between 740b3c0 and b81c5df.

📒 Files selected for processing (2)
  • cudax/include/cuda/experimental/__stf/allocators/adapters.cuh
  • cudax/include/cuda/experimental/__stf/allocators/pooled_allocator.cuh

Comment thread cudax/include/cuda/experimental/__stf/allocators/adapters.cuh Outdated
Review bot flagged that the SCOPE(exit) cleanup in the previous version
of clear() calls b.memory_node.deallocate(...), which can throw --
data_place_*::deallocate all use cuda_try internally for cudaFreeHost /
cudaFree / cudaFreeAsync, and the invalid / affine / device_auto
variants throw std::logic_error unconditionally. SCOPE bodies are
noexcept, so a deallocate-throw during unwinding from a sync failure
would call std::terminate.

Restructured to avoid putting deallocate() in a noexcept context.
Capture the sync status (do not throw on the spot), do the
deallocation normally, then surface the captured error via
cuda_try(sync_err) afterwards.

Failure modes:
  - Sync ok, deallocate ok: loop continues, cleared_or_moved = true.
  - Sync fails, deallocate ok: deallocate runs, cuda_try(sync_err)
    throws cuda_exception, to_free holds the rest, cleared_or_moved
    stays false. Caller can retry.
  - Sync ok, deallocate throws: deallocate's exception propagates,
    to_free holds the rest, cleared_or_moved stays false.
  - Both fail (correlated -- likely the same sticky CUDA error):
    deallocate's exception wins, sync_err is lost. User-visible
    diagnostic is equivalent because both reflect the same root cause.

Also dropped the now-unused scope_guard.cuh include.
@andralex

Copy link
Copy Markdown
Contributor Author

Thanks for the review -- you were right on all counts. Fixed in 65d7e9a.

Restructured to avoid putting deallocate() in a noexcept context. The new flow captures the sync status, performs the deallocation normally (which can now throw safely), and then surfaces the sync error via cuda_try(sync_err) at a convenient throw point:

cudaError_t sync_err = cudaSuccess;
if (!stream_synchronized && !b.memory_node.allocation_is_stream_ordered())
{
  sync_err = cudaStreamSynchronize(stream);
  if (sync_err == cudaSuccess) { stream_synchronized = true; }
}

b.memory_node.deallocate(b.ptr, b.sz, stream);
cuda_try(sync_err);

No SCOPE guards, no noexcept body containing throwing code. The transactional invariant ("on throw, to_free accurately reflects what is pending and cleared_or_moved stays false") is preserved.

Also dropped the now-unused scope_guard.cuh include.

Side note for a follow-up PR: data_place_device::deallocate (data_place_impl.cuh:284) has the same cuda_try-inside-SCOPE(exit) anti-pattern for the cudaSetDevice rollback. Outside this PR's scope but worth a separate fix.

@andralex

Copy link
Copy Markdown
Contributor Author

/ok to test 65d7e9a

Rename ``stream_synchronized`` -> ``cudaStreamSynchronize_was_called``
and ``sync_err`` -> ``cudaStreamSynchronize_result``. The new flag name
matches the new semantics: it tracks whether the sync was attempted at
all, not whether it succeeded. Drop the now-dead
``if (sync_err == cudaSuccess) { ... = true; }`` -- on failure the
subsequent ``cuda_try`` throws and exits the loop, so the "retry on
next iteration" branch is unreachable.

Add a short inline comment ahead of the two throwing statements noting
that on throw the loop is left in steady state (``to_free`` accurate,
``cleared_or_moved`` false), and refresh the upstream comment block to
reference the renamed local.

No behavioral change.
@andralex

Copy link
Copy Markdown
Contributor Author

/ok to test a9d287a

@github-actions

This comment has been minimized.

Test ``cudax.test.stf.threads.axpy-threads-graph`` aborts with
``malloc(): smallbin double linked list corrupted`` on H100 with this
PR's changes. The CI investigator's leading hypothesis is the
``raw_buffer`` ctor's ``memory_node(mv(memory_node_))`` (slipped in
during the move-semantics conversation) interacting badly with the
``clear()`` loop's ``mv(...back()) + pop_back()`` pattern.

Code review of both changes does not reveal an obvious double-free, so
revert the ctor change and keep the ``clear()`` restructure. If CI then
goes green, the ctor mv was the culprit and can be reintroduced more
carefully (or left out -- the saving is one shared_ptr refcount bump
per emplace, negligible in this code path). If CI still fails, the bug
is in the ``clear()`` loop and needs a deeper look.

No behavioral change relative to ``main``'s original raw_buffer ctor.
@andralex

Copy link
Copy Markdown
Contributor Author

Bisect: reverted the 5200237 raw_buffer ctor change (memory_node(mv(memory_node_)) -> memory_node(memory_node_)) to isolate the H100 crash in cudax.test.stf.threads.axpy-threads-graph. CI investigator's leading hypothesis was that this ctor change combined with the new mv(...back()) + pop_back() in clear() was producing the malloc(): smallbin double linked list corrupted abort. Code review did not pin it down, so testing empirically. If this run goes green, the ctor mv was the culprit; if still red, the bug is in the clear() loop and needs deeper analysis.

@andralex

Copy link
Copy Markdown
Contributor Author

/ok to test 5200237

@coderabbitai

coderabbitai Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Actionable comments posted: 0

@github-actions

This comment has been minimized.

@caugonnet

Copy link
Copy Markdown
Contributor

#9186 hopefully deal with these concurrency issues

@andralex andralex enabled auto-merge (squash) June 8, 2026 21:43
@andralex

andralex commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test b772be7

@andralex

andralex commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 7273e6a

@github-actions

This comment has been minimized.

@andralex

andralex commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 8a3e21b

@github-actions

This comment has been minimized.

@andralex

andralex commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test a5af2e6

@github-actions

This comment has been minimized.

@andralex

Copy link
Copy Markdown
Contributor Author

/ok to test

@copy-pr-bot

copy-pr-bot Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

/ok to test

@andralex, there was an error processing your request: E1

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/1/

@andralex

Copy link
Copy Markdown
Contributor Author

/ok to test a13bdb9

@github-actions

Copy link
Copy Markdown
Contributor

🥳 CI Workflow Results

🟩 Finished in 1h 45m: Pass: 100%/55 | Total: 1d 09h | Max: 1h 03m | Hits: 13%/191739

See results here.

@andralex andralex merged commit 420d8e0 into NVIDIA:main Jun 10, 2026
76 checks passed
@github-project-automation github-project-automation Bot moved this from In Review to Done in CCCL Jun 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

2 participants