mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
add w4afp8 offline script (#3636)
This commit is contained in:
@@ -226,8 +226,8 @@ __global__ void permute_scale_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
||||||
const int row = scale.dims()[0];
|
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
|
||||||
const int col = scale.dims()[1];
|
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
|
||||||
if (col % 16 != 0) {
|
if (col % 16 != 0) {
|
||||||
PD_THROW("Only supported when col is divisible by 16.");
|
PD_THROW("Only supported when col is divisible by 16.");
|
||||||
}
|
}
|
||||||
|
@@ -83,10 +83,18 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gemm_case = [[256, 256, 1, 0]]
|
gemm_case = [
|
||||||
|
[8192, 3584, 8, 0], # eb45T ffn1
|
||||||
|
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||||
|
[7168, 8192, 8, 0], # eb45T ffn2
|
||||||
|
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||||
|
]
|
||||||
|
|
||||||
dtype = ["BF16"]
|
dtype = ["BF16"]
|
||||||
|
|
||||||
|
use_fast_compile = True
|
||||||
|
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
|
||||||
|
|
||||||
|
|
||||||
def get_cutlass_type(type):
|
def get_cutlass_type(type):
|
||||||
if type == "BF16":
|
if type == "BF16":
|
||||||
@@ -100,7 +108,7 @@ template_head_file.write(gemm_template_head)
|
|||||||
|
|
||||||
for type in dtype:
|
for type in dtype:
|
||||||
for case in gemm_case:
|
for case in gemm_case:
|
||||||
for n in range(16, 257, 16):
|
for n in n_range:
|
||||||
template_head_file.write(
|
template_head_file.write(
|
||||||
gemm_template_case.format(
|
gemm_template_case.format(
|
||||||
M=case[0],
|
M=case[0],
|
||||||
@@ -176,7 +184,7 @@ for type in dtype:
|
|||||||
template_head_file.write("\n")
|
template_head_file.write("\n")
|
||||||
|
|
||||||
for case in gemm_case:
|
for case in gemm_case:
|
||||||
for n in range(16, 257, 16):
|
for n in n_range:
|
||||||
template_head_file.write(
|
template_head_file.write(
|
||||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||||
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||||
|
@@ -60,12 +60,13 @@ curl -i http://0.0.0.0:8180/health
|
|||||||
Send requests to the service with the following command:
|
Send requests to the service with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
|
curl -X POST "http://0.0.0.0:1822/v1/chat/completions" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "Write me a poem about large language model."}
|
{"role": "user", "content": "Write me a poem about large language model."}
|
||||||
]
|
],
|
||||||
|
"stream": true
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@@ -52,7 +52,7 @@ class EngineSevice:
|
|||||||
Base class containing common engine functionality
|
Base class containing common engine functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg, start_queue=True):
|
||||||
"""
|
"""
|
||||||
Initializes the LLMEngine with the provided configuration.
|
Initializes the LLMEngine with the provided configuration.
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class EngineSevice:
|
|||||||
cfg.parallel_config.local_data_parallel_id,
|
cfg.parallel_config.local_data_parallel_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.start_worker_queue_service()
|
self.start_worker_queue_service(start_queue)
|
||||||
|
|
||||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[
|
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[
|
||||||
self.cfg.parallel_config.local_data_parallel_id
|
self.cfg.parallel_config.local_data_parallel_id
|
||||||
@@ -181,7 +181,7 @@ class EngineSevice:
|
|||||||
create=True,
|
create=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_worker_queue_service(self):
|
def start_worker_queue_service(self, start_queue):
|
||||||
"""
|
"""
|
||||||
start queue service for engine worker communication
|
start queue service for engine worker communication
|
||||||
"""
|
"""
|
||||||
@@ -189,7 +189,8 @@ class EngineSevice:
|
|||||||
self.cfg.master_ip,
|
self.cfg.master_ip,
|
||||||
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
|
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
|
||||||
)
|
)
|
||||||
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
|
||||||
|
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
|
||||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||||
address=address,
|
address=address,
|
||||||
|
@@ -38,7 +38,7 @@ from fastdeploy.engine.common_engine import EngineSevice
|
|||||||
from fastdeploy.engine.expert_service import start_data_parallel_service
|
from fastdeploy.engine.expert_service import start_data_parallel_service
|
||||||
from fastdeploy.engine.request import Request
|
from fastdeploy.engine.request import Request
|
||||||
from fastdeploy.input.preprocess import InputPreprocessor
|
from fastdeploy.input.preprocess import InputPreprocessor
|
||||||
from fastdeploy.inter_communicator import IPCSignal
|
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -362,7 +362,10 @@ class LLMEngine:
|
|||||||
self.zmq_server.close()
|
self.zmq_server.close()
|
||||||
if hasattr(self, "dp_processed"):
|
if hasattr(self, "dp_processed"):
|
||||||
for p in self.dp_processed:
|
for p in self.dp_processed:
|
||||||
|
console_logger.info(f"Waiting for worker {p.pid} to exit")
|
||||||
p.join()
|
p.join()
|
||||||
|
for p in self.dp_engine_worker_queue_server:
|
||||||
|
p.cleanup()
|
||||||
|
|
||||||
def _setting_environ_variables(self):
|
def _setting_environ_variables(self):
|
||||||
"""
|
"""
|
||||||
@@ -610,11 +613,26 @@ class LLMEngine:
|
|||||||
|
|
||||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||||
|
self.launched_expert_service_signal.value[0] = 1
|
||||||
self.dp_processed = []
|
self.dp_processed = []
|
||||||
|
self.dp_engine_worker_queue_server = []
|
||||||
for i in range(
|
for i in range(
|
||||||
1,
|
1,
|
||||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||||
):
|
):
|
||||||
|
address = (
|
||||||
|
self.cfg.master_ip,
|
||||||
|
int(self.cfg.engine_worker_queue_port[i]),
|
||||||
|
)
|
||||||
|
llm_logger.info(f"dp start queue service {address}")
|
||||||
|
self.dp_engine_worker_queue_server.append(
|
||||||
|
EngineWorkerQueue(
|
||||||
|
address=address,
|
||||||
|
is_server=True,
|
||||||
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
self.dp_processed.append(
|
self.dp_processed.append(
|
||||||
multiprocessing.Process(
|
multiprocessing.Process(
|
||||||
target=start_data_parallel_service,
|
target=start_data_parallel_service,
|
||||||
@@ -625,7 +643,7 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
llm_logger.info(
|
llm_logger.info(
|
||||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
|
||||||
+ f" data parallel id {i}"
|
+ f" data parallel id {i}"
|
||||||
)
|
)
|
||||||
self.dp_processed[-1].start()
|
self.dp_processed[-1].start()
|
||||||
|
@@ -39,7 +39,7 @@ class ExpertService:
|
|||||||
local_data_parallel_id (int): Local data parallel ID.
|
local_data_parallel_id (int): Local data parallel ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg, local_data_parallel_id):
|
def __init__(self, cfg, local_data_parallel_id, start_queue=True):
|
||||||
"""
|
"""
|
||||||
Initializes the LLMEngine with the provided configuration.
|
Initializes the LLMEngine with the provided configuration.
|
||||||
|
|
||||||
@@ -64,8 +64,7 @@ class ExpertService:
|
|||||||
else:
|
else:
|
||||||
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
|
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
|
||||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||||
|
self.engine = EngineSevice(self.cfg, start_queue)
|
||||||
self.engine = EngineSevice(self.cfg)
|
|
||||||
if self.cfg.scheduler_config.name == "splitwise":
|
if self.cfg.scheduler_config.name == "splitwise":
|
||||||
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||||
|
|
||||||
@@ -149,7 +148,7 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
|
|||||||
"""
|
"""
|
||||||
Start expert service
|
Start expert service
|
||||||
"""
|
"""
|
||||||
expert_service = ExpertService(cfg, local_data_parallel_id)
|
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
||||||
@@ -160,6 +159,5 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
|
|||||||
|
|
||||||
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
|
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
|
||||||
t_deamon.start()
|
t_deamon.start()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")
|
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")
|
||||||
|
@@ -428,9 +428,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).unsqueeze()
|
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
|
||||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).unsqueeze()
|
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
|
||||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).unsqueeze()
|
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
|
||||||
|
|
||||||
name_tensor_map = {
|
name_tensor_map = {
|
||||||
"up_gate_proj_weight": up_gate_proj_weight,
|
"up_gate_proj_weight": up_gate_proj_weight,
|
||||||
|
@@ -283,8 +283,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
name_tensor_map = {
|
name_tensor_map = {
|
||||||
"up_gate_proj_weight": up_gate_proj_weight,
|
"up_gate_proj_weight": up_gate_proj_weight,
|
||||||
"down_proj_weight": down_proj_weight,
|
"down_proj_weight": down_proj_weight,
|
||||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
"up_gate_proj_weight_scale_inv": up_gate_proj_weight_scale,
|
||||||
"down_proj_weight_scale": down_proj_weight_scale,
|
"down_proj_weight_scale_inv": down_proj_weight_scale,
|
||||||
}
|
}
|
||||||
for name, tensor in name_tensor_map.items():
|
for name, tensor in name_tensor_map.items():
|
||||||
getattr(layer, name).set_value(tensor)
|
getattr(layer, name).set_value(tensor)
|
||||||
|
@@ -165,8 +165,22 @@ class PaddleDisWorkerProc:
|
|||||||
exist_swapped_task_signal:
|
exist_swapped_task_signal:
|
||||||
model_weights_status:
|
model_weights_status:
|
||||||
"""
|
"""
|
||||||
# init worker_ready_signal
|
|
||||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
|
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||||
|
launched_expert_service_signal_data = np.zeros(
|
||||||
|
shape=[min(self.parallel_config.data_parallel_size, self.max_chips_per_node)], dtype=np.int32
|
||||||
|
)
|
||||||
|
self.launched_expert_service_signal = IPCSignal(
|
||||||
|
name="launched_expert_service_signal",
|
||||||
|
array=launched_expert_service_signal_data,
|
||||||
|
dtype=np.int32,
|
||||||
|
suffix=self.parallel_config.engine_pid,
|
||||||
|
create=False,
|
||||||
|
)
|
||||||
|
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# init worker_ready_signal
|
||||||
array_size = min(
|
array_size = min(
|
||||||
self.max_chips_per_node,
|
self.max_chips_per_node,
|
||||||
self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
|
self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
|
||||||
@@ -242,6 +256,7 @@ class PaddleDisWorkerProc:
|
|||||||
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
|
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
|
||||||
req_ids = []
|
req_ids = []
|
||||||
num_running_requests = 0
|
num_running_requests = 0
|
||||||
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
while True:
|
while True:
|
||||||
if self.local_rank == 0:
|
if self.local_rank == 0:
|
||||||
if self.model_weights_status.value[0] != 0:
|
if self.model_weights_status.value[0] != 0:
|
||||||
@@ -255,7 +270,6 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
self.insert_step = False
|
self.insert_step = False
|
||||||
req_dicts = None
|
req_dicts = None
|
||||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
|
||||||
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
|
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
|
||||||
|
|
||||||
# The first worker detects whether there are tasks in the task queue
|
# The first worker detects whether there are tasks in the task queue
|
||||||
|
@@ -18,6 +18,7 @@ from fastdeploy.model_executor.load_weight_utils import (
|
|||||||
get_all_safetensors,
|
get_all_safetensors,
|
||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
@@ -47,14 +48,21 @@ def parse_arguments():
|
|||||||
help="Whether merge the model into safetensors format.",
|
help="Whether merge the model into safetensors format.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moe_quant_type",
|
||||||
|
default="w4a8",
|
||||||
|
choices=["w4a8", "w4afp8"],
|
||||||
|
help="The moe quant type of the model.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def reorder():
|
def reorder():
|
||||||
def fn(weight):
|
def fn(weight, moe_quant_type):
|
||||||
from paddle.nn.quant import weight_quantize
|
from paddle.nn.quant import weight_quantize
|
||||||
|
|
||||||
quant_weight, _ = weight_quantize(weight.cuda(), algo="w4a8", arch=80)
|
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
|
||||||
return quant_weight.cpu()
|
return quant_weight.cpu()
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
@@ -69,22 +77,27 @@ def deal_in_scale():
|
|||||||
|
|
||||||
|
|
||||||
def deal_weight_scale():
|
def deal_weight_scale():
|
||||||
def fn(weight_scale, processed_in_scale):
|
def fn(weight_scale, processed_in_scale, moe_quant_type):
|
||||||
|
if moe_quant_type == "w4a8":
|
||||||
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
|
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
|
||||||
return processed_weight_scale
|
return processed_weight_scale
|
||||||
|
elif moe_quant_type == "w4afp8":
|
||||||
|
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
|
||||||
|
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
|
||||||
|
return processed_weight_scale
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
# tmp support w4a8
|
# tmp support w4a8
|
||||||
def deal_quant(state_dict, save_state_dict):
|
def deal_quant(state_dict, save_state_dict, moe_quant_type):
|
||||||
w4a8_mapping = [
|
param_mapping = [
|
||||||
# pattern,fn
|
# pattern,fn
|
||||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
|
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
|
||||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
|
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
|
||||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
|
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
|
||||||
]
|
]
|
||||||
for pattern, fn in w4a8_mapping:
|
for pattern, fn in param_mapping:
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
# print(f"deal {key}")
|
# print(f"deal {key}")
|
||||||
match = re.search(pattern, key)
|
match = re.search(pattern, key)
|
||||||
@@ -94,9 +107,11 @@ def deal_quant(state_dict, save_state_dict):
|
|||||||
if "weight_scale" in key:
|
if "weight_scale" in key:
|
||||||
in_scale_key = key.replace("weight_scale", "activation_scale")
|
in_scale_key = key.replace("weight_scale", "activation_scale")
|
||||||
in_scale = save_state_dict[in_scale_key]
|
in_scale = save_state_dict[in_scale_key]
|
||||||
save_state_dict[key] = fn(weight_or_scale, in_scale)
|
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
|
||||||
else:
|
elif "activation_scale" in key:
|
||||||
save_state_dict[key] = fn(weight_or_scale)
|
save_state_dict[key] = fn(weight_or_scale)
|
||||||
|
else:
|
||||||
|
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict, args):
|
def save_safetensors(state_dict, args):
|
||||||
@@ -153,7 +168,7 @@ def main():
|
|||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info("Finish Quantize.")
|
logger.info("Finish Quantize.")
|
||||||
logger.info(f"load and quantize took : {end - start:.6f} seconds")
|
logger.info(f"load and quantize took : {end - start:.6f} seconds")
|
||||||
deal_quant(state_dict, save_state_dict)
|
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
save_state_dict[key] = state_dict.pop(key)
|
save_state_dict[key] = state_dict.pop(key)
|
||||||
logger.info("Begin to save model")
|
logger.info("Begin to save model")
|
||||||
|
@@ -19,6 +19,7 @@ export devices=0
|
|||||||
export CUDA_VISIBLE_DEVICES=${devices}
|
export CUDA_VISIBLE_DEVICES=${devices}
|
||||||
model_path=${1:-"/PATH/MODEL_PATH"}
|
model_path=${1:-"/PATH/MODEL_PATH"}
|
||||||
output_path=${2:-"/PATH/OUTPUT_MODEL"}
|
output_path=${2:-"/PATH/OUTPUT_MODEL"}
|
||||||
|
moe_quant_type=${3:-"w4a8"}
|
||||||
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
|
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
|
||||||
unset ${name}
|
unset ${name}
|
||||||
done
|
done
|
||||||
@@ -31,4 +32,5 @@ self_ip=`hostname -i`
|
|||||||
python offline_w4a8.py \
|
python offline_w4a8.py \
|
||||||
--model_name_or_path ${model_path} \
|
--model_name_or_path ${model_path} \
|
||||||
--output_dir ${output_path} \
|
--output_dir ${output_path} \
|
||||||
--safe_serialization "True"
|
--safe_serialization "True" \
|
||||||
|
--moe_quant_type ${moe_quant_type}
|
||||||
|
69
tests/model_loader/test_w4a8_model.py
Normal file
69
tests/model_loader/test_w4a8_model.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.llm import LLM
|
||||||
|
|
||||||
|
bash_path = os.getenv("MODEL_PATH")
|
||||||
|
FD_ENGINE_QUEUE_PORTS = [
|
||||||
|
[9961, 9962, 9963, 9964, 9965, 9966, 9967, 9968],
|
||||||
|
[9971, 9972, 9973, 9974, 9975, 9976, 9977, 9978],
|
||||||
|
[9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988],
|
||||||
|
[9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
models = [
|
||||||
|
"ernie-4_5-fake-w4a8-unpermuted",
|
||||||
|
"ernie-4_5-fake-w4a8-permuted",
|
||||||
|
"ernie-4_5-fake-w4afp8-unpermuted",
|
||||||
|
"ernie-4_5-fake-w4afp8-permuted",
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = ["解释下“温故而知新"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=models)
|
||||||
|
def llm(request):
|
||||||
|
"""LLM测试夹具"""
|
||||||
|
model_path = os.path.join(bash_path, request.param)
|
||||||
|
try:
|
||||||
|
port_index = models.index(request.param) % len(FD_ENGINE_QUEUE_PORTS)
|
||||||
|
llm_instance = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
data_parallel_size=8,
|
||||||
|
max_model_len=8192,
|
||||||
|
num_gpu_blocks_override=1024,
|
||||||
|
engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index],
|
||||||
|
load_choices="default",
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
)
|
||||||
|
yield weakref.proxy(llm_instance)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"LLM initialization failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(60)
|
||||||
|
def test_generation(llm):
|
||||||
|
print(f"testing generation with model: {llm}")
|
||||||
|
# topp_params = SamplingParams(temperature=0.1, top_p=0, max_tokens=20)
|
||||||
|
# output = llm.generate(prompts=prompts, sampling_params=topp_params)
|
||||||
|
# print(output)
|
Reference in New Issue
Block a user