mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[FDConfig]Remove splitwise_role and engine_worker_queue_port in FDConfig (#4147)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* remove splitwise_role and engine_worker_queue_port * fix xpu * fix xpu * fix xpu * fix unittest * resolve conflct
This commit is contained in:
@@ -296,8 +296,6 @@ class ParallelConfig:
|
||||
# Do profile or not
|
||||
self.do_profile: bool = False
|
||||
|
||||
# splitwise role
|
||||
self.splitwise_role: str = "mixed"
|
||||
# guided decoding backend
|
||||
self.guided_decoding_backend: str = None
|
||||
# disable any whitespace for guided decoding
|
||||
@@ -319,14 +317,6 @@ class ParallelConfig:
|
||||
else:
|
||||
self.expert_parallel_size = 1
|
||||
self.use_ep = self.expert_parallel_size > 1
|
||||
if self.splitwise_role == "mixed":
|
||||
self.moe_phase = MoEPhase(phase="prefill")
|
||||
elif self.splitwise_role == "prefill":
|
||||
self.moe_phase = MoEPhase(phase="prefill")
|
||||
elif self.splitwise_role == "decode":
|
||||
self.moe_phase = MoEPhase(phase="decode")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# pd_disaggregation
|
||||
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
@@ -1116,10 +1106,8 @@ class FDConfig:
|
||||
max_model_len: int = 8192,
|
||||
ips: str = None,
|
||||
use_warmup: bool = False,
|
||||
engine_worker_queue_port: str = "8002",
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
splitwise_role: str = "mixed",
|
||||
innode_prefill_ports: Optional[List[int]] = None,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
@@ -1182,7 +1170,6 @@ class FDConfig:
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.use_warmup = use_warmup
|
||||
self.splitwise_role = splitwise_role
|
||||
self.innode_prefill_ports = innode_prefill_ports
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
@@ -1190,11 +1177,7 @@ class FDConfig:
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.engine_worker_queue_port = engine_worker_queue_port
|
||||
self._str_to_list("innode_prefill_ports", int)
|
||||
if isinstance(engine_worker_queue_port, int):
|
||||
self.engine_worker_queue_port = str(engine_worker_queue_port)
|
||||
self._str_to_list("engine_worker_queue_port", str)
|
||||
|
||||
if envs.FD_FOR_TORCH_MODEL_FORMAT:
|
||||
self.model_config.model_format = "torch"
|
||||
@@ -1267,6 +1250,15 @@ class FDConfig:
|
||||
else:
|
||||
self.guided_decoding_backend = "xgrammar"
|
||||
|
||||
if self.scheduler_config.splitwise_role == "mixed":
|
||||
self.model_config.moe_phase = MoEPhase(phase="prefill")
|
||||
elif self.scheduler_config.splitwise_role == "prefill":
|
||||
self.model_config.moe_phase = MoEPhase(phase="prefill")
|
||||
elif self.scheduler_config.splitwise_role == "decode":
|
||||
self.model_config.moe_phase = MoEPhase(phase="decode")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
check the legality of config
|
||||
@@ -1301,7 +1293,7 @@ class FDConfig:
|
||||
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
|
||||
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
|
||||
)
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
assert self.scheduler_config.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
|
||||
assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
|
||||
self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
|
||||
@@ -1387,8 +1379,8 @@ class FDConfig:
|
||||
initialize cache info
|
||||
"""
|
||||
disaggregate_info = {}
|
||||
if self.splitwise_role != "mixed":
|
||||
disaggregate_info["role"] = self.splitwise_role
|
||||
if self.scheduler_config.splitwise_role != "mixed":
|
||||
disaggregate_info["role"] = self.scheduler_config.splitwise_role
|
||||
disaggregate_info["cache_info"] = dict()
|
||||
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
|
||||
disaggregate_info["transfer_protocol"] = current_protocol
|
||||
@@ -1396,7 +1388,9 @@ class FDConfig:
|
||||
if protocol == "ipc":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.engine_worker_queue_port[self.parallel_config.local_data_parallel_id],
|
||||
"port": self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
],
|
||||
"device_ids": self.local_device_ids,
|
||||
}
|
||||
elif protocol == "rdma":
|
||||
|
@@ -1019,6 +1019,11 @@ class EngineArgs:
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
|
||||
if isinstance(self.engine_worker_queue_port, int):
|
||||
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
|
||||
if isinstance(self.engine_worker_queue_port, str):
|
||||
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
|
||||
|
||||
all_dict = asdict(self)
|
||||
all_dict["model_cfg"] = model_cfg
|
||||
cache_cfg = CacheConfig(all_dict)
|
||||
@@ -1032,11 +1037,6 @@ class EngineArgs:
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
|
||||
if isinstance(self.engine_worker_queue_port, int):
|
||||
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
|
||||
if isinstance(self.engine_worker_queue_port, str):
|
||||
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
|
||||
|
||||
assert is_port_available(
|
||||
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
|
||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||
@@ -1052,12 +1052,10 @@ class EngineArgs:
|
||||
speculative_config=speculative_cfg,
|
||||
ips=self.ips,
|
||||
use_warmup=self.use_warmup,
|
||||
engine_worker_queue_port=self.engine_worker_queue_port,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser=self.reasoning_parser,
|
||||
tool_parser=self.tool_call_parser,
|
||||
splitwise_role=self.splitwise_role,
|
||||
innode_prefill_ports=self.innode_prefill_ports,
|
||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||
|
@@ -76,10 +76,10 @@ class EngineService:
|
||||
cfg.scheduler_config.max_num_seqs,
|
||||
cfg,
|
||||
cfg.parallel_config.tensor_parallel_size,
|
||||
cfg.splitwise_role,
|
||||
cfg.scheduler_config.splitwise_role,
|
||||
cfg.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
if cfg.splitwise_role != "mixed":
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
raise NotImplementedError(
|
||||
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
|
||||
)
|
||||
@@ -88,13 +88,13 @@ class EngineService:
|
||||
cfg.scheduler_config.max_num_seqs,
|
||||
cfg,
|
||||
cfg.parallel_config.tensor_parallel_size,
|
||||
cfg.splitwise_role,
|
||||
cfg.scheduler_config.splitwise_role,
|
||||
cfg.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
|
||||
self.start_worker_queue_service(start_queue)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.parallel_config.engine_worker_queue_port[
|
||||
self.cfg.parallel_config.local_data_parallel_id
|
||||
]
|
||||
|
||||
@@ -137,7 +137,9 @@ class EngineService:
|
||||
self.token_processor.run()
|
||||
|
||||
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
||||
current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id])
|
||||
current_suffix = int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
)
|
||||
llm_logger.info(f"current_suffix: {current_suffix}")
|
||||
exist_task_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.exist_task_signal = IPCSignal(
|
||||
@@ -195,7 +197,7 @@ class EngineService:
|
||||
"""
|
||||
address = (
|
||||
self.cfg.master_ip,
|
||||
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
|
||||
int(self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
|
||||
)
|
||||
|
||||
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
|
||||
@@ -209,7 +211,7 @@ class EngineService:
|
||||
|
||||
if (
|
||||
self.cfg.cache_config.enable_prefix_caching
|
||||
or self.cfg.splitwise_role != "mixed"
|
||||
or self.cfg.scheduler_config.splitwise_role != "mixed"
|
||||
and self.cfg.parallel_config.local_data_parallel_id == 0
|
||||
):
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -253,7 +255,10 @@ class EngineService:
|
||||
del self.resource_manager.req_dict[task.request_id]
|
||||
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
|
||||
if (
|
||||
self.cfg.speculative_config.method in ["mtp"]
|
||||
and self.cfg.scheduler_config.splitwise_role == "decode"
|
||||
):
|
||||
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
|
||||
if task.error_code != 200:
|
||||
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||
@@ -478,7 +483,10 @@ class EngineService:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0:
|
||||
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
|
||||
if (
|
||||
self.cfg.scheduler_config.splitwise_role == "mixed"
|
||||
or self.split_connector.has_splitwise_tasks()
|
||||
):
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if self.engine_worker_queue.num_cache_infos() > 0:
|
||||
@@ -507,7 +515,7 @@ class EngineService:
|
||||
continue
|
||||
|
||||
current_id = (current_id + 1) % 100003
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
llm_logger.info("Inserting splitwise tasks")
|
||||
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
||||
|
||||
@@ -759,7 +767,7 @@ class EngineService:
|
||||
device_ids=device_ids,
|
||||
pod_ip=self.cfg.master_ip,
|
||||
engine_worker_queue_port=int(
|
||||
self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
),
|
||||
pid_suffix=ipc_signal_suffix,
|
||||
)
|
||||
|
@@ -115,7 +115,7 @@ class LLMEngine:
|
||||
start_time = time.time()
|
||||
|
||||
self.api_server_pid = api_server_pid
|
||||
self.ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
|
||||
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
|
||||
self._init_worker_signals()
|
||||
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
@@ -127,7 +127,7 @@ class LLMEngine:
|
||||
self.engine.start_zmq_service(api_server_pid)
|
||||
|
||||
if self.do_profile == 0 and (
|
||||
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
|
||||
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
|
||||
):
|
||||
device_ids = self.cfg.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
@@ -161,7 +161,7 @@ class LLMEngine:
|
||||
self._stop_profile()
|
||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
||||
self.launch_components()
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.launched_cache_manager_signal.value[0] = 1
|
||||
|
||||
# Worker launched
|
||||
@@ -311,7 +311,7 @@ class LLMEngine:
|
||||
)
|
||||
|
||||
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.launched_cache_manager_signal = IPCSignal(
|
||||
name="launched_cache_manager_signal",
|
||||
@@ -426,10 +426,10 @@ class LLMEngine:
|
||||
}
|
||||
)
|
||||
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
variables["FLAGS_use_pd_disaggregation"] = 1
|
||||
# TODO dynamic load environment variable
|
||||
if self.cfg.splitwise_role == "prefill":
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
||||
|
||||
if self.cfg.model_config.enable_mm:
|
||||
@@ -463,7 +463,7 @@ class LLMEngine:
|
||||
else len(self.data_processor.tokenizer.vocab)
|
||||
)
|
||||
|
||||
ports = ",".join(self.cfg.engine_worker_queue_port)
|
||||
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
|
||||
ips = None
|
||||
if self.cfg.ips is not None:
|
||||
ips = ",".join(self.cfg.ips)
|
||||
@@ -481,9 +481,9 @@ class LLMEngine:
|
||||
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
|
||||
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
|
||||
f" --pad_token_id {self.data_processor.pad_token_id}"
|
||||
f" --engine_pid {self.cfg.engine_worker_queue_port[0]}"
|
||||
f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
|
||||
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
|
||||
f" --splitwise_role {self.cfg.splitwise_role}"
|
||||
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
|
||||
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
|
||||
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
|
||||
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
|
||||
@@ -602,7 +602,7 @@ class LLMEngine:
|
||||
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
|
||||
self.cfg.cache_config.reset(num_gpu_blocks)
|
||||
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
device_ids = self.cfg.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
|
||||
@@ -619,7 +619,7 @@ class LLMEngine:
|
||||
return True, ""
|
||||
|
||||
def launch_components(self):
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# 单机逻辑
|
||||
self.engine.engine_worker_queue.available_prefill_instances.put(1)
|
||||
self.engine.split_mode_get_tasks()
|
||||
@@ -632,7 +632,7 @@ class LLMEngine:
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
@@ -649,7 +649,7 @@ class LLMEngine:
|
||||
):
|
||||
address = (
|
||||
self.cfg.master_ip,
|
||||
int(self.cfg.engine_worker_queue_port[i]),
|
||||
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
|
||||
)
|
||||
llm_logger.info(f"dp start queue service {address}")
|
||||
self.dp_engine_worker_queue_server.append(
|
||||
|
@@ -50,13 +50,13 @@ class ExpertService:
|
||||
self.cfg = cfg
|
||||
start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node
|
||||
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
|
||||
if cfg.splitwise_role != "mixed":
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
|
||||
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
|
||||
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
|
||||
self.cfg.disaggregate_info = None
|
||||
|
||||
if cfg.splitwise_role != "mixed":
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
if len(self.cfg.cache_config.pd_comm_port) == 1:
|
||||
self.cfg.cache_config.pd_comm_port[0] = (
|
||||
int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
|
||||
@@ -84,21 +84,21 @@ class ExpertService:
|
||||
self.api_server_pid = ipc_signal_suffix
|
||||
self.engine.start_zmq_service(ipc_signal_suffix)
|
||||
else:
|
||||
ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
|
||||
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
|
||||
self.engine.split_mode_get_tasks()
|
||||
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.cfg.init_cache_info()
|
||||
role = self.cfg.splitwise_role
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
self.engine.scheduler.start(role, host_ip, disaggregate)
|
||||
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.splitwise_receive_thread = threading.Thread(
|
||||
target=self.engine.split_connector.start_receiver, args=()
|
||||
)
|
||||
|
@@ -58,7 +58,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
"top_k": layer.top_k,
|
||||
"hidden_size": layer.hidden_size,
|
||||
"num_experts": layer.num_experts,
|
||||
"splitwise_role": layer.fd_config.parallel_config.splitwise_role,
|
||||
"splitwise_role": layer.fd_config.scheduler_config.splitwise_role,
|
||||
"num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
"ep_size": layer.ep_size,
|
||||
"ep_rank": layer.ep_rank,
|
||||
@@ -67,7 +67,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
}
|
||||
|
||||
config = layer.fd_config
|
||||
splitwise_role = config.parallel_config.splitwise_role
|
||||
splitwise_role = config.scheduler_config.splitwise_role
|
||||
load_strategy = config.load_config.load_strategy
|
||||
|
||||
# For "mixed" splitwise role: conditionally initialize both or none
|
||||
@@ -81,7 +81,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
return
|
||||
|
||||
# For non-mixed ep
|
||||
phase = config.parallel_config.moe_phase.phase
|
||||
phase = config.model_config.moe_phase.phase
|
||||
if phase == "prefill":
|
||||
self.ep_prefill_runner = EPPrefillRunner(**common_args)
|
||||
else:
|
||||
@@ -159,12 +159,12 @@ class MoEMethodBase(QuantMethodBase):
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
if layer.ep_size > 1:
|
||||
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
if layer.fd_config.model_config.moe_phase.phase == "prefill":
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
self.ep_prefill_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_prefill(layer, x, gate)
|
||||
else:
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_decode(layer, x, gate)
|
||||
else:
|
||||
|
@@ -219,6 +219,7 @@ class SchedulerConfig:
|
||||
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
|
||||
self.max_num_batched_tokens = 2048
|
||||
self.max_num_seqs = 34
|
||||
self.splitwise_role = "mixed"
|
||||
self.config = None
|
||||
|
||||
for key, value in args.items():
|
||||
|
@@ -150,7 +150,7 @@ class MTPProposer(Proposer):
|
||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
|
||||
):
|
||||
cache_kvs_list = []
|
||||
for i in range(
|
||||
@@ -267,7 +267,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.main_model_num_gpu_blocks = num_gpu_blocks
|
||||
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
if not (self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"):
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
|
@@ -40,7 +40,7 @@ class InternalAdapter:
|
||||
target=self._recv_external_module_control_instruct, daemon=True
|
||||
)
|
||||
self.recv_external_instruct_thread.start()
|
||||
if cfg.splitwise_role != "mixed":
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.response_external_instruct_thread = threading.Thread(
|
||||
target=self._response_external_module_control_instruct, daemon=True
|
||||
)
|
||||
@@ -54,7 +54,7 @@ class InternalAdapter:
|
||||
|
||||
available_block_num = self.engine.resource_manager.available_block_num()
|
||||
server_info = {
|
||||
"splitwise_role": self.cfg.splitwise_role,
|
||||
"splitwise_role": self.cfg.scheduler_config.splitwise_role,
|
||||
"block_size": int(self.cfg.cache_config.block_size),
|
||||
"block_num": int(available_block_num),
|
||||
"max_block_num": int(self.cfg.cache_config.total_block_num),
|
||||
|
@@ -206,7 +206,7 @@ class SplitwiseConnector:
|
||||
"cache_info": {
|
||||
"ipc": {
|
||||
"ip": "0.0.0.0",
|
||||
"port": self.cfg.engine_worker_queue_port[self.idx],
|
||||
"port": self.cfg.parallel_config.engine_worker_queue_port[self.idx],
|
||||
"current_id": current_id,
|
||||
},
|
||||
},
|
||||
@@ -289,7 +289,9 @@ class SplitwiseConnector:
|
||||
if port not in self.connect_innode_instances:
|
||||
self.create_connection(port)
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port[self.idx]
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[
|
||||
self.idx
|
||||
]
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
|
@@ -63,12 +63,12 @@ class DCUModelRunner(GPUModelRunner):
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
self.use_cudagraph
|
||||
|
@@ -651,7 +651,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
if not profile and (
|
||||
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
|
||||
):
|
||||
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
|
||||
else:
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
@@ -1069,7 +1071,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||
)
|
||||
|
||||
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
||||
if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
|
||||
skip_save_output = True
|
||||
else:
|
||||
skip_save_output = False
|
||||
|
@@ -193,7 +193,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
if_only_prefill = True
|
||||
decode_exists = None
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
only_prefill_batch_list = []
|
||||
decode_exists = self.exist_decode()
|
||||
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
|
||||
@@ -211,7 +211,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if_only_decode = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
@@ -1103,8 +1103,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Update config about moe for better performance
|
||||
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||
|
||||
# Update Batch type for cuda graph for only_prefill_batch
|
||||
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
|
||||
@@ -1145,7 +1145,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
if not profile and (
|
||||
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
|
||||
):
|
||||
cache_kvs_list = []
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -1711,7 +1713,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
|
||||
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
||||
if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
|
||||
skip_save_output = True
|
||||
else:
|
||||
skip_save_output = False
|
||||
|
@@ -905,12 +905,12 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
self.use_cudagraph
|
||||
@@ -947,7 +947,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
)
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
if not profile and (
|
||||
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
|
||||
):
|
||||
cache_kvs_list = []
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -1407,7 +1409,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
|
||||
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
||||
if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
|
||||
skip_save_output = True
|
||||
else:
|
||||
skip_save_output = False
|
||||
|
@@ -51,6 +51,7 @@ class WorkerBase(ABC):
|
||||
self.parallel_config = fd_config.parallel_config
|
||||
self.device_config = fd_config.device_config
|
||||
self.cache_config = fd_config.cache_config
|
||||
self.scheduler_config = fd_config.scheduler_config
|
||||
# ... config
|
||||
|
||||
# Device and Runner
|
||||
|
@@ -151,6 +151,7 @@ class PaddleDisWorkerProc:
|
||||
self.fd_config = fd_config
|
||||
self.parallel_config = fd_config.parallel_config
|
||||
self.cache_config = fd_config.cache_config
|
||||
self.scheduler_config = fd_config.scheduler_config
|
||||
|
||||
# TODO(gongshaotian): Use worker factory to get worker
|
||||
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
|
||||
@@ -412,7 +413,7 @@ class PaddleDisWorkerProc:
|
||||
num_blocks_local = self.fd_config.parallel_config.total_block_num
|
||||
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
||||
# wait engine launch cache_manager
|
||||
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
||||
if self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed":
|
||||
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.launched_cache_manager_signal = IPCSignal(
|
||||
name="launched_cache_manager_signal",
|
||||
@@ -762,7 +763,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
early_stop_config=early_stop_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
ips=args.ips,
|
||||
moba_attention_config=moba_attention_config,
|
||||
)
|
||||
|
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -157,6 +158,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
|
||||
scheduler_config.max_num_seqs = 8
|
||||
cache_config = CacheConfig({})
|
||||
parallel_config = ParallelConfig(args={})
|
||||
model_config = Mock()
|
||||
# Initialize cuda graph capture list
|
||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=scheduler_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
|
||||
@@ -165,6 +167,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
cache_config=cache_config,
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -95,10 +96,12 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
||||
cache_config = CacheConfig(args={})
|
||||
scheduler_config.max_num_seqs = 1
|
||||
parallel_config = ParallelConfig(args={})
|
||||
model_config = Mock()
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config,
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
|
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -104,6 +105,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
|
||||
scheduler_config.max_num_seqs = 1
|
||||
cache_config = CacheConfig({})
|
||||
parallel_config = ParallelConfig(args={})
|
||||
model_config = Mock()
|
||||
# Initialize cuda graph capture list
|
||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=scheduler_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
|
||||
@@ -112,6 +114,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
|
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -91,11 +92,13 @@ class TestGraphOptBackend(unittest.TestCase):
|
||||
|
||||
baseline_cache_config = CacheConfig({})
|
||||
baseline_parallel_config = ParallelConfig(args={})
|
||||
model_config = Mock()
|
||||
self.baseline_fd_config = FDConfig(
|
||||
graph_opt_config=baseline_graph_opt_config,
|
||||
scheduler_config=baseline_scheduler_config,
|
||||
cache_config=baseline_cache_config,
|
||||
parallel_config=baseline_parallel_config,
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
@@ -137,6 +140,7 @@ class TestGraphOptBackend(unittest.TestCase):
|
||||
# Setup cache config
|
||||
cache_config = CacheConfig({})
|
||||
parallel_config = ParallelConfig(args={})
|
||||
model_config = Mock()
|
||||
|
||||
# Create FD config
|
||||
return FDConfig(
|
||||
@@ -144,6 +148,7 @@ class TestGraphOptBackend(unittest.TestCase):
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
|
@@ -20,6 +20,7 @@ os.environ["FLAGS_cuda_graph_blacklist"] = "pd_op.matmul,pd_op.transpose"
|
||||
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
@@ -94,12 +95,13 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
|
||||
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
|
||||
cache_config = CacheConfig({})
|
||||
parallel_config = ParallelConfig(args={})
|
||||
|
||||
model_config = Mock()
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
|
@@ -47,6 +47,7 @@ class FakeModelConfig:
|
||||
self.max_position_embeddings = 512
|
||||
self.tie_word_embeddings = True
|
||||
self.model_format = "auto"
|
||||
self.enable_mm = False
|
||||
|
||||
|
||||
def get_default_test_fd_config():
|
||||
@@ -56,12 +57,13 @@ def get_default_test_fd_config():
|
||||
parallel_config = ParallelConfig(args={})
|
||||
parallel_config.data_parallel_rank = 1
|
||||
cache_config = CacheConfig({})
|
||||
model_config = FakeModelConfig()
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config,
|
||||
parallel_config=parallel_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
fd_config.model_config = FakeModelConfig()
|
||||
return fd_config
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import (
|
||||
@@ -18,12 +19,14 @@ class TestConfig(unittest.TestCase):
|
||||
cache_config = CacheConfig({})
|
||||
load_config = LoadConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
model_config = Mock()
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
load_config=load_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
ips=["1.1.1.1", "0.0.0.0"],
|
||||
test_mode=True,
|
||||
)
|
||||
@@ -36,12 +39,14 @@ class TestConfig(unittest.TestCase):
|
||||
cache_config = CacheConfig({})
|
||||
load_config = LoadConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
model_config = Mock()
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
load_config=load_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
ips="0.0.0.0",
|
||||
test_mode=True,
|
||||
)
|
||||
@@ -54,12 +59,15 @@ class TestConfig(unittest.TestCase):
|
||||
load_config = LoadConfig({})
|
||||
cache_config.enable_chunked_prefill = True
|
||||
scheduler_config = SchedulerConfig({})
|
||||
model_config = model_config = Mock()
|
||||
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
ips="0.0.0.0",
|
||||
test_mode=True,
|
||||
)
|
||||
@@ -73,6 +81,7 @@ class TestConfig(unittest.TestCase):
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
ips="0.0.0.0",
|
||||
test_mode=True,
|
||||
)
|
||||
@@ -87,13 +96,16 @@ class TestConfig(unittest.TestCase):
|
||||
cache_config.pd_comm_port = "2334"
|
||||
load_config = LoadConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
scheduler_config.splitwise_role = "prefill"
|
||||
model_config = model_config = Mock()
|
||||
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
scheduler_config=scheduler_config,
|
||||
splitwise_role="prefill",
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
fd_config.init_cache_info()
|
||||
|
Reference in New Issue
Block a user