Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions fastdeploy/cache_manager/cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,8 +1074,6 @@ def consume_signals(self):
layer_id = kv_signal_data[1].item()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}")
ready_engine_signals = []
pending_engine_signals = []
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
with self.engine_cache_task_thread_lock:
for bi in range(tasks_count):
Expand All @@ -1086,21 +1084,19 @@ def consume_signals(self):
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = prefilled_token_num
if layer_id == 0:
if engine_idx in self.idx_cache_task_dict:
ready_engine_signals.append((engine_idx, prefilled_token_num))
else:
pending_engine_signals.append((engine_idx, prefilled_token_num))
if pending_engine_signals:
with self.pending_layer0_signal_lock:
for engine_idx, prefilled_token_num in pending_engine_signals:
with self.pending_layer0_signal_lock:
self.pending_layer0_signals[engine_idx] = (engine_idx, prefilled_token_num)
if pending_engine_signals:
logger.debug(f"cache_task_pending_layer0_signal: {pending_engine_signals}")
if ready_engine_signals:
logger.info(
f"Put batch_engine_signals {ready_engine_signals} into cache_prefilled_engine_ids_queue"
)
self.cache_prefilled_engine_ids_queue.put(ready_engine_signals)
# Recover signals for engine_idxs that already have cache_info registered.
# This handles the case where cache_info arrives before layer0 signal.
recovered_signals = []
with self.pending_layer0_signal_lock:
for engine_idx in list(self.pending_layer0_signals.keys()):
if engine_idx in self.idx_cache_task_dict:
recovered_signals.append(self.pending_layer0_signals.pop(engine_idx))

This comment was marked as outdated.

if recovered_signals:
for signal in recovered_signals:
logger.info(f"consume_signals recovered signal: {signal}")
self.cache_prefilled_engine_ids_queue.put([signal])
except Exception as e:
logger.error(f"Consume signals get exception: {e}, {traceback.format_exc()}")

Expand Down
Loading