mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[v1 loader]qwen Offline fp8 (#4036)
* support offline fp8 * update ut * update ut * update ut * fix * update * update
This commit is contained in:
@@ -57,7 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -341,10 +341,10 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
self.output_sizes = output_sizes
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if model_format == "torch":
|
||||
if weight_need_transpose:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
|
||||
assert loaded_shard_id in ["q_a", "kv_a"]
|
||||
@@ -365,6 +365,12 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
|
||||
@@ -483,15 +489,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
shard_dim = -1 if output_dim else 0
|
||||
output_size = param.shape[shard_dim]
|
||||
if loaded_shard_id is None:
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
param.weight_need_transpose = False
|
||||
# Loaded weight is already fused on disk.
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
@@ -506,6 +513,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
else:
|
||||
# split gate up
|
||||
assert loaded_shard_id in ["gate", "up"]
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
dim = -1 if output_dim else 0
|
||||
@@ -532,6 +542,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
@@ -604,11 +620,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
add_bias=add_bias,
|
||||
)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads_per_rank * self.head_dim,
|
||||
"k": self.kv_num_heads_per_rank * self.head_dim,
|
||||
"v": self.kv_num_heads_per_rank * self.head_dim,
|
||||
"q": self.num_heads_per_rank * head_dim,
|
||||
"k": self.kv_num_heads_per_rank * head_dim,
|
||||
"v": self.kv_num_heads_per_rank * head_dim,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
@@ -617,11 +633,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
assert output_dim is not None
|
||||
dim = -1 if output_dim else 0
|
||||
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if loaded_shard_id is None:
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
param.weight_need_transpose = False
|
||||
# Loaded weight is already fused on disk
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
@@ -637,13 +654,16 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
else:
|
||||
# split q k v
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
shard_size = (shard_id + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
shard_size = block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
@@ -663,10 +683,17 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_shard_size = self.kv_num_heads_per_rank * head_dim
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
|
||||
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
|
Reference in New Issue
Block a user