diff --git a/tools/deep_gemm_pre-compile/generate_config.py b/tools/deep_gemm_pre-compile/generate_config.py index 46bfa4347..d40498ad1 100644 --- a/tools/deep_gemm_pre-compile/generate_config.py +++ b/tools/deep_gemm_pre-compile/generate_config.py @@ -41,51 +41,26 @@ def generate_kn_pairs(args, model_cfg: dict) -> Tuple[List, List, List]: gemm_kn_pairs = [] grouped_gemm_contiguous_kn_pairs = [] grouped_gemm_masked_kn_pairs = [] - if tp_size > 1 and ep_size == 1: - logger.debug("Generating kn pairs for tensor parallel.") - # Dense normal gemm - gemm_kn_pairs.extend( - [ - [int(intermediate_size / tp_size), hidden_size], - [hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2) / tp_size)], - [hidden_size, int(intermediate_size * 2 / tp_size)], - [int(hidden_size / tp_size), hidden_size], - ] - ) + logger.debug("Generating kn pairs for tensor parallel.") + # Dense normal gemm + gemm_kn_pairs.extend( + [ + [int(intermediate_size / tp_size), hidden_size], + [hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2) / tp_size)], + [hidden_size, int(intermediate_size * 2 / tp_size)], + [int(hidden_size / tp_size), hidden_size], + ] + ) - # Moe grouped gemm contiguous - grouped_gemm_contiguous_kn_pairs.extend( - [ - [int(moe_intermediate_size / tp_size), hidden_size], - [hidden_size, int(moe_intermediate_size * 2 / tp_size)], - ] - ) - if has_shared_experts: - logger.debug("Generating kn pairs for models with shared experts.") - gemm_kn_pairs.extend( - [ - [hidden_size, int(moe_intermediate_size * 4 / tp_size)], - [int(moe_intermediate_size * 2 / tp_size), hidden_size], - ] - ) - elif tp_size == 1 and ep_size > 1: - logger.debug("Generating kn pairs for expert parallel.") - # Dense normal gemm - gemm_kn_pairs.extend( - [ - [intermediate_size, hidden_size], - [hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2))], - [hidden_size, int(intermediate_size * 2)], - [hidden_size, hidden_size], - ] - ) - # Moe grouped gemm contiguous - grouped_gemm_contiguous_kn_pairs.extend( - [ - [moe_intermediate_size, hidden_size], - [hidden_size, int(moe_intermediate_size * 2)], - ] - ) + # Moe grouped gemm contiguous + grouped_gemm_contiguous_kn_pairs.extend( + [ + [int(moe_intermediate_size / tp_size), hidden_size], + [hidden_size, int(moe_intermediate_size * 2 / tp_size)], + ] + ) + + if ep_size > 1: # Moe grouped gemm masked grouped_gemm_masked_kn_pairs.extend( [ @@ -93,18 +68,14 @@ def generate_kn_pairs(args, model_cfg: dict) -> Tuple[List, List, List]: [hidden_size, int(moe_intermediate_size * 2)], ] ) - if has_shared_experts: - logger.debug("Generating kn pairs for models with shared experts.") - gemm_kn_pairs.extend( - [ - [hidden_size, int(moe_intermediate_size * 4)], - [int(moe_intermediate_size * 2), hidden_size], - ] - ) - elif tp_size > 1 and ep_size > 1: - raise ValueError("Not supported to enable EP and TP at the same time for now.") - else: - raise ValueError("Please check the tensor parallel size and expert parallel size.") + if has_shared_experts: + logger.debug("Generating kn pairs for models with shared experts.") + gemm_kn_pairs.extend( + [ + [hidden_size, int(moe_intermediate_size * 4 / tp_size)], + [int(moe_intermediate_size * 2 / tp_size), hidden_size], + ] + ) return ( gemm_kn_pairs,