add w4afp8 offline script (#3636)

This commit is contained in:
Yuan Xiaolan
2025-08-29 17:56:05 +08:00
committed by GitHub
parent f677c032c0
commit c71ee0831c
12 changed files with 163 additions and 37 deletions

View File

@@ -226,8 +226,8 @@ __global__ void permute_scale_kernel(
}
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
const int row = scale.dims()[0];
const int col = scale.dims()[1];
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
if (col % 16 != 0) {
PD_THROW("Only supported when col is divisible by 16.");
}

View File

@@ -83,10 +83,18 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
}}
"""
gemm_case = [[256, 256, 1, 0]]
gemm_case = [
[8192, 3584, 8, 0], # eb45T ffn1
[8192, 3584, 8, 2048], # eb45T ffn1
[7168, 8192, 8, 0], # eb45T ffn2
[7168, 8192, 8, 2048], # eb45T ffn2
]
dtype = ["BF16"]
use_fast_compile = True
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
def get_cutlass_type(type):
if type == "BF16":
@@ -100,7 +108,7 @@ template_head_file.write(gemm_template_head)
for type in dtype:
for case in gemm_case:
for n in range(16, 257, 16):
for n in n_range:
template_head_file.write(
gemm_template_case.format(
M=case[0],
@@ -176,7 +184,7 @@ for type in dtype:
template_head_file.write("\n")
for case in gemm_case:
for n in range(16, 257, 16):
for n in n_range:
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(

View File

@@ -60,12 +60,13 @@ curl -i http://0.0.0.0:8180/health
Send requests to the service with the following command:
```shell
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
curl -X POST "http://0.0.0.0:1822/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Write me a poem about large language model."}
]
],
"stream": true
}'
```

View File

@@ -52,7 +52,7 @@ class EngineSevice:
Base class containing common engine functionality
"""
def __init__(self, cfg):
def __init__(self, cfg, start_queue=True):
"""
Initializes the LLMEngine with the provided configuration.
@@ -84,7 +84,7 @@ class EngineSevice:
cfg.parallel_config.local_data_parallel_id,
)
self.start_worker_queue_service()
self.start_worker_queue_service(start_queue)
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[
self.cfg.parallel_config.local_data_parallel_id
@@ -181,7 +181,7 @@ class EngineSevice:
create=True,
)
def start_worker_queue_service(self):
def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
"""
@@ -189,7 +189,8 @@ class EngineSevice:
self.cfg.master_ip,
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
)
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,

View File

