""" This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py """ """Expert Parallelism Load Balancer (EPLB)""" from typing import Tuple import numpy as np def balanced_packing(weight: np.ndarray, num_packs: int) -> Tuple[np.ndarray, np.ndarray]: """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. Parameters: weight: [X, n], the weight of each item num_packs: number of packs Returns: pack_index: [X, n], the pack index of each item rank_in_pack: [X, n], the rank of the item in the pack """ num_layers, num_groups = weight.shape assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs if groups_per_pack == 1: pack_index = np.arange(weight.shape[-1], dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0) rank_in_pack = np.zeros_like(weight, dtype=np.int32) return pack_index, rank_in_pack indices = np.argsort(-weight.astype(np.float32), axis=-1) pack_index = np.full_like(weight, fill_value=-1, dtype=np.int32) rank_in_pack = np.full_like(pack_index, fill_value=-1) for i in range(num_layers): pack_weights = [0] * num_packs pack_items = [0] * num_packs for group in indices[i]: pack = min( (i for i in range(num_packs) if pack_items[i] < groups_per_pack), key=pack_weights.__getitem__, ) assert pack_items[pack] < groups_per_pack pack_index[i, group] = pack rank_in_pack[i, group] = pack_items[pack] pack_weights[pack] += weight[i, group] pack_items[pack] += 1 return pack_index, rank_in_pack def replicate_experts(weight: np.ndarray, num_phy: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. Parameters: weight: [X, num_log] num_phy: total number of experts after replication Returns: phy2log: [X, num_phy], logical expert id of each physical expert rank: [X, num_phy], the replica rank logcnt: [X, num_log], number of replicas for each logical expert """ n, num_log = weight.shape num_redundant = num_phy - num_log assert num_redundant >= 0 phy2log = np.arange(num_phy, dtype=np.int32).reshape(1, -1).repeat(n, axis=0) rank = np.zeros((n, num_phy), dtype=np.int32) logcnt = np.ones((n, num_log), dtype=np.int32) arangen = np.arange(n, dtype=np.int32) for i in range(num_log, num_phy): redundant_indices = np.argmax(weight / logcnt, axis=-1) phy2log[:, i] = redundant_indices rank[:, i] = logcnt[arangen, redundant_indices] logcnt[arangen, redundant_indices] += 1 return phy2log, rank, logcnt def rebalance_experts_hierarchical( weight: np.ndarray, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Parameters: weight: [num_moe_layers, num_logical_experts] num_physical_experts: number of physical experts after replication num_groups: number of expert groups num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: physical_to_logical_map: [num_moe_layers, num_physical_experts] logical_to_physical_map: [num_moe_layers, num_logical_experts, X] logical_count: [num_moe_layers, num_logical_experts] """ num_layers, num_logical_experts = weight.shape assert num_logical_experts % num_groups == 0 group_size = num_logical_experts // num_groups assert num_groups % num_nodes == 0 groups_per_node = num_groups // num_nodes assert num_gpus % num_nodes == 0 assert num_physical_experts % num_gpus == 0 phy_experts_per_gpu = num_physical_experts // num_gpus def inverse(perm: np.ndarray) -> np.ndarray: inv = np.empty_like(perm) inv[np.arange(perm.shape[0])[:, None], perm] = np.arange(perm.shape[1], dtype=np.int32).reshape(1, -1) return inv # Step 1: pack groups to nodes tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(axis=-1) group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) log2mlog = ( ((group_pack_index * groups_per_node + group_rank_in_pack) * group_size)[:, :, None] + np.arange(group_size, dtype=np.int32) ).reshape(num_layers, -1) mlog2log = inverse(log2mlog) # Step 2: construct redundant experts within nodes tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=-1).reshape(-1, num_logical_experts // num_nodes) phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) # Step 3: pack physical_experts to GPUs tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=-1) pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack pphy2phy = inverse(phy2pphy) pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=-1) # [num_layers * num_nodes, num_log_per_nodes] pphy2mlog = ( pphy2mlog.reshape(num_layers, num_nodes, -1) + np.arange( 0, num_logical_experts, num_logical_experts // num_nodes, dtype=np.int32, ).reshape(1, -1, 1) ).reshape(num_layers, -1) pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=-1) pphyrank = np.take_along_axis(phyrank, pphy2phy, axis=-1).reshape(num_layers, -1) logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=-1) return pphy2log, pphyrank, logcnt def rebalance_experts( weight: np.ndarray, num_replicas: int, num_groups: int, num_nodes: int, num_gpus: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Entry point for expert-parallelism load balancer. Parameters: weight: [layers, num_logical_experts], the load statistics for all logical experts num_replicas: number of physical experts, must be a multiple of `num_gpus` num_groups: number of expert groups num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: physical_to_logical_map: [layers, num_replicas], the expert index of each replica logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert """ num_layers, num_logical_experts = weight.shape weight = weight.astype(np.float32) if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( weight, num_replicas, num_groups, num_nodes, num_gpus ) else: # use global load-balance policy phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas) maxlogcnt = logcnt.max() log2phy = np.full((num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int32) np.put_along_axis( log2phy.reshape(num_layers, -1)[:, :, None], (phy2log * maxlogcnt + phyrank)[:, :, None], np.arange(num_replicas, dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0)[:, :, None], axis=1, ) return phy2log, log2phy, logcnt __all__ = ["rebalance_experts"] def main(): """ """ num_hidden_layers = 3 num_expert = 64 num_groups = 8 num_replicas = 64 num_nodes = 4 num_gpus = 4 * 8 model_tokens_per_expert_stats_list = np.random.randint(low=1, high=10, size=(num_hidden_layers, num_expert)) phy2log, phyrank, logcnt = rebalance_experts( model_tokens_per_expert_stats_list, num_replicas, num_groups, num_nodes, num_gpus, ) print(phy2log) print(phyrank) print(logcnt) if __name__ == "__main__": main()