[FDConfig]Remove splitwise_role and engine_worker_queue_port in FDConfig (#4147)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* remove splitwise_role and engine_worker_queue_port

* fix xpu

* fix xpu

* fix xpu

* fix unittest

* resolve conflct
This commit is contained in:
YuanRisheng
2025-09-19 17:01:52 +08:00
committed by GitHub
parent ee9d8a840a
commit 24180fba0a
23 changed files with 129 additions and 89 deletions

View File

@@ -296,8 +296,6 @@ class ParallelConfig:
# Do profile or not # Do profile or not
self.do_profile: bool = False self.do_profile: bool = False
# splitwise role
self.splitwise_role: str = "mixed"
# guided decoding backend # guided decoding backend
self.guided_decoding_backend: str = None self.guided_decoding_backend: str = None
# disable any whitespace for guided decoding # disable any whitespace for guided decoding
@@ -319,14 +317,6 @@ class ParallelConfig:
else: else:
self.expert_parallel_size = 1 self.expert_parallel_size = 1
self.use_ep = self.expert_parallel_size > 1 self.use_ep = self.expert_parallel_size > 1
if self.splitwise_role == "mixed":
self.moe_phase = MoEPhase(phase="prefill")
elif self.splitwise_role == "prefill":
self.moe_phase = MoEPhase(phase="prefill")
elif self.splitwise_role == "decode":
self.moe_phase = MoEPhase(phase="decode")
else:
raise NotImplementedError
# pd_disaggregation # pd_disaggregation
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
@@ -1116,10 +1106,8 @@ class FDConfig:
max_model_len: int = 8192, max_model_len: int = 8192,
ips: str = None, ips: str = None,
use_warmup: bool = False, use_warmup: bool = False,
engine_worker_queue_port: str = "8002",
limit_mm_per_prompt: Optional[Dict[str, Any]] = None, limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
splitwise_role: str = "mixed",
innode_prefill_ports: Optional[List[int]] = None, innode_prefill_ports: Optional[List[int]] = None,
max_num_partial_prefills: int = 1, max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1, max_long_partial_prefills: int = 1,
@@ -1182,7 +1170,6 @@ class FDConfig:
self.limit_mm_per_prompt = limit_mm_per_prompt self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs self.mm_processor_kwargs = mm_processor_kwargs
self.use_warmup = use_warmup self.use_warmup = use_warmup
self.splitwise_role = splitwise_role
self.innode_prefill_ports = innode_prefill_ports self.innode_prefill_ports = innode_prefill_ports
self.max_num_partial_prefills = max_num_partial_prefills self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills self.max_long_partial_prefills = max_long_partial_prefills
@@ -1190,11 +1177,7 @@ class FDConfig:
self.reasoning_parser = reasoning_parser self.reasoning_parser = reasoning_parser
self.guided_decoding_backend = guided_decoding_backend self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace self.disable_any_whitespace = disable_any_whitespace
self.engine_worker_queue_port = engine_worker_queue_port
self._str_to_list("innode_prefill_ports", int) self._str_to_list("innode_prefill_ports", int)
if isinstance(engine_worker_queue_port, int):
self.engine_worker_queue_port = str(engine_worker_queue_port)
self._str_to_list("engine_worker_queue_port", str)
if envs.FD_FOR_TORCH_MODEL_FORMAT: if envs.FD_FOR_TORCH_MODEL_FORMAT:
self.model_config.model_format = "torch" self.model_config.model_format = "torch"
@@ -1267,6 +1250,15 @@ class FDConfig:
else: else:
self.guided_decoding_backend = "xgrammar" self.guided_decoding_backend = "xgrammar"
if self.scheduler_config.splitwise_role == "mixed":
self.model_config.moe_phase = MoEPhase(phase="prefill")
elif self.scheduler_config.splitwise_role == "prefill":
self.model_config.moe_phase = MoEPhase(phase="prefill")
elif self.scheduler_config.splitwise_role == "decode":
self.model_config.moe_phase = MoEPhase(phase="decode")
else:
raise NotImplementedError
def check(self): def check(self):
""" """
check the legality of config check the legality of config
@@ -1301,7 +1293,7 @@ class FDConfig:
f"max_long_partial_prefills: {self.max_long_partial_prefills} should " f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}" f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
) )
assert self.splitwise_role in ["mixed", "prefill", "decode"] assert self.scheduler_config.splitwise_role in ["mixed", "prefill", "decode"]
# TODO(@wufeisheng): TP and EP need to be supported simultaneously. # TODO(@wufeisheng): TP and EP need to be supported simultaneously.
assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or ( assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1 self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
@@ -1387,8 +1379,8 @@ class FDConfig:
initialize cache info initialize cache info
""" """
disaggregate_info = {} disaggregate_info = {}
if self.splitwise_role != "mixed": if self.scheduler_config.splitwise_role != "mixed":
disaggregate_info["role"] = self.splitwise_role disaggregate_info["role"] = self.scheduler_config.splitwise_role
disaggregate_info["cache_info"] = dict() disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split(",") current_protocol = self.cache_config.cache_transfer_protocol.split(",")
disaggregate_info["transfer_protocol"] = current_protocol disaggregate_info["transfer_protocol"] = current_protocol
@@ -1396,7 +1388,9 @@ class FDConfig:
if protocol == "ipc": if protocol == "ipc":
disaggregate_info["cache_info"][protocol] = { disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip, "ip": self.host_ip,
"port": self.engine_worker_queue_port[self.parallel_config.local_data_parallel_id], "port": self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id
],
"device_ids": self.local_device_ids, "device_ids": self.local_device_ids,
} }
elif protocol == "rdma": elif protocol == "rdma":

