Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
873 changes: 384 additions & 489 deletions docs/architecture/human-in-the-loop.md

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions go/adk/pkg/a2a/hitl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
)

var (
denyWordPatterns []*regexp.Regexp
rejectWordPatterns []*regexp.Regexp
approveWordPatterns []*regexp.Regexp
)

func init() {
for _, keyword := range KAgentHitlResumeKeywordsDeny {
denyWordPatterns = append(denyWordPatterns, regexp.MustCompile(`(?i)\b`+regexp.QuoteMeta(keyword)+`\b`))
for _, keyword := range KAgentHitlResumeKeywordsReject {
rejectWordPatterns = append(rejectWordPatterns, regexp.MustCompile(`(?i)\b`+regexp.QuoteMeta(keyword)+`\b`))
}
for _, keyword := range KAgentHitlResumeKeywordsApprove {
approveWordPatterns = append(approveWordPatterns, regexp.MustCompile(`(?i)\b`+regexp.QuoteMeta(keyword)+`\b`))
Expand All @@ -28,20 +28,20 @@ const (
KAgentHitlInterruptTypeToolApproval = "tool_approval"
KAgentHitlDecisionTypeKey = "decision_type"
KAgentHitlDecisionTypeApprove = "approve"
KAgentHitlDecisionTypeDeny = "deny"
KAgentHitlDecisionTypeReject = "reject"
)

var (
KAgentHitlResumeKeywordsApprove = []string{"approved", "approve", "proceed", "yes", "continue"}
KAgentHitlResumeKeywordsDeny = []string{"denied", "deny", "reject", "no", "cancel", "stop"}
KAgentHitlResumeKeywordsReject = []string{"denied", "deny", "reject", "no", "cancel", "stop"}
)

// DecisionType represents a HITL decision.
type DecisionType string

const (
DecisionApprove DecisionType = "approve"
DecisionDeny DecisionType = "deny"
DecisionReject DecisionType = "reject"
)

// ToolApprovalRequest represents a tool call requiring user approval.
Expand All @@ -60,9 +60,9 @@ func GetKAgentMetadataKey(key string) string {
// keyword matching. Word boundaries prevent false positives from substrings
// (e.g. "no" inside "know", "yes" inside "yesterday").
func ExtractDecisionFromText(text string) DecisionType {
for _, pattern := range denyWordPatterns {
for _, pattern := range rejectWordPatterns {
if pattern.MatchString(text) {
return DecisionDeny
return DecisionReject
}
}
for _, pattern := range approveWordPatterns {
Expand All @@ -87,8 +87,8 @@ func ExtractDecisionFromMessage(message *a2atype.Message) DecisionType {
switch decision {
case KAgentHitlDecisionTypeApprove:
return DecisionApprove
case KAgentHitlDecisionTypeDeny:
return DecisionDeny
case KAgentHitlDecisionTypeReject:
return DecisionReject
}
}
}
Expand Down
30 changes: 15 additions & 15 deletions go/adk/pkg/a2a/hitl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ func TestExtractDecisionFromMessage_DataPart(t *testing.T) {
t.Errorf("ExtractDecisionFromMessage(approve DataPart) = %q, want %q", result, DecisionApprove)
}

denyData := map[string]any{
KAgentHitlDecisionTypeKey: KAgentHitlDecisionTypeDeny,
rejectData := map[string]any{
KAgentHitlDecisionTypeKey: KAgentHitlDecisionTypeReject,
}
message = a2atype.NewMessage(a2atype.MessageRoleUser,
&a2atype.DataPart{Data: denyData},
&a2atype.DataPart{Data: rejectData},
)
result = ExtractDecisionFromMessage(message)
if result != DecisionDeny {
t.Errorf("ExtractDecisionFromMessage(deny DataPart) = %q, want %q", result, DecisionDeny)
if result != DecisionReject {
t.Errorf("ExtractDecisionFromMessage(reject DataPart) = %q, want %q", result, DecisionReject)
}
}

Expand All @@ -66,8 +66,8 @@ func TestExtractDecisionFromMessage_TextPart(t *testing.T) {
a2atype.TextPart{Text: "Request denied, do not proceed"},
)
result = ExtractDecisionFromMessage(message)
if result != DecisionDeny {
t.Errorf("ExtractDecisionFromMessage(deny text) = %q, want %q", result, DecisionDeny)
if result != DecisionReject {
t.Errorf("ExtractDecisionFromMessage(reject text) = %q, want %q", result, DecisionReject)
}

message = a2atype.NewMessage(a2atype.MessageRoleUser,
Expand All @@ -84,13 +84,13 @@ func TestExtractDecisionFromMessage_Priority(t *testing.T) {
a2atype.TextPart{Text: "approved"},
&a2atype.DataPart{
Data: map[string]any{
KAgentHitlDecisionTypeKey: KAgentHitlDecisionTypeDeny,
KAgentHitlDecisionTypeKey: KAgentHitlDecisionTypeReject,
},
},
)
result := ExtractDecisionFromMessage(message)
if result != DecisionDeny {
t.Errorf("ExtractDecisionFromMessage(mixed parts) = %q, want %q (DataPart should take priority)", result, DecisionDeny)
if result != DecisionReject {
t.Errorf("ExtractDecisionFromMessage(mixed parts) = %q, want %q (DataPart should take priority)", result, DecisionReject)
}
}

Expand Down Expand Up @@ -125,12 +125,12 @@ func TestExtractDecisionFromText_WordBoundary(t *testing.T) {
{name: "yes inside yesterday should not match", text: "yesterday was fine", want: ""},
{name: "stop inside unstoppable should not match", text: "unstoppable progress", want: ""},
{name: "cancel inside cancellation should not match", text: "the cancellation policy", want: ""},
{name: "standalone no matches", text: "no, I do not agree", want: DecisionDeny},
{name: "standalone no matches", text: "no, I do not agree", want: DecisionReject},
{name: "standalone yes matches", text: "yes, go ahead", want: DecisionApprove},
{name: "standalone stop matches", text: "stop the process", want: DecisionDeny},
{name: "case insensitive whole word", text: "NO", want: DecisionDeny},
{name: "keyword at end of sentence", text: "the answer is no", want: DecisionDeny},
{name: "keyword with punctuation", text: "no!", want: DecisionDeny},
{name: "standalone stop matches", text: "stop the process", want: DecisionReject},
{name: "case insensitive whole word", text: "NO", want: DecisionReject},
{name: "keyword at end of sentence", text: "the answer is no", want: DecisionReject},
{name: "keyword with punctuation", text: "no!", want: DecisionReject},
{name: "continue inside discontinue should not match", text: "I will discontinue", want: ""},
{name: "approve as standalone", text: "I approve", want: DecisionApprove},
}
Expand Down
120 changes: 86 additions & 34 deletions python/packages/kagent-adk/src/kagent/adk/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,18 @@ async def _publish_failed_status_event(
logger.error("Failed to publish failure event: %s", enqueue_error, exc_info=True)

@staticmethod
def _find_pending_confirmations(session: Session) -> dict[str, str | None]:
def _find_pending_confirmations(session: Session) -> dict[str, tuple[str | None, dict | None]]:
"""Find pending adk_request_confirmation calls and their original tool call IDs.

Scans session events backwards for the most recent adk_request_confirmation
FunctionCall events that haven't been responded to yet.

Returns:
Dict mapping confirmation function_call_id to the original tool call ID
(from args.originalFunctionCall.id), or None if not available.
Dict mapping confirmation function_call_id to a tuple of:
- the original tool call ID (from args.originalFunctionCall.id), or None
- the original toolConfirmation payload (from args.toolConfirmation.payload), or None
"""
pending: dict[str, str | None] = {}
pending: dict[str, tuple[str | None, dict | None]] = {}
responded_ids: set[str] = set()

for event in reversed(session.events or []):
Expand All @@ -364,16 +365,23 @@ def _find_pending_confirmations(session: Session) -> dict[str, str | None]:
if fr.name == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME and fr.id is not None:
responded_ids.add(fr.id)

# Collect requested confirmation IDs and extract original tool call ID
# Collect requested confirmation IDs and extract original tool call ID + payload
for fc in event.get_function_calls():
if fc.name == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME and fc.id is not None:
# Extract original tool call ID from args.originalFunctionCall.id
original_id = None
original_payload = None
if fc.args and isinstance(fc.args, dict):
orig_fc = fc.args.get("originalFunctionCall")
if isinstance(orig_fc, dict):
original_id = orig_fc.get("id")
pending[fc.id] = original_id
tool_conf = fc.args.get("toolConfirmation")
if isinstance(tool_conf, dict):
original_payload = tool_conf.get("payload")
if isinstance(original_payload, dict):
original_payload = dict(original_payload)
else:
original_payload = None
pending[fc.id] = (original_id, original_payload)

# Stop scanning once we find confirmation requests (they're recent)
if pending:
Expand All @@ -385,6 +393,27 @@ def _find_pending_confirmations(session: Session) -> dict[str, str | None]:

return pending

@staticmethod
def _build_confirmation_payload(
original_payload: dict | None,
extra: dict | None,
) -> dict | None:
"""Merge the original request_confirmation payload with decision-specific data.

The original payload (set by the tool in ``request_confirmation()``) is
preserved so that the tool's ``_handle_resume`` can read its own state
(e.g. subagent task_id, context_id). Decision-specific keys (like
``rejection_reason``) are merged on top.
"""
if not original_payload and not extra:
return None
merged: dict = {}
if original_payload:
merged.update(original_payload)
if extra:
merged.update(extra)
return merged

def _process_hitl_decision(
self, session: Session, decision: str, message: Message
) -> list[genai_types.Part] | None:
Expand All @@ -394,9 +423,9 @@ def _process_hitl_decision(
return None

logger.info(
"HITL continuation detected: decision=%s, pending_confirmations=%d",
"HITL continuation: decision=%s, pending=%s",
decision,
len(pending_confirmations),
{fc_id: orig_id for fc_id, (orig_id, _) in pending_confirmations.items()},
)

# Check for ask-user answers — if present, build a single approved
Expand All @@ -405,8 +434,9 @@ def _process_hitl_decision(
ask_user_answers = extract_ask_user_answers_from_message(message)
if ask_user_answers is not None:
parts = []
for fc_id in pending_confirmations:
confirmation = ToolConfirmation(confirmed=True, payload={"answers": ask_user_answers})
for fc_id, (_, orig_payload) in pending_confirmations.items():
payload = self._build_confirmation_payload(orig_payload, {"answers": ask_user_answers})
confirmation = ToolConfirmation(confirmed=True, payload=payload)
parts.append(
genai_types.Part(
function_response=genai_types.FunctionResponse(
Expand All @@ -424,19 +454,37 @@ def _process_hitl_decision(
if decision == KAGENT_HITL_DECISION_TYPE_BATCH:
# Batch mode: per-tool decisions
batch_decisions = extract_batch_decisions_from_message(message) or {}
logger.info(
"HITL batch: batch_decisions=%s, rejection_reasons=%s",
batch_decisions,
rejection_reasons,
)
parts = []
for fc_id, original_id in pending_confirmations.items():
# Look up the per-tool decision using the original tool call ID
tool_decision = batch_decisions.get(original_id, KAGENT_HITL_DECISION_TYPE_APPROVE)
confirmed = tool_decision == KAGENT_HITL_DECISION_TYPE_APPROVE
# Attach rejection reason if provided for this specific tool
payload: dict | None = None
if not confirmed and rejection_reasons:
reason = rejection_reasons.get(original_id) if original_id else None
if reason:
payload = {"rejection_reason": reason}
confirmation = ToolConfirmation(confirmed=confirmed, payload=payload)
# Append a response for each tool call
for fc_id, (original_id, orig_payload) in pending_confirmations.items():
# Check if this is a subagent HITL request by checking if orig_payload has hitl_parts.
is_subagent = bool(orig_payload and orig_payload.get("hitl_parts"))

if is_subagent:
# Forward the entire batch decision to the tool so
# _handle_resume can relay it to the subagent as-is.
all_approved = all(d == KAGENT_HITL_DECISION_TYPE_APPROVE for d in batch_decisions.values())
extra: dict = {"batch_decisions": batch_decisions}
if rejection_reasons:
extra["rejection_reasons"] = rejection_reasons
payload = self._build_confirmation_payload(orig_payload, extra)
confirmation = ToolConfirmation(confirmed=all_approved, payload=payload)
else:
# Direct tool — look up by original_id as before
tool_decision = batch_decisions.get(original_id, KAGENT_HITL_DECISION_TYPE_APPROVE)
confirmed = tool_decision == KAGENT_HITL_DECISION_TYPE_APPROVE
extra_reject: dict | None = None
if not confirmed and rejection_reasons:
reason = rejection_reasons.get(original_id) if original_id else None
if reason:
extra_reject = {"rejection_reason": reason}
payload = self._build_confirmation_payload(orig_payload, extra_reject)
confirmation = ToolConfirmation(confirmed=confirmed, payload=payload)

parts.append(
genai_types.Part(
function_response=genai_types.FunctionResponse(
Expand All @@ -451,22 +499,26 @@ def _process_hitl_decision(
# Uniform mode: same decision for all pending tools
confirmed = decision == KAGENT_HITL_DECISION_TYPE_APPROVE
# Attach rejection reason if provided (uniform denial uses "*" sentinel)
payload = None
uniform_extra: dict | None = None
if not confirmed and rejection_reasons:
reason = rejection_reasons.get("*")
if reason:
payload = {"rejection_reason": reason}
confirmation = ToolConfirmation(confirmed=confirmed, payload=payload)
return [
genai_types.Part(
function_response=genai_types.FunctionResponse(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
id=fc_id,
response={"response": confirmation.model_dump_json()},
uniform_extra = {"rejection_reason": reason}
parts = []
for fc_id, (_, orig_payload) in pending_confirmations.items():
merged_payload = self._build_confirmation_payload(orig_payload, uniform_extra)
confirmation = ToolConfirmation(confirmed=confirmed, payload=merged_payload)
serialized = confirmation.model_dump_json()
parts.append(
genai_types.Part(
function_response=genai_types.FunctionResponse(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
id=fc_id,
response={"response": serialized},
)
)
)
for fc_id in pending_confirmations
]
return parts

async def _handle_request(
self,
Expand Down
Loading
Loading