[bugfix]fix blockwisefp8 and all_reduce (#3243)

* fix

* update

* fix linear for prequant loader
This commit is contained in:
bukejiyu
2025-08-06 23:54:33 +08:00
committed by GitHub
parent 3a15e0c53e
commit 9408e667a5
4 changed files with 37 additions and 24 deletions

View File

@@ -83,7 +83,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
@@ -101,7 +101,6 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
dtype="float32",
is_bias=False,
)
layer.weight_dtype = "float8_e4m3fn"
def process_loaded_weights(self, layer, weights) -> None:
weight_tensor = weights.transpose([1, 0])