View File

@@ -1019,6 +1019,11 @@ class EngineArgs:
else: else:
self.max_num_batched_tokens = self.max_model_len self.max_num_batched_tokens = self.max_model_len
if isinstance(self.engine_worker_queue_port, int):
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
if isinstance(self.engine_worker_queue_port, str):
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
all_dict = asdict(self) all_dict = asdict(self)
all_dict["model_cfg"] = model_cfg all_dict["model_cfg"] = model_cfg
cache_cfg = CacheConfig(all_dict) cache_cfg = CacheConfig(all_dict)
@@ -1032,11 +1037,6 @@ class EngineArgs:
early_stop_cfg = self.create_early_stop_config() early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop) early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
if isinstance(self.engine_worker_queue_port, int):
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
if isinstance(self.engine_worker_queue_port, str):
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
assert is_port_available( assert is_port_available(
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id]) "0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
@@ -1052,12 +1052,10 @@ class EngineArgs:
speculative_config=speculative_cfg, speculative_config=speculative_cfg,
ips=self.ips, ips=self.ips,
use_warmup=self.use_warmup, use_warmup=self.use_warmup,
engine_worker_queue_port=self.engine_worker_queue_port,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser=self.reasoning_parser, reasoning_parser=self.reasoning_parser,
tool_parser=self.tool_call_parser, tool_parser=self.tool_call_parser,
splitwise_role=self.splitwise_role,
innode_prefill_ports=self.innode_prefill_ports, innode_prefill_ports=self.innode_prefill_ports,
max_num_partial_prefills=self.max_num_partial_prefills, max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills,

View File

