[v1 loader]qwen Offline fp8 (#4036)

* support offline fp8

* update ut

* update ut

* update ut

* fix

* update

* update
This commit is contained in:
bukejiyu
2025-09-15 13:44:11 +08:00
committed by GitHub
parent b1a5b756a3
commit 29ed617f0f
21 changed files with 440 additions and 138 deletions

View File

@@ -57,7 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
{
**extra_weight_attrs,
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
},
)
@@ -341,10 +341,10 @@ class MergedReplicatedLinear(ReplicatedLinear):
self.output_sizes = output_sizes
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
weight_need_transpose = getattr(param, "weight_need_transpose", False)
loaded_weight = get_tensor(loaded_weight)
if model_format == "torch":
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
assert loaded_shard_id in ["q_a", "kv_a"]
@@ -365,6 +365,12 @@ class MergedReplicatedLinear(ReplicatedLinear):
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
@@ -483,15 +489,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
weight_need_transpose = getattr(param, "weight_need_transpose", False)
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
shard_dim = -1 if output_dim else 0
output_size = param.shape[shard_dim]
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
param.weight_need_transpose = False
# Loaded weight is already fused on disk.
shard_offsets = [
# (shard_id, shard_offset, shard_size)
@@ -506,6 +513,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
else:
# split gate up
assert loaded_shard_id in ["gate", "up"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
dim = -1 if output_dim else 0
@@ -532,6 +542,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
def load_state_dict(self, state_dict: dict):
@@ -604,11 +620,11 @@ class QKVParallelLinear(ColumnParallelLinear):
add_bias=add_bias,
)
def _get_shard_size_mapping(self, loaded_shard_id: str):
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
shard_size_mapping = {
"q": self.num_heads_per_rank * self.head_dim,
"k": self.kv_num_heads_per_rank * self.head_dim,
"v": self.kv_num_heads_per_rank * self.head_dim,
"q": self.num_heads_per_rank * head_dim,
"k": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * head_dim,
}
return shard_size_mapping.get(loaded_shard_id)
@@ -617,11 +633,12 @@ class QKVParallelLinear(ColumnParallelLinear):
assert output_dim is not None
dim = -1 if output_dim else 0
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
param.weight_need_transpose = False
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
@@ -637,13 +654,16 @@ class QKVParallelLinear(ColumnParallelLinear):
else:
# split q k v
assert loaded_shard_id in ["q", "k", "v"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
block_size = self._get_shard_size_mapping(loaded_shard_id)
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = (shard_id + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
loaded_weight = get_tensor(loaded_weight)
@@ -663,10 +683,17 @@ class QKVParallelLinear(ColumnParallelLinear):
param_shard_size = self.kv_num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
def load_weight(self, state_dict: dict):