mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
update flake8 version to support pre-commit in python3.12 (#3000)
* update flake8 version to support pre-commit in python3.12 * polish code
This commit is contained in:
@@ -17,7 +17,11 @@ dcu backend methods
|
||||
"""
|
||||
|
||||
from .fused_moe_triton_backends import DCUTritonWeightOnlyMoEMethod
|
||||
from .weight_only import DCUWeightOnlyLinearMethod
|
||||
from .top_p_sampling import native_top_p_sampling
|
||||
from .weight_only import DCUWeightOnlyLinearMethod
|
||||
|
||||
__all__ = ["DCUTritonWeightOnlyMoEMethod", "DCUWeightOnlyLinearMethod", "native_top_p_sampling"]
|
||||
__all__ = [
|
||||
"DCUTritonWeightOnlyMoEMethod",
|
||||
"DCUWeightOnlyLinearMethod",
|
||||
"native_top_p_sampling",
|
||||
]
|
||||
|
@@ -13,13 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
def native_top_p_sampling(
|
||||
probs: paddle.Tensor,
|
||||
top_p: paddle.Tensor
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
def native_top_p_sampling(probs: paddle.Tensor, top_p: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
sorted_indices = paddle.argsort(probs, descending=True)
|
||||
sorted_probs = paddle.sort(probs, descending=True)
|
||||
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
|
||||
@@ -30,7 +28,9 @@ def native_top_p_sampling(
|
||||
sorted_indices = sorted_indices + paddle.arange(probs.shape[0], dtype="int64").unsqueeze(-1) * probs.shape[-1]
|
||||
|
||||
condition = paddle.scatter(
|
||||
sorted_indices_to_remove.flatten(), sorted_indices.flatten(), sorted_indices_to_remove.flatten()
|
||||
sorted_indices_to_remove.flatten(),
|
||||
sorted_indices.flatten(),
|
||||
sorted_indices_to_remove.flatten(),
|
||||
)
|
||||
|
||||
condition = paddle.cast(condition, "bool").reshape(probs.shape)
|
||||
|
Reference in New Issue
Block a user