@@ -76,10 +76,10 @@ class EngineService:
cfg.scheduler_config.max_num_seqs, cfg.scheduler_config.max_num_seqs,
cfg, cfg,
cfg.parallel_config.tensor_parallel_size, cfg.parallel_config.tensor_parallel_size,
cfg.splitwise_role, cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id, cfg.parallel_config.local_data_parallel_id,
) )
if cfg.splitwise_role != "mixed": if cfg.scheduler_config.splitwise_role != "mixed":
raise NotImplementedError( raise NotImplementedError(
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now." "Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
) )
@@ -88,13 +88,13 @@ class EngineService:
cfg.scheduler_config.max_num_seqs, cfg.scheduler_config.max_num_seqs,
cfg, cfg,
cfg.parallel_config.tensor_parallel_size, cfg.parallel_config.tensor_parallel_size,
cfg.splitwise_role, cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id, cfg.parallel_config.local_data_parallel_id,
) )
self.start_worker_queue_service(start_queue) self.start_worker_queue_service(start_queue)
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[ os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.parallel_config.engine_worker_queue_port[
self.cfg.parallel_config.local_data_parallel_id self.cfg.parallel_config.local_data_parallel_id
] ]
@@ -137,7 +137,9 @@ class EngineService:
self.token_processor.run() self.token_processor.run()
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理 def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]) current_suffix = int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
llm_logger.info(f"current_suffix: {current_suffix}") llm_logger.info(f"current_suffix: {current_suffix}")
exist_task_signal_data = np.zeros([1], dtype=np.int32) exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal( self.exist_task_signal = IPCSignal(
@@ -195,7 +197,7 @@ class EngineService:
""" """
address = ( address = (
self.cfg.master_ip, self.cfg.master_ip,
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]), int(self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
) )
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"): if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
@@ -209,7 +211,7 @@ class EngineService:
if ( if (
self.cfg.cache_config.enable_prefix_caching self.cfg.cache_config.enable_prefix_caching
or self.cfg.splitwise_role != "mixed" or self.cfg.scheduler_config.splitwise_role != "mixed"
and self.cfg.parallel_config.local_data_parallel_id == 0 and self.cfg.parallel_config.local_data_parallel_id == 0
): ):
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
@@ -253,7 +255,10 @@ class EngineService:
del self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx] cur_task = self.resource_manager.tasks_list[cur_task_idx]
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": if (
self.cfg.speculative_config.method in ["mtp"]
and self.cfg.scheduler_config.splitwise_role == "decode"
):
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
if task.error_code != 200: if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True self.resource_manager.stop_flags[cur_task_idx] = True
@@ -478,7 +483,10 @@ class EngineService:
time.sleep(0.001) time.sleep(0.001)
continue continue
if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0: if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0:
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks(): if (
self.cfg.scheduler_config.splitwise_role == "mixed"
or self.split_connector.has_splitwise_tasks()
):
time.sleep(0.005) time.sleep(0.005)
continue continue
if self.engine_worker_queue.num_cache_infos() > 0: if self.engine_worker_queue.num_cache_infos() > 0:
@@ -507,7 +515,7 @@ class EngineService:
continue continue
current_id = (current_id + 1) % 100003 current_id = (current_id + 1) % 100003
if self.cfg.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks") llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id) self.split_connector.send_splitwise_tasks(tasks, current_id)
@@ -759,7 +767,7 @@ class EngineService:
device_ids=device_ids, device_ids=device_ids,
pod_ip=self.cfg.master_ip, pod_ip=self.cfg.master_ip,
engine_worker_queue_port=int( engine_worker_queue_port=int(
self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
), ),
pid_suffix=ipc_signal_suffix, pid_suffix=ipc_signal_suffix,
) )

View File

