mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Adapt for iluvatar gpu (#2684)
This commit is contained in:
@@ -75,11 +75,17 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>(
|
||||
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
|
||||
Reference in New Issue
Block a user