Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -14,32 +14,37 @@
# limitations under the License.
"""
from typing import Tuple
from typing import Tuple, Union
import numpy as np
import paddle
from paddle import Tensor
from paddle import Tensor, nn
from paddle.framework import in_dynamic_mode
from scipy.linalg import block_diag
from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and current_platform.available():
try:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
speculate_get_padding_offset,
)
get_padding_offset, speculate_get_padding_offset)
except Exception:
raise ImportError(
f"Verify environment consistency between compilation and FastDeploy installation. "
f"And ensure the Paddle version supports FastDeploy's custom operators"
"Verify environment consistency between compilation and FastDeploy installation. "
"And ensure the Paddle version supports FastDeploy's custom operators"
)
import re
import os
cache_params = os.getenv("CACHE_PARAMS", "none")
from fastdeploy import envs
cache_params = envs.FD_CACHE_PARAMS
if cache_params != "none":
c8_state_dict = paddle.load(cache_params, return_numpy=True)
def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
def per_block_cast_to_fp8(x: Tensor,
block_size: list = [128,
128]) -> Tuple[Tensor, Tensor]:
"""
Only used in deep_gemm block wise quant weight.
copy from FastDeploy/custom_ops/gpu_ops/fp8_deep_gemm/tests/test_core.py.
@@ -48,10 +53,13 @@ def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = paddle.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
x_padded = paddle.zeros((ceil_div(m, block_size[0]) * block_size[0],
ceil_div(n, block_size[1]) * block_size[1]),
dtype=x.dtype)
x_padded[:m, :n] = x
x_view = paddle.view(x_padded, (-1, 128, x_padded.shape[1] // 128, 128))
x_view = paddle.view(
x_padded,
(-1, block_size[0], x_padded.shape[1] // block_size[1], block_size[1]))
x_abs = paddle.abs(x_view).astype(paddle.float32)
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
@@ -63,15 +71,15 @@ def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
# for distributed tensor model parallel
def _set_var_distributed(var, split_axis):
def _set_var_distributed(var: Tensor, split_axis: int):
"""
Set whether the variable is distributed. If the variable is None, no operation will be performed.
Args:
var (Variable, Optional): A Variable object, which can be None. The default value is None.
The Variable object should have an attribute 'is_distributed' to indicate whether
the variable has been processed in a distributed manner.
split_axis (Integer): the sharding dimension of dist tensors
var (Tensor): A Variable object, which can be None. The default value is None.
The Variable object should have an attribute 'is_distributed' to indicate whether
the variable has been processed in a distributed manner.
split_axis (int): the sharding dimension of dist tensors.
Returns:
None. No return value.
@@ -91,10 +99,16 @@ def _set_var_distributed(var, split_axis):
main_block._find_var_recursive(var.name).is_distributed = True
def get_tensor(input):
def get_tensor(input: Union[paddle.Tensor, np.ndarray, str]) -> paddle.Tensor:
"""
EP并行中权重按层分布式存储为了节省峰值显存在state_dict处理部分仅保存
层名与对应权重的路径因此需要将权重的类型转换为paddle.Tensor
Return a corresponding PaddlePaddle tensor based on the type and content of the input.
Args:
input (Union[paddle.Tensor, np.ndarray, str]): The input data.
Returns:
paddle.Tensor: Returns a PaddlePaddle tensor.
"""
if isinstance(input, paddle.Tensor):
if input.place.is_cpu_place():
@@ -104,7 +118,6 @@ def get_tensor(input):
return paddle.to_tensor(input)
elif isinstance(input, str):
if ".safetensors" in input:
match = re.match(r"\[(.*?)\](.*)", input)
if match:
key_name = match.group(1)
@@ -116,12 +129,11 @@ def get_tensor(input):
weight = f.get_tensor(key_name)
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(
paddle.framework._current_expected_place(), False
)
paddle.framework._current_expected_place(), False)
return weight
else:
return None
else:
else:
if cache_params != "none":
tmp_key = input.split("/")[-1]
if tmp_key in c8_state_dict:
@@ -129,25 +141,134 @@ def get_tensor(input):
return paddle.to_tensor(c8_state_dict.pop(tmp_key))
return paddle.load(input)
else:
# 理论上不会命中这个分支
return input
def matmul_hadU(X: Tensor) -> paddle.Tensor:
"""
Perform matrix multiplication using the Hadamard matrix.
Args:
X (Tensor): The tensor to be multiplied.
Returns:
Tensor: The tensor after Hadamard matrix multiplication, with the same shape as the input tensor X.
"""
input = X.clone().reshape((-1, X.shape[-1], 1))
output = input.clone()
while input.shape[1] > 1:
input = input.reshape(
(input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
output = output.reshape(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.reshape((input.shape[0], input.shape[1], -1))
(input, output) = (output, input)
del output
return input.reshape(X.shape)
def random_hadamard_matrix(block_size: int,
dtype: Union[paddle.dtype, str]) -> paddle.Tensor:
"""
Generate a random Hadamard matrix.
Args:
block_size (int): The size of the block, i.e., the number of rows and columns of the matrix.
dtype (str): The data type, for example 'float32'.
Returns:
paddle.Tensor: The generated random Hadamard matrix.
"""
Q = paddle.diag(paddle.ones((block_size), dtype=dtype))
block = matmul_hadU(Q)
return block
def create_hadamard_matrix(hidden_size: int) -> paddle.Tensor:
"""
Generate a Hadamard matrix.
Args:
hidden_size (int): The size of the hidden layer.
Returns:
paddle.Tensor: The generated Hadamard matrix.
"""
hadamard_block_size = 32
h = random_hadamard_matrix(hadamard_block_size, "float32")
block_num = hidden_size // hadamard_block_size
hadamard_matrix = paddle.to_tensor(
block_diag(*[h for i in range(block_num)]))
return hadamard_matrix
create_hadamard_matrix_map = {}
# Zkk: below key are used in 4.5T fp8.
create_hadamard_matrix_map[8192] = create_hadamard_matrix(8192)
create_hadamard_matrix_map[448] = create_hadamard_matrix(448)
create_hadamard_matrix_map[1024] = create_hadamard_matrix(1024)
create_hadamard_matrix_map[3584] = create_hadamard_matrix(3584)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
"""
Ensure the numerator is divisible by the denominator.
Args:
numerator (int): The numerator.
denominator (int): The denominator.
Returns:
None
Raises:
AssertionError: If the numerator cannot be evenly divided by the denominator, an assertion error is raised.
"""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
def divide(numerator: int, denominator: int):
"""
Calculate the division result of two numbers.
Args:
numerator (int): The dividend.
denominator (int): The divisor.
Returns:
int: The result of the division, which is the quotient of the dividend divided by the divisor.
"""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def remove_padding(max_len, input_ids, seq_lens_this_time):
def remove_padding(
max_len: paddle.Tensor, input_ids: paddle.Tensor,
seq_lens_this_time: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor]:
"""
remove_padding
Remove padded sequences from the input.
Args:
max_len (paddle.Tensor): The maximum length of the input sequences.
input_ids (paddle.Tensor): The IDs of the input sequences.
seq_lens_this_time (paddle.Tensor): The actual length of each sequence.
Returns:
tuple: A tuple containing:
- The sequence IDs with padding removed (paddle.Tensor).
- The padding offsets (paddle.Tensor).
- The cumulative offsets (paddle.Tensor).
- The query sequence lengths (paddle.Tensor).
- The key sequence lengths (paddle.Tensor).
"""
if current_platform.is_cuda():
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time)
@@ -159,7 +280,7 @@ def remove_padding(max_len, input_ids, seq_lens_this_time):
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num,
seq_lens_this_time)
seq_lens_this_time)
return (
ids_remove_padding,
padding_offset,
@@ -168,10 +289,30 @@ def remove_padding(max_len, input_ids, seq_lens_this_time):
cu_seqlens_k,
)
def speculate_remove_padding(max_len, input_ids, seq_lens_this_time,
draft_tokens, seq_lens_encoder):
def speculate_remove_padding(
max_len: paddle.Tensor, input_ids: paddle.Tensor,
seq_lens_this_time: paddle.Tensor, draft_tokens: paddle.Tensor,
seq_lens_encoder: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor]:
"""
remove_padding
Remove padding from sequences.
Args:
max_len (paddle.Tensor): The maximum length of the sequences.
input_ids (paddle.Tensor): The IDs of the input sequences.
seq_lens_this_time (paddle.Tensor): The lengths of the sequences in the current batch.
draft_tokens (paddle.Tensor): The draft tokens.
seq_lens_encoder (paddle.Tensor): The lengths of the encoder sequences.
Returns:
tuple: A tuple containing:
- The input sequence IDs with padding removed (paddle.Tensor).
- Padding offsets (paddle.Tensor).
- Cumulative offsets (paddle.Tensor).
- Query sequence lengths (paddle.Tensor).
- Key sequence lengths (paddle.Tensor).
"""
if current_platform.is_cuda():
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time)
@@ -197,3 +338,43 @@ def speculate_remove_padding(max_len, input_ids, seq_lens_this_time,
cu_seqlens_q,
cu_seqlens_k,
)
class CpuGuard:
"""CpuGuard"""
def __init__(self):
"""init"""
pass
def __enter__(self):
"""enter"""
self.ori_device = paddle.device.get_device()
paddle.device.set_device("cpu")
def __exit__(self, exc_type, exc_val, exc_tb):
"""exit"""
paddle.device.set_device(self.ori_device)
def create_and_set_parameter(layer: nn.Layer, name: str,
tensor: paddle.Tensor):
"""
Create a parameter for a specified layer and set its value to the given tensor.
Args:
layer (nn.Layer): The layer object to which the parameter will be added.
name (str): The name of the parameter to be created.
tensor (paddle.Tensor): The tensor to set as the value of the parameter.
Returns:
None
"""
setattr(
layer, name,
layer.create_parameter(
shape=tensor.shape,
dtype=tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
))
getattr(layer, name).set_value(tensor)