@@ -115,7 +115,7 @@ class LLMEngine:
start_time = time.time() start_time = time.time()
self.api_server_pid = api_server_pid self.api_server_pid = api_server_pid
self.ipc_signal_suffix = self.cfg.engine_worker_queue_port[0] self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
self._init_worker_signals() self._init_worker_signals()
self.data_processor = self.input_processor.create_processor() self.data_processor = self.input_processor.create_processor()
@@ -127,7 +127,7 @@ class LLMEngine:
self.engine.start_zmq_service(api_server_pid) self.engine.start_zmq_service(api_server_pid)
if self.do_profile == 0 and ( if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
): ):
device_ids = self.cfg.device_ids.split(",") device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
@@ -161,7 +161,7 @@ class LLMEngine:
self._stop_profile() self._stop_profile()
# Launch components: scheduler, cache_manager, expert_service et.al. # Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components() self.launch_components()
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1 self.launched_cache_manager_signal.value[0] = 1
# Worker launched # Worker launched
@@ -311,7 +311,7 @@ class LLMEngine:
) )
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager # launched_cache_manager_signal 用于感知engine是否启动了cache_manager
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal( self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal", name="launched_cache_manager_signal",
@@ -426,10 +426,10 @@ class LLMEngine:
} }
) )
if self.cfg.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
variables["FLAGS_use_pd_disaggregation"] = 1 variables["FLAGS_use_pd_disaggregation"] = 1
# TODO dynamic load environment variable # TODO dynamic load environment variable
if self.cfg.splitwise_role == "prefill": if self.cfg.scheduler_config.splitwise_role == "prefill":
variables["FLAGS_fmt_write_cache_completed_signal"] = 1 variables["FLAGS_fmt_write_cache_completed_signal"] = 1
if self.cfg.model_config.enable_mm: if self.cfg.model_config.enable_mm:
@@ -463,7 +463,7 @@ class LLMEngine:
else len(self.data_processor.tokenizer.vocab) else len(self.data_processor.tokenizer.vocab)
) )
ports = ",".join(self.cfg.engine_worker_queue_port) ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
ips = None ips = None
if self.cfg.ips is not None: if self.cfg.ips is not None:
ips = ",".join(self.cfg.ips) ips = ",".join(self.cfg.ips)
@@ -481,9 +481,9 @@ class LLMEngine:
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}" f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
f" --pad_token_id {self.data_processor.pad_token_id}" f" --pad_token_id {self.data_processor.pad_token_id}"
f" --engine_pid {self.cfg.engine_worker_queue_port[0]}" f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}" f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
f" --splitwise_role {self.cfg.splitwise_role}" f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
@@ -602,7 +602,7 @@ class LLMEngine:
num_gpu_blocks = self.get_profile_block_num_signal.value[0] num_gpu_blocks = self.get_profile_block_num_signal.value[0]
self.cfg.cache_config.reset(num_gpu_blocks) self.cfg.cache_config.reset(num_gpu_blocks)
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
device_ids = self.cfg.device_ids.split(",") device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
@@ -619,7 +619,7 @@ class LLMEngine:
return True, "" return True, ""
def launch_components(self): def launch_components(self):
if self.cfg.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
# 单机逻辑 # 单机逻辑
self.engine.engine_worker_queue.available_prefill_instances.put(1) self.engine.engine_worker_queue.available_prefill_instances.put(1)
self.engine.split_mode_get_tasks() self.engine.split_mode_get_tasks()
@@ -632,7 +632,7 @@ class LLMEngine:
self.cfg.init_cache_info() self.cfg.init_cache_info()
role = self.cfg.splitwise_role role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise": if self.cfg.scheduler_config.name == "splitwise":
@@ -649,7 +649,7 @@ class LLMEngine:
): ):
address = ( address = (
self.cfg.master_ip, self.cfg.master_ip,
int(self.cfg.engine_worker_queue_port[i]), int(self.cfg.parallel_config.engine_worker_queue_port[i]),
) )
llm_logger.info(f"dp start queue service {address}") llm_logger.info(f"dp start queue service {address}")
self.dp_engine_worker_queue_server.append( self.dp_engine_worker_queue_server.append(

View File

@@ -50,13 +50,13 @@ class ExpertService:
self.cfg = cfg self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
if cfg.splitwise_role != "mixed": if cfg.scheduler_config.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos] self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}") llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
self.cfg.disaggregate_info = None self.cfg.disaggregate_info = None
if cfg.splitwise_role != "mixed": if cfg.scheduler_config.splitwise_role != "mixed":
if len(self.cfg.cache_config.pd_comm_port) == 1: if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = ( self.cfg.cache_config.pd_comm_port[0] = (
int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
@@ -84,21 +84,21 @@ class ExpertService:
self.api_server_pid = ipc_signal_suffix self.api_server_pid = ipc_signal_suffix
self.engine.start_zmq_service(ipc_signal_suffix) self.engine.start_zmq_service(ipc_signal_suffix)
else: else:
ipc_signal_suffix = self.cfg.engine_worker_queue_port[0] ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
llm_logger.info(f"start expert service {local_data_parallel_id}") llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix) self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
self.engine.split_mode_get_tasks() self.engine.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise": if self.cfg.scheduler_config.name == "splitwise":
self.cfg.init_cache_info() self.cfg.init_cache_info()
role = self.cfg.splitwise_role role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info disaggregate = self.cfg.disaggregate_info
self.engine.scheduler.start(role, host_ip, disaggregate) self.engine.scheduler.start(role, host_ip, disaggregate)
if self.cfg.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
self.splitwise_receive_thread = threading.Thread( self.splitwise_receive_thread = threading.Thread(
target=self.engine.split_connector.start_receiver, args=() target=self.engine.split_connector.start_receiver, args=()
) )

View File

@@ -58,7 +58,7 @@ class MoEMethodBase(QuantMethodBase):
"top_k": layer.top_k, "top_k": layer.top_k,
"hidden_size": layer.hidden_size, "hidden_size": layer.hidden_size,
"num_experts": layer.num_experts, "num_experts": layer.num_experts,
"splitwise_role": layer.fd_config.parallel_config.splitwise_role, "splitwise_role": layer.fd_config.scheduler_config.splitwise_role,
"num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, "num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
"ep_size": layer.ep_size, "ep_size": layer.ep_size,
"ep_rank": layer.ep_rank, "ep_rank": layer.ep_rank,
@@ -67,7 +67,7 @@ class MoEMethodBase(QuantMethodBase):
} }
config = layer.fd_config config = layer.fd_config
splitwise_role = config.parallel_config.splitwise_role splitwise_role = config.scheduler_config.splitwise_role
load_strategy = config.load_config.load_strategy load_strategy = config.load_config.load_strategy
# For "mixed" splitwise role: conditionally initialize both or none # For "mixed" splitwise role: conditionally initialize both or none
@@ -81,7 +81,7 @@ class MoEMethodBase(QuantMethodBase):
return return
# For non-mixed ep # For non-mixed ep
phase = config.parallel_config.moe_phase.phase phase = config.model_config.moe_phase.phase
if phase == "prefill": if phase == "prefill":
self.ep_prefill_runner = EPPrefillRunner(**common_args) self.ep_prefill_runner = EPPrefillRunner(**common_args)
else: else:
@@ -159,12 +159,12 @@ class MoEMethodBase(QuantMethodBase):
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
""" """
if layer.ep_size > 1: if layer.ep_size > 1:
if layer.fd_config.parallel_config.moe_phase.phase == "prefill": if layer.fd_config.model_config.moe_phase.phase == "prefill":
if layer.fd_config.parallel_config.splitwise_role == "mixed": if layer.fd_config.scheduler_config.splitwise_role == "mixed":
self.ep_prefill_runner.clean_low_latency_buffer() self.ep_prefill_runner.clean_low_latency_buffer()
return self.apply_ep_prefill(layer, x, gate) return self.apply_ep_prefill(layer, x, gate)
else: else:
if layer.fd_config.parallel_config.splitwise_role == "mixed": if layer.fd_config.scheduler_config.splitwise_role == "mixed":
self.ep_decoder_runner.clean_low_latency_buffer() self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate) return self.apply_ep_decode(layer, x, gate)
else: else:

