polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
@@ -32,7 +33,7 @@ from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
class ExpertService(object):
class ExpertService:
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -51,17 +52,14 @@ class ExpertService(object):
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
end_pos = ((local_data_parallel_id + 1) * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[
start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(
",")[start_pos:end_pos]
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.cfg.disaggregate_info = None
self.scheduler = cfg.scheduler_config.scheduler()
self.scheduler.reset_nodeid(
f"{self.scheduler.infer.nodeid}_{str(local_data_parallel_id)}")
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
@@ -73,33 +71,41 @@ class ExpertService(object):
num_client=cfg.tensor_parallel_size,
local_data_parallel_id=local_data_parallel_id,
)
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, \
cfg.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id)
self.resource_manager = ResourceManager(
cfg.max_num_seqs,
cfg,
cfg.tensor_parallel_size,
cfg.splitwise_role,
local_data_parallel_id,
)
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = int(
self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
self.cfg.cache_config.pd_comm_port[0] = int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
else:
self.cfg.cache_config.pd_comm_port = [
self.cfg.cache_config.pd_comm_port[local_data_parallel_id]
]
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.split_connector = SplitwiseConnector(self.cfg, self.scheduler,
self.engine_worker_queue,
self.resource_manager)
self.split_connector = SplitwiseConnector(
self.cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
)
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector)
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.partial_chunked_tokens = [0] * (
self.cfg.max_num_partial_prefills + 1)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \
// self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
self.partial_chunked_tokens[idx] = (
(self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
@@ -120,17 +126,15 @@ class ExpertService(object):
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}"
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
self.insert_task_to_worker_thread = threading.Thread(
target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread.daemon = True
self.insert_task_to_worker_thread.start()
# Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.token_processor.run()
@@ -144,9 +148,7 @@ class ExpertService(object):
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
console_logger.info(
"Worker processes are launched with {} seconds.".format(
time.time() - start_time))
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
return True
def _insert_task_to_worker(self):
@@ -169,17 +171,17 @@ class ExpertService(object):
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch)
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(
),
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.
enc_dec_block_num,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch)
batch=num_prefill_batch,
)
if len(tasks) == 0:
time.sleep(0.001)
@@ -187,8 +189,7 @@ class ExpertService(object):
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(
tasks, current_id)
self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003
@@ -197,8 +198,7 @@ class ExpertService(object):
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format(
e, str(traceback.format_exc()))
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
def split_mode_get_tasks(self):
@@ -212,15 +212,13 @@ class ExpertService(object):
try:
if len(waiting_requests) > 0:
for task in waiting_requests:
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
waiting_requests.remove(task)
else:
break
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks(
)
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
@@ -231,7 +229,7 @@ class ExpertService(object):
self.insert_tasks(tasks)
elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], 'finished'):
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
@@ -246,7 +244,8 @@ class ExpertService(object):
else:
for task in tasks:
if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
task.prompt_token_ids_len
):
waiting_requests.append(task)
else:
self.insert_tasks([task])
@@ -274,8 +273,7 @@ class ExpertService(object):
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[
task.request_id]
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
@@ -285,8 +283,7 @@ class ExpertService(object):
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks(
(current_tasks, self.resource_manager.real_bsz))
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
@@ -299,9 +296,7 @@ class ExpertService(object):
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(
"Inserting batch:{} exceeds the available batch:{}.".format(
len(tasks), available_batch))
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
@@ -325,8 +320,7 @@ class ExpertService(object):
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[
i].prompt_token_ids_len
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks:
@@ -338,8 +332,7 @@ class ExpertService(object):
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks(
(tasks, self.resource_manager.real_bsz))
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
return True
def _exit_sub_services(self):
@@ -348,8 +341,7 @@ class ExpertService(object):
"""
if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear(
)
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")