support mm mtp (#4013)

This commit is contained in:
xiaoxiaohehe001
2025-09-09 13:55:45 +08:00
committed by GitHub
parent c753f1fc9e
commit 5223065d59
11 changed files with 278 additions and 54 deletions

View File

@@ -1210,21 +1210,20 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["image_features"],
self.forward_meta,
)
hidden_states = model_output
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
self.parallel_config.max_model_len,
)
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
self.parallel_config.max_model_len,
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)