View File

@@ -219,6 +219,7 @@ class SchedulerConfig:
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
self.max_num_batched_tokens = 2048 self.max_num_batched_tokens = 2048
self.max_num_seqs = 34 self.max_num_seqs = 34
self.splitwise_role = "mixed"
self.config = None self.config = None
for key, value in args.items(): for key, value in args.items():

View File

@@ -150,7 +150,7 @@ class MTPProposer(Proposer):
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
) )
if not self.parallel_config.do_profile and ( if not self.parallel_config.do_profile and (
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
): ):
cache_kvs_list = [] cache_kvs_list = []
for i in range( for i in range(
@@ -267,7 +267,7 @@ class MTPProposer(Proposer):
self.main_model_num_gpu_blocks = num_gpu_blocks self.main_model_num_gpu_blocks = num_gpu_blocks
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): if not (self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"):
self.initialize_kv_cache() self.initialize_kv_cache()
# Reset free list # Reset free list

View File

@@ -40,7 +40,7 @@ class InternalAdapter:
target=self._recv_external_module_control_instruct, daemon=True target=self._recv_external_module_control_instruct, daemon=True
) )
self.recv_external_instruct_thread.start() self.recv_external_instruct_thread.start()
if cfg.splitwise_role != "mixed": if cfg.scheduler_config.splitwise_role != "mixed":
self.response_external_instruct_thread = threading.Thread( self.response_external_instruct_thread = threading.Thread(
target=self._response_external_module_control_instruct, daemon=True target=self._response_external_module_control_instruct, daemon=True
) )
@@ -54,7 +54,7 @@ class InternalAdapter:
available_block_num = self.engine.resource_manager.available_block_num() available_block_num = self.engine.resource_manager.available_block_num()
server_info = { server_info = {
"splitwise_role": self.cfg.splitwise_role, "splitwise_role": self.cfg.scheduler_config.splitwise_role,
"block_size": int(self.cfg.cache_config.block_size), "block_size": int(self.cfg.cache_config.block_size),
"block_num": int(available_block_num), "block_num": int(available_block_num),
"max_block_num": int(self.cfg.cache_config.total_block_num), "max_block_num": int(self.cfg.cache_config.total_block_num),

View File

@@ -206,7 +206,7 @@ class SplitwiseConnector:
"cache_info": { "cache_info": {
"ipc": { "ipc": {
"ip": "0.0.0.0", "ip": "0.0.0.0",
"port": self.cfg.engine_worker_queue_port[self.idx], "port": self.cfg.parallel_config.engine_worker_queue_port[self.idx],
"current_id": current_id, "current_id": current_id,
}, },
}, },
@@ -289,7 +289,9 @@ class SplitwiseConnector:
if port not in self.connect_innode_instances: if port not in self.connect_innode_instances:
self.create_connection(port) self.create_connection(port)
for task in tasks: for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port[self.idx] task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[
self.idx
]
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks: for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port task.disaggregate_info["cache_info"]["ipc"]["port"] = port

View File

@@ -63,12 +63,12 @@ class DCUModelRunner(GPUModelRunner):
only_decode_batch = True only_decode_batch = True
prefill_exists = None prefill_exists = None
# mix ep in single node # mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = [] only_decode_batch_list = []
prefill_exists = self.exist_prefill() prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list) only_decode_batch = all(only_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = ( self.forward_meta.step_use_cudagraph = (
self.use_cudagraph self.use_cudagraph

View File

@@ -651,7 +651,9 @@ class GCUModelRunner(ModelRunnerBase):
) )
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.") raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
else: else:
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
@@ -1069,7 +1071,7 @@ class GCUModelRunner(ModelRunnerBase):
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
) )
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
skip_save_output = True skip_save_output = True
else: else:
skip_save_output = False skip_save_output = False