@@ -38,7 +38,7 @@ from fastdeploy.engine.common_engine import EngineSevice
from fastdeploy.engine.expert_service import start_data_parallel_service
from fastdeploy.engine.request import Request
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
@@ -362,7 +362,10 @@ class LLMEngine:
self.zmq_server.close()
if hasattr(self, "dp_processed"):
for p in self.dp_processed:
console_logger.info(f"Waiting for worker {p.pid} to exit")
p.join()
for p in self.dp_engine_worker_queue_server:
p.cleanup()
def _setting_environ_variables(self):
"""
@@ -610,11 +613,26 @@ class LLMEngine:
if not envs.FD_ENABLE_MULTI_API_SERVER:
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.launched_expert_service_signal.value[0] = 1
self.dp_processed = []
self.dp_engine_worker_queue_server = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
address = (
self.cfg.master_ip,
int(self.cfg.engine_worker_queue_port[i]),
)
llm_logger.info(f"dp start queue service {address}")
self.dp_engine_worker_queue_server.append(
EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
)
self.dp_processed.append(
multiprocessing.Process(
target=start_data_parallel_service,
@@ -625,7 +643,7 @@ class LLMEngine:
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()

View File

@@ -39,7 +39,7 @@ class ExpertService:
local_data_parallel_id (int): Local data parallel ID.
"""
def __init__(self, cfg, local_data_parallel_id):
def __init__(self, cfg, local_data_parallel_id, start_queue=True):
"""
Initializes the LLMEngine with the provided configuration.
@@ -64,8 +64,7 @@ class ExpertService:
else:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.engine = EngineSevice(self.cfg)
self.engine = EngineSevice(self.cfg, start_queue)
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
@@ -149,7 +148,7 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id)
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
@@ -160,6 +159,5 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
t_deamon.start()
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")

View File

@@ -428,9 +428,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).unsqueeze()
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).unsqueeze()
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).unsqueeze()
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,

View File

@@ -283,8 +283,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_weight_scale_inv": up_gate_proj_weight_scale,
"down_proj_weight_scale_inv": down_proj_weight_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)

View File

@@ -165,8 +165,22 @@ class PaddleDisWorkerProc:
exist_swapped_task_signal:
model_weights_status:
"""
# init worker_ready_signal
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
launched_expert_service_signal_data = np.zeros(
shape=[min(self.parallel_config.data_parallel_size, self.max_chips_per_node)], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
pass
# init worker_ready_signal
array_size = min(
self.max_chips_per_node,
self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
@@ -242,6 +256,7 @@ class PaddleDisWorkerProc:
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
req_ids = []
num_running_requests = 0
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
while True:
if self.local_rank == 0:
if self.model_weights_status.value[0] != 0:
@@ -255,7 +270,6 @@ class PaddleDisWorkerProc:
self.insert_step = False
req_dicts = None
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue

View File

@@ -18,6 +18,7 @@ from fastdeploy.model_executor.load_weight_utils import (
get_all_safetensors,
safetensors_weights_iterator,
)
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
def parse_arguments():
@@ -47,14 +48,21 @@ def parse_arguments():
help="Whether merge the model into safetensors format.",
)
parser.add_argument(
"--moe_quant_type",
default="w4a8",
choices=["w4a8", "w4afp8"],
help="The moe quant type of the model.",
)
return parser.parse_args()
def reorder():
def fn(weight):
def fn(weight, moe_quant_type):
from paddle.nn.quant import weight_quantize
quant_weight, _ = weight_quantize(weight.cuda(), algo="w4a8", arch=80)
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
return quant_weight.cpu()
return fn
@@ -69,22 +77,27 @@ def deal_in_scale():
def deal_weight_scale():
def fn(weight_scale, processed_in_scale):
def fn(weight_scale, processed_in_scale, moe_quant_type):
if moe_quant_type == "w4a8":
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
return processed_weight_scale
elif moe_quant_type == "w4afp8":
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
return processed_weight_scale
return fn
# tmp support w4a8
def deal_quant(state_dict, save_state_dict):
w4a8_mapping = [
def deal_quant(state_dict, save_state_dict, moe_quant_type):
param_mapping = [
# pattern,fn
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
]
for pattern, fn in w4a8_mapping:
for pattern, fn in param_mapping:
for key in list(state_dict.keys()):
# print(f"deal {key}")
match = re.search(pattern, key)
@@ -94,9 +107,11 @@ def deal_quant(state_dict, save_state_dict):
if "weight_scale" in key:
in_scale_key = key.replace("weight_scale", "activation_scale")
in_scale = save_state_dict[in_scale_key]
save_state_dict[key] = fn(weight_or_scale, in_scale)
else:
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
elif "activation_scale" in key:
save_state_dict[key] = fn(weight_or_scale)
else:
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
def save_safetensors(state_dict, args):
@@ -153,7 +168,7 @@ def main():
end = time.perf_counter()
logger.info("Finish Quantize.")
logger.info(f"load and quantize took : {end - start:.6f} seconds")
deal_quant(state_dict, save_state_dict)
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
for key in list(state_dict.keys()):
save_state_dict[key] = state_dict.pop(key)
logger.info("Begin to save model")

View File

@@ -19,6 +19,7 @@ export devices=0
export CUDA_VISIBLE_DEVICES=${devices}
model_path=${1:-"/PATH/MODEL_PATH"}
output_path=${2:-"/PATH/OUTPUT_MODEL"}
moe_quant_type=${3:-"w4a8"}
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
unset ${name}
done
@@ -31,4 +32,5 @@ self_ip=`hostname -i`
python offline_w4a8.py \
--model_name_or_path ${model_path} \
--output_dir ${output_path} \
--safe_serialization "True"
--safe_serialization "True" \
--moe_quant_type ${moe_quant_type}

View File

@@ -0,0 +1,69 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import weakref
import pytest
from fastdeploy.entrypoints.llm import LLM
bash_path = os.getenv("MODEL_PATH")
FD_ENGINE_QUEUE_PORTS = [
[9961, 9962, 9963, 9964, 9965, 9966, 9967, 9968],
[9971, 9972, 9973, 9974, 9975, 9976, 9977, 9978],
[9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988],
[9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998],
]
models = [
"ernie-4_5-fake-w4a8-unpermuted",
"ernie-4_5-fake-w4a8-permuted",
"ernie-4_5-fake-w4afp8-unpermuted",
"ernie-4_5-fake-w4afp8-permuted",
]
prompts = ["解释下“温故而知新"]
@pytest.fixture(scope="module", params=models)
def llm(request):
"""LLM测试夹具"""
model_path = os.path.join(bash_path, request.param)
try:
port_index = models.index(request.param) % len(FD_ENGINE_QUEUE_PORTS)
llm_instance = LLM(
model=model_path,
tensor_parallel_size=1,
data_parallel_size=8,
max_model_len=8192,
num_gpu_blocks_override=1024,
engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index],
load_choices="default",
enable_expert_parallel=True,
)
yield weakref.proxy(llm_instance)
except Exception as e:
pytest.skip(f"LLM initialization failed: {e}")
@pytest.mark.timeout(60)
def test_generation(llm):
print(f"testing generation with model: {llm}")
# topp_params = SamplingParams(temperature=0.1, top_p=0, max_tokens=20)
# output = llm.generate(prompts=prompts, sampling_params=topp_params)
# print(output)