[Precision] Change lm_head layer running in float32 (#3596)

* support lm_head fp32 bf16 fp16

* delete print

* code check

* check

* check

* code check

* check

* check
This commit is contained in:
chen
2025-08-26 20:20:06 +08:00
committed by GitHub
parent 2136990144
commit d233e3c97c
14 changed files with 85 additions and 49 deletions

View File

@@ -587,6 +587,11 @@ def parse_args():
action="store_true",
help="Enable output of token-level log probabilities.",
)
parser.add_argument(
"--lm_head_fp32",
action="store_true",
help="The data type of lm_head",
)
args = parser.parse_args()
return args