View File

@@ -193,7 +193,7 @@ class GPUModelRunner(ModelRunnerBase):
""" """
if_only_prefill = True if_only_prefill = True
decode_exists = None decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_prefill_batch_list = [] only_prefill_batch_list = []
decode_exists = self.exist_decode() decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists) paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
@@ -211,7 +211,7 @@ class GPUModelRunner(ModelRunnerBase):
if_only_decode = True if_only_decode = True
prefill_exists = None prefill_exists = None
# mix ep in single node # mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = [] only_decode_batch_list = []
prefill_exists = self.exist_prefill() prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
@@ -1103,8 +1103,8 @@ class GPUModelRunner(ModelRunnerBase):
# Update config about moe for better performance # Update config about moe for better performance
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill" self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
# Update Batch type for cuda graph for only_prefill_batch # Update Batch type for cuda graph for only_prefill_batch
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
@@ -1145,7 +1145,9 @@ class GPUModelRunner(ModelRunnerBase):
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
cache_kvs_list = [] cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -1711,7 +1713,7 @@ class GPUModelRunner(ModelRunnerBase):
stop_seqs_len=self.share_inputs["stop_seqs_len"], stop_seqs_len=self.share_inputs["stop_seqs_len"],
) )
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
skip_save_output = True skip_save_output = True
else: else:
skip_save_output = False skip_save_output = False

View File

@@ -905,12 +905,12 @@ class MetaxModelRunner(ModelRunnerBase):
only_decode_batch = True only_decode_batch = True
prefill_exists = None prefill_exists = None
# mix ep in single node # mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = [] only_decode_batch_list = []
prefill_exists = self.exist_prefill() prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list) only_decode_batch = all(only_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = ( self.forward_meta.step_use_cudagraph = (
self.use_cudagraph self.use_cudagraph
@@ -947,7 +947,9 @@ class MetaxModelRunner(ModelRunnerBase):
) )
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
cache_kvs_list = [] cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -1407,7 +1409,7 @@ class MetaxModelRunner(ModelRunnerBase):
stop_seqs_len=self.share_inputs["stop_seqs_len"], stop_seqs_len=self.share_inputs["stop_seqs_len"],
) )
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
skip_save_output = True skip_save_output = True
else: else:
skip_save_output = False skip_save_output = False

View File

@@ -51,6 +51,7 @@ class WorkerBase(ABC):
self.parallel_config = fd_config.parallel_config self.parallel_config = fd_config.parallel_config
self.device_config = fd_config.device_config self.device_config = fd_config.device_config
self.cache_config = fd_config.cache_config self.cache_config = fd_config.cache_config
self.scheduler_config = fd_config.scheduler_config
# ... config # ... config
# Device and Runner # Device and Runner

View File

