update flake8 version to support pre-commit in python3.12 (#3000)

* update flake8 version to support pre-commit in python3.12

* polish code
This commit is contained in:
Zero Rains
2025-07-24 16:43:31 +08:00
committed by GitHub
parent 5151bc92c8
commit 0fb37ab7e4
30 changed files with 324 additions and 275 deletions

View File

@@ -17,7 +17,11 @@ dcu backend methods
"""
from .fused_moe_triton_backends import DCUTritonWeightOnlyMoEMethod
from .weight_only import DCUWeightOnlyLinearMethod
from .top_p_sampling import native_top_p_sampling
from .weight_only import DCUWeightOnlyLinearMethod
__all__ = ["DCUTritonWeightOnlyMoEMethod", "DCUWeightOnlyLinearMethod", "native_top_p_sampling"]
__all__ = [
"DCUTritonWeightOnlyMoEMethod",
"DCUWeightOnlyLinearMethod",
"native_top_p_sampling",
]

View File

@@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
def native_top_p_sampling(
probs: paddle.Tensor,
top_p: paddle.Tensor
) -> tuple[paddle.Tensor, paddle.Tensor]:
def native_top_p_sampling(probs: paddle.Tensor, top_p: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor]:
sorted_indices = paddle.argsort(probs, descending=True)
sorted_probs = paddle.sort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
@@ -30,7 +28,9 @@ def native_top_p_sampling(
sorted_indices = sorted_indices + paddle.arange(probs.shape[0], dtype="int64").unsqueeze(-1) * probs.shape[-1]
condition = paddle.scatter(
sorted_indices_to_remove.flatten(), sorted_indices.flatten(), sorted_indices_to_remove.flatten()
sorted_indices_to_remove.flatten(),
sorted_indices.flatten(),
sorted_indices_to_remove.flatten(),
)
condition = paddle.cast(condition, "bool").reshape(probs.shape)