Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def _validate_split_kv_size(value: int) -> int:
"FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))),
# Whether to enable top_p=1.0 optimization.
"FD_ENABLE_TOP_P_ONE_OPT": lambda: bool(int(os.getenv("FD_ENABLE_TOP_P_ONE_OPT", "1"))),
# Sub-switches of --enable-flashinfer-allreduce-fusion (only effective when that flag is on).
# Independently control the attention-side allreduce+rmsnorm fusion and the moe-side fusion.
"FD_ENABLE_ATTN_ALLREDUCE_FUSION": lambda: bool(int(os.getenv("FD_ENABLE_ATTN_ALLREDUCE_FUSION", "1"))),
"FD_ENABLE_MOE_ALLREDUCE_FUSION": lambda: bool(int(os.getenv("FD_ENABLE_MOE_ALLREDUCE_FUSION", "1"))),
}


Expand Down
11 changes: 9 additions & 2 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
else:
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm

from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.ops.triton_ops import _TRITON_AVAILABLE, qk_rmsnorm_fused

Expand Down Expand Up @@ -123,9 +124,15 @@ def __init__(
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm")
# post_attention_layernorm fuses the attention-side allreduce; input_layernorm
# (of the next layer) fuses the moe-side allreduce. Allow disabling each via env.
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and (
("post_attention_layernorm" in prefix)
or (("input_layernorm" in prefix and layer_id != 0) and not fd_config.parallel_config.use_ep)
(("post_attention_layernorm" in prefix) and envs.FD_ENABLE_ATTN_ALLREDUCE_FUSION)
or (
("input_layernorm" in prefix and layer_id != 0)
and not fd_config.parallel_config.use_ep
and envs.FD_ENABLE_MOE_ALLREDUCE_FUSION
)
)

self.is_last_norm = prefix.endswith(".norm")
Expand Down
12 changes: 9 additions & 3 deletions fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from paddleformers.utils.log import logger

import fastdeploy
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
Expand Down Expand Up @@ -65,7 +66,9 @@ def __init__(
) -> None:
super().__init__()
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and not reduce_results
fd_config.parallel_config.enable_flashinfer_allreduce_fusion
and envs.FD_ENABLE_MOE_ALLREDUCE_FUSION
and not reduce_results
)

# shared experts not split when use_sequence_parallel_moe in ep + tp
Expand Down Expand Up @@ -139,7 +142,9 @@ def __init__(
self.use_tp = self.tensor_parallel_size > 1
self.last_layer_id = fd_config.model_config.num_hidden_layers - 1
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and layer_id != self.last_layer_id
fd_config.parallel_config.enable_flashinfer_allreduce_fusion
and envs.FD_ENABLE_MOE_ALLREDUCE_FUSION
and layer_id != self.last_layer_id
)
self.n_routed_experts: int = fd_config.model_config.n_routed_experts
self.n_shared_experts: int = fd_config.model_config.n_shared_experts
Expand Down Expand Up @@ -239,7 +244,8 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion
and envs.FD_ENABLE_ATTN_ALLREDUCE_FUSION,
)

self.attn = Attention(
Expand Down
Loading