@@ -151,6 +151,7 @@ class PaddleDisWorkerProc:
self.fd_config = fd_config self.fd_config = fd_config
self.parallel_config = fd_config.parallel_config self.parallel_config = fd_config.parallel_config
self.cache_config = fd_config.cache_config self.cache_config = fd_config.cache_config
self.scheduler_config = fd_config.scheduler_config
# TODO(gongshaotian): Use worker factory to get worker # TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks) self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
@@ -412,7 +413,7 @@ class PaddleDisWorkerProc:
num_blocks_local = self.fd_config.parallel_config.total_block_num num_blocks_local = self.fd_config.parallel_config.total_block_num
logger.info(f"------- num_blocks_global: {num_blocks_local} --------") logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager # wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": if self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal( self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal", name="launched_cache_manager_signal",
@@ -762,7 +763,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
early_stop_config=early_stop_config, early_stop_config=early_stop_config,
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
engine_worker_queue_port=args.engine_worker_queue_port,
ips=args.ips, ips=args.ips,
moba_attention_config=moba_attention_config, moba_attention_config=moba_attention_config,
) )

View File

@@ -15,6 +15,7 @@
""" """
import unittest import unittest
from unittest.mock import Mock
import paddle import paddle
@@ -157,6 +158,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
scheduler_config.max_num_seqs = 8 scheduler_config.max_num_seqs = 8
cache_config = CacheConfig({}) cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
model_config = Mock()
# Initialize cuda graph capture list # Initialize cuda graph capture list
graph_opt_config._set_cudagraph_sizes(max_num_seqs=scheduler_config.max_num_seqs) graph_opt_config._set_cudagraph_sizes(max_num_seqs=scheduler_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs) graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
@@ -165,6 +167,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
parallel_config=parallel_config, parallel_config=parallel_config,
cache_config=cache_config, cache_config=cache_config,
model_config=model_config,
test_mode=True, test_mode=True,
) )

View File

@@ -1,4 +1,5 @@
import unittest import unittest
from unittest.mock import Mock
import paddle import paddle
@@ -95,10 +96,12 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
cache_config = CacheConfig(args={}) cache_config = CacheConfig(args={})
scheduler_config.max_num_seqs = 1 scheduler_config.max_num_seqs = 1
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
model_config = Mock()
fd_config = FDConfig( fd_config = FDConfig(
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
cache_config=cache_config, cache_config=cache_config,
model_config=model_config,
parallel_config=parallel_config, parallel_config=parallel_config,
) )

View File

@@ -15,6 +15,7 @@
""" """
import unittest import unittest
from unittest.mock import Mock
import paddle import paddle
@@ -104,6 +105,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
scheduler_config.max_num_seqs = 1 scheduler_config.max_num_seqs = 1
cache_config = CacheConfig({}) cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
model_config = Mock()
# Initialize cuda graph capture list # Initialize cuda graph capture list
graph_opt_config._set_cudagraph_sizes(max_num_seqs=scheduler_config.max_num_seqs) graph_opt_config._set_cudagraph_sizes(max_num_seqs=scheduler_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs) graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
@@ -112,6 +114,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
model_config=model_config,
test_mode=True, test_mode=True,
) )

View File

@@ -15,6 +15,7 @@
""" """
import unittest import unittest
from unittest.mock import Mock
import numpy as np import numpy as np
import paddle import paddle
@@ -91,11 +92,13 @@ class TestGraphOptBackend(unittest.TestCase):
baseline_cache_config = CacheConfig({}) baseline_cache_config = CacheConfig({})
baseline_parallel_config = ParallelConfig(args={}) baseline_parallel_config = ParallelConfig(args={})
model_config = Mock()
self.baseline_fd_config = FDConfig( self.baseline_fd_config = FDConfig(
graph_opt_config=baseline_graph_opt_config, graph_opt_config=baseline_graph_opt_config,
scheduler_config=baseline_scheduler_config, scheduler_config=baseline_scheduler_config,
cache_config=baseline_cache_config, cache_config=baseline_cache_config,
parallel_config=baseline_parallel_config, parallel_config=baseline_parallel_config,
model_config=model_config,
test_mode=True, test_mode=True,
) )
@@ -137,6 +140,7 @@ class TestGraphOptBackend(unittest.TestCase):
# Setup cache config # Setup cache config
cache_config = CacheConfig({}) cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
model_config = Mock()
# Create FD config # Create FD config
return FDConfig( return FDConfig(
@@ -144,6 +148,7 @@ class TestGraphOptBackend(unittest.TestCase):
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
model_config=model_config,
test_mode=True, test_mode=True,
) )

