mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
add w4afp8 offline script (#3636)
This commit is contained in:
@@ -226,8 +226,8 @@ __global__ void permute_scale_kernel(
|
||||
}
|
||||
|
||||
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
||||
const int row = scale.dims()[0];
|
||||
const int col = scale.dims()[1];
|
||||
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
|
||||
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
|
||||
if (col % 16 != 0) {
|
||||
PD_THROW("Only supported when col is divisible by 16.");
|
||||
}
|
||||
|
@@ -83,10 +83,18 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
}}
|
||||
"""
|
||||
|
||||
gemm_case = [[256, 256, 1, 0]]
|
||||
gemm_case = [
|
||||
[8192, 3584, 8, 0], # eb45T ffn1
|
||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||
[7168, 8192, 8, 0], # eb45T ffn2
|
||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||
]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
use_fast_compile = True
|
||||
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
|
||||
|
||||
|
||||
def get_cutlass_type(type):
|
||||
if type == "BF16":
|
||||
@@ -100,7 +108,7 @@ template_head_file.write(gemm_template_head)
|
||||
|
||||
for type in dtype:
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
for n in n_range:
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
@@ -176,7 +184,7 @@ for type in dtype:
|
||||
template_head_file.write("\n")
|
||||
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
for n in n_range:
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
|
@@ -60,12 +60,13 @@ curl -i http://0.0.0.0:8180/health
|
||||
Send requests to the service with the following command:
|
||||
|
||||
```shell
|
||||
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
|
||||
curl -X POST "http://0.0.0.0:1822/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write me a poem about large language model."}
|
||||
]
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
|
@@ -52,7 +52,7 @@ class EngineSevice:
|
||||
Base class containing common engine functionality
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, start_queue=True):
|
||||
"""
|
||||
Initializes the LLMEngine with the provided configuration.
|
||||
|
||||
@@ -84,7 +84,7 @@ class EngineSevice:
|
||||
cfg.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
|
||||
self.start_worker_queue_service()
|
||||
self.start_worker_queue_service(start_queue)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[
|
||||
self.cfg.parallel_config.local_data_parallel_id
|
||||
@@ -181,7 +181,7 @@ class EngineSevice:
|
||||
create=True,
|
||||
)
|
||||
|
||||
def start_worker_queue_service(self):
|
||||
def start_worker_queue_service(self, start_queue):
|
||||
"""
|
||||
start queue service for engine worker communication
|
||||
"""
|
||||
@@ -189,7 +189,8 @@ class EngineSevice:
|
||||
self.cfg.master_ip,
|
||||
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
|
||||
)
|
||||
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
||||
|
||||
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
|
||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
address=address,
|
||||
|
@@ -38,7 +38,7 @@ from fastdeploy.engine.common_engine import EngineSevice
|
||||
from fastdeploy.engine.expert_service import start_data_parallel_service
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
|
||||
@@ -362,7 +362,10 @@ class LLMEngine:
|
||||
self.zmq_server.close()
|
||||
if hasattr(self, "dp_processed"):
|
||||
for p in self.dp_processed:
|
||||
console_logger.info(f"Waiting for worker {p.pid} to exit")
|
||||
p.join()
|
||||
for p in self.dp_engine_worker_queue_server:
|
||||
p.cleanup()
|
||||
|
||||
def _setting_environ_variables(self):
|
||||
"""
|
||||
@@ -610,11 +613,26 @@ class LLMEngine:
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.launched_expert_service_signal.value[0] = 1
|
||||
self.dp_processed = []
|
||||
self.dp_engine_worker_queue_server = []
|
||||
for i in range(
|
||||
1,
|
||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||
):
|
||||
address = (
|
||||
self.cfg.master_ip,
|
||||
int(self.cfg.engine_worker_queue_port[i]),
|
||||
)
|
||||
llm_logger.info(f"dp start queue service {address}")
|
||||
self.dp_engine_worker_queue_server.append(
|
||||
EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
target=start_data_parallel_service,
|
||||
@@ -625,7 +643,7 @@ class LLMEngine:
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
||||
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
self.dp_processed[-1].start()
|
||||
|
@@ -39,7 +39,7 @@ class ExpertService:
|
||||
local_data_parallel_id (int): Local data parallel ID.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, local_data_parallel_id):
|
||||
def __init__(self, cfg, local_data_parallel_id, start_queue=True):
|
||||
"""
|
||||
Initializes the LLMEngine with the provided configuration.
|
||||
|
||||
@@ -64,8 +64,7 @@ class ExpertService:
|
||||
else:
|
||||
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
|
||||
self.engine = EngineSevice(self.cfg)
|
||||
self.engine = EngineSevice(self.cfg, start_queue)
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
|
||||
@@ -149,7 +148,7 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
|
||||
"""
|
||||
Start expert service
|
||||
"""
|
||||
expert_service = ExpertService(cfg, local_data_parallel_id)
|
||||
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
|
||||
|
||||
try:
|
||||
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
||||
@@ -160,6 +159,5 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
|
||||
|
||||
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
|
||||
t_deamon.start()
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")
|
||||
|
@@ -428,9 +428,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).unsqueeze()
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).unsqueeze()
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).unsqueeze()
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
|
@@ -283,8 +283,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
"up_gate_proj_weight_scale_inv": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale_inv": down_proj_weight_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
@@ -165,8 +165,22 @@ class PaddleDisWorkerProc:
|
||||
exist_swapped_task_signal:
|
||||
model_weights_status:
|
||||
"""
|
||||
# init worker_ready_signal
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
launched_expert_service_signal_data = np.zeros(
|
||||
shape=[min(self.parallel_config.data_parallel_size, self.max_chips_per_node)], dtype=np.int32
|
||||
)
|
||||
self.launched_expert_service_signal = IPCSignal(
|
||||
name="launched_expert_service_signal",
|
||||
array=launched_expert_service_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
|
||||
pass
|
||||
|
||||
# init worker_ready_signal
|
||||
array_size = min(
|
||||
self.max_chips_per_node,
|
||||
self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
|
||||
@@ -242,6 +256,7 @@ class PaddleDisWorkerProc:
|
||||
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
|
||||
req_ids = []
|
||||
num_running_requests = 0
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
while True:
|
||||
if self.local_rank == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
@@ -255,7 +270,6 @@ class PaddleDisWorkerProc:
|
||||
|
||||
self.insert_step = False
|
||||
req_dicts = None
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
|
||||
|
||||
# The first worker detects whether there are tasks in the task queue
|
||||
|
@@ -18,6 +18,7 @@ from fastdeploy.model_executor.load_weight_utils import (
|
||||
get_all_safetensors,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@@ -47,14 +48,21 @@ def parse_arguments():
|
||||
help="Whether merge the model into safetensors format.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moe_quant_type",
|
||||
default="w4a8",
|
||||
choices=["w4a8", "w4afp8"],
|
||||
help="The moe quant type of the model.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def reorder():
|
||||
def fn(weight):
|
||||
def fn(weight, moe_quant_type):
|
||||
from paddle.nn.quant import weight_quantize
|
||||
|
||||
quant_weight, _ = weight_quantize(weight.cuda(), algo="w4a8", arch=80)
|
||||
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
|
||||
return quant_weight.cpu()
|
||||
|
||||
return fn
|
||||
@@ -69,22 +77,27 @@ def deal_in_scale():
|
||||
|
||||
|
||||
def deal_weight_scale():
|
||||
def fn(weight_scale, processed_in_scale):
|
||||
def fn(weight_scale, processed_in_scale, moe_quant_type):
|
||||
if moe_quant_type == "w4a8":
|
||||
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
|
||||
return processed_weight_scale
|
||||
elif moe_quant_type == "w4afp8":
|
||||
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
|
||||
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
|
||||
return processed_weight_scale
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
# tmp support w4a8
|
||||
def deal_quant(state_dict, save_state_dict):
|
||||
w4a8_mapping = [
|
||||
def deal_quant(state_dict, save_state_dict, moe_quant_type):
|
||||
param_mapping = [
|
||||
# pattern,fn
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
|
||||
]
|
||||
for pattern, fn in w4a8_mapping:
|
||||
for pattern, fn in param_mapping:
|
||||
for key in list(state_dict.keys()):
|
||||
# print(f"deal {key}")
|
||||
match = re.search(pattern, key)
|
||||
@@ -94,9 +107,11 @@ def deal_quant(state_dict, save_state_dict):
|
||||
if "weight_scale" in key:
|
||||
in_scale_key = key.replace("weight_scale", "activation_scale")
|
||||
in_scale = save_state_dict[in_scale_key]
|
||||
save_state_dict[key] = fn(weight_or_scale, in_scale)
|
||||
else:
|
||||
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
|
||||
elif "activation_scale" in key:
|
||||
save_state_dict[key] = fn(weight_or_scale)
|
||||
else:
|
||||
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
|
||||
|
||||
|
||||
def save_safetensors(state_dict, args):
|
||||
@@ -153,7 +168,7 @@ def main():
|
||||
end = time.perf_counter()
|
||||
logger.info("Finish Quantize.")
|
||||
logger.info(f"load and quantize took : {end - start:.6f} seconds")
|
||||
deal_quant(state_dict, save_state_dict)
|
||||
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
|
||||
for key in list(state_dict.keys()):
|
||||
save_state_dict[key] = state_dict.pop(key)
|
||||
logger.info("Begin to save model")
|
||||
|
@@ -19,6 +19,7 @@ export devices=0
|
||||
export CUDA_VISIBLE_DEVICES=${devices}
|
||||
model_path=${1:-"/PATH/MODEL_PATH"}
|
||||
output_path=${2:-"/PATH/OUTPUT_MODEL"}
|
||||
moe_quant_type=${3:-"w4a8"}
|
||||
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
|
||||
unset ${name}
|
||||
done
|
||||
@@ -31,4 +32,5 @@ self_ip=`hostname -i`
|
||||
python offline_w4a8.py \
|
||||
--model_name_or_path ${model_path} \
|
||||
--output_dir ${output_path} \
|
||||
--safe_serialization "True"
|
||||
--safe_serialization "True" \
|
||||
--moe_quant_type ${moe_quant_type}
|
||||
|
69
tests/model_loader/test_w4a8_model.py
Normal file
69
tests/model_loader/test_w4a8_model.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
|
||||
bash_path = os.getenv("MODEL_PATH")
|
||||
FD_ENGINE_QUEUE_PORTS = [
|
||||
[9961, 9962, 9963, 9964, 9965, 9966, 9967, 9968],
|
||||
[9971, 9972, 9973, 9974, 9975, 9976, 9977, 9978],
|
||||
[9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988],
|
||||
[9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998],
|
||||
]
|
||||
|
||||
|
||||
models = [
|
||||
"ernie-4_5-fake-w4a8-unpermuted",
|
||||
"ernie-4_5-fake-w4a8-permuted",
|
||||
"ernie-4_5-fake-w4afp8-unpermuted",
|
||||
"ernie-4_5-fake-w4afp8-permuted",
|
||||
]
|
||||
|
||||
prompts = ["解释下“温故而知新"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=models)
|
||||
def llm(request):
|
||||
"""LLM测试夹具"""
|
||||
model_path = os.path.join(bash_path, request.param)
|
||||
try:
|
||||
port_index = models.index(request.param) % len(FD_ENGINE_QUEUE_PORTS)
|
||||
llm_instance = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=1,
|
||||
data_parallel_size=8,
|
||||
max_model_len=8192,
|
||||
num_gpu_blocks_override=1024,
|
||||
engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index],
|
||||
load_choices="default",
|
||||
enable_expert_parallel=True,
|
||||
)
|
||||
yield weakref.proxy(llm_instance)
|
||||
except Exception as e:
|
||||
pytest.skip(f"LLM initialization failed: {e}")
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_generation(llm):
|
||||
print(f"testing generation with model: {llm}")
|
||||
# topp_params = SamplingParams(temperature=0.1, top_p=0, max_tokens=20)
|
||||
# output = llm.generate(prompts=prompts, sampling_params=topp_params)
|
||||
# print(output)
|
Reference in New Issue
Block a user