mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Intel HPU] Support intel hpu platform (#4161)
* [Intel HPU] Support intel hpu platform * fix some issues * apply precommit and move AttentionBackend_HPU * fix format issue * correct ops import * fix ci issue * update code in layers * fix code style issue * remove dense tp moe ep mode * fix enc_dec_block_num * fix rebase issue * rename hpu to gaudi in readme * rename ForwardMeta_HPU to HPUForwardMeta
This commit is contained in:
@@ -82,6 +82,10 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
||||
from fastdeploy.worker.metax_worker import MetaxWorker
|
||||
|
||||
return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
if current_platform.is_intel_hpu():
|
||||
from fastdeploy.worker.hpu_worker import HpuWorker
|
||||
|
||||
return HpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
|
||||
|
||||
def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
|
||||
@@ -89,21 +93,22 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
|
||||
# Global rank
|
||||
ranks = dist.get_world_size()
|
||||
dist_strategy = fleet.DistributedStrategy()
|
||||
if ranks > 0:
|
||||
dist_strategy.hybrid_configs = {
|
||||
"dp_degree": 1,
|
||||
"mp_degree": ranks,
|
||||
"pp_degree": 1,
|
||||
"sharding_degree": 1,
|
||||
}
|
||||
|
||||
dist_strategy.hybrid_configs = {
|
||||
"dp_degree": 1,
|
||||
"mp_degree": ranks,
|
||||
"pp_degree": 1,
|
||||
"sharding_degree": 1,
|
||||
}
|
||||
|
||||
# Set control in tensor parallel
|
||||
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
|
||||
fleet.init(is_collective=True, strategy=dist_strategy)
|
||||
|
||||
# Local rank
|
||||
local_rank = fleet.worker_index()
|
||||
# Set control in tensor parallel
|
||||
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
|
||||
fleet.init(is_collective=True, strategy=dist_strategy)
|
||||
|
||||
# Local rank
|
||||
local_rank = fleet.worker_index()
|
||||
else:
|
||||
local_rank = 0
|
||||
return ranks, local_rank
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user