View File

@@ -20,6 +20,7 @@ os.environ["FLAGS_cuda_graph_blacklist"] = "pd_op.matmul,pd_op.transpose"
import unittest import unittest
from unittest.mock import Mock
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
@@ -94,12 +95,13 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs) graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
cache_config = CacheConfig({}) cache_config = CacheConfig({})
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
model_config = Mock()
fd_config = FDConfig( fd_config = FDConfig(
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
model_config=model_config,
test_mode=True, test_mode=True,
) )

View File

@@ -47,6 +47,7 @@ class FakeModelConfig:
self.max_position_embeddings = 512 self.max_position_embeddings = 512
self.tie_word_embeddings = True self.tie_word_embeddings = True
self.model_format = "auto" self.model_format = "auto"
self.enable_mm = False
def get_default_test_fd_config(): def get_default_test_fd_config():
@@ -56,12 +57,13 @@ def get_default_test_fd_config():
parallel_config = ParallelConfig(args={}) parallel_config = ParallelConfig(args={})
parallel_config.data_parallel_rank = 1 parallel_config.data_parallel_rank = 1
cache_config = CacheConfig({}) cache_config = CacheConfig({})
model_config = FakeModelConfig()
fd_config = FDConfig( fd_config = FDConfig(
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
parallel_config=parallel_config, parallel_config=parallel_config,
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config,
test_mode=True, test_mode=True,
) )
fd_config.model_config = FakeModelConfig()
return fd_config return fd_config

View File

@@ -1,4 +1,5 @@
import unittest import unittest
from unittest.mock import Mock
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import ( from fastdeploy.config import (
@@ -18,12 +19,14 @@ class TestConfig(unittest.TestCase):
cache_config = CacheConfig({}) cache_config = CacheConfig({})
load_config = LoadConfig({}) load_config = LoadConfig({})
scheduler_config = SchedulerConfig({}) scheduler_config = SchedulerConfig({})
model_config = Mock()
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
load_config=load_config, load_config=load_config,
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config,
ips=["1.1.1.1", "0.0.0.0"], ips=["1.1.1.1", "0.0.0.0"],
test_mode=True, test_mode=True,
) )
@@ -36,12 +39,14 @@ class TestConfig(unittest.TestCase):
cache_config = CacheConfig({}) cache_config = CacheConfig({})
load_config = LoadConfig({}) load_config = LoadConfig({})
scheduler_config = SchedulerConfig({}) scheduler_config = SchedulerConfig({})
model_config = Mock()
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
load_config=load_config, load_config=load_config,
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config,
ips="0.0.0.0", ips="0.0.0.0",
test_mode=True, test_mode=True,
) )
@@ -54,12 +59,15 @@ class TestConfig(unittest.TestCase):
load_config = LoadConfig({}) load_config = LoadConfig({})
cache_config.enable_chunked_prefill = True cache_config.enable_chunked_prefill = True
scheduler_config = SchedulerConfig({}) scheduler_config = SchedulerConfig({})
model_config = model_config = Mock()
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
cache_config=cache_config, cache_config=cache_config,
load_config=load_config, load_config=load_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config,
ips="0.0.0.0", ips="0.0.0.0",
test_mode=True, test_mode=True,
) )
@@ -73,6 +81,7 @@ class TestConfig(unittest.TestCase):
cache_config=cache_config, cache_config=cache_config,
load_config=load_config, load_config=load_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config,
ips="0.0.0.0", ips="0.0.0.0",
test_mode=True, test_mode=True,
) )
@@ -87,13 +96,16 @@ class TestConfig(unittest.TestCase):
cache_config.pd_comm_port = "2334" cache_config.pd_comm_port = "2334"
load_config = LoadConfig({}) load_config = LoadConfig({})
scheduler_config = SchedulerConfig({}) scheduler_config = SchedulerConfig({})
scheduler_config.splitwise_role = "prefill"
model_config = model_config = Mock()
fd_config = FDConfig( fd_config = FDConfig(
parallel_config=parallel_config, parallel_config=parallel_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
cache_config=cache_config, cache_config=cache_config,
load_config=load_config, load_config=load_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
splitwise_role="prefill", model_config=model_config,
test_mode=True, test_mode=True,
) )
fd_config.init_cache_info() fd_config.init_cache_info()