mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,13 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def check_safetensors_model(model_dir: str):
|
||||
"""
|
||||
model_dir : the directory of the model
|
||||
Check whther the model is safetensors format
|
||||
model_dir : the directory of the model
|
||||
Check whther the model is safetensors format
|
||||
"""
|
||||
model_files = list()
|
||||
all_files = os.listdir(model_dir)
|
||||
@@ -35,8 +36,7 @@ def check_safetensors_model(model_dir: str):
|
||||
return True
|
||||
try:
|
||||
# check all the file exists
|
||||
safetensors_num = int(
|
||||
model_files[0].strip(".safetensors").split("-")[-1])
|
||||
safetensors_num = int(model_files[0].strip(".safetensors").split("-")[-1])
|
||||
flags = [0] * safetensors_num
|
||||
for x in model_files:
|
||||
current_index = int(x.strip(".safetensors").split("-")[1])
|
||||
|
Reference in New Issue
Block a user