Files
yuyinfenxi/predict.py
2025-07-02 13:54:05 +08:00

277 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
预测脚本
用于对单个音频文件进行情感预测
"""
import os
import pickle
import numpy as np
import librosa
from tensorflow.keras.models import load_model
import sys
# 导入自定义模块
from data_utils.feature_extractor import (
extract_features, features_to_matrix,
normalize_features_with_params, reshape_for_lstm
)
def load_model_files(model_dir):
"""
加载模型和相关文件
Args:
model_dir: 模型目录
Returns:
model: Keras模型
emotion_encoder: 情感标签编码器
feature_scaler: 特征缩放器
language_encoder: 语言标签编码器(如果有)
feature_names: 特征名称列表
"""
try:
# 确保模型目录存在
if not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
print(f"创建模型目录: {model_dir}")
# 加载模型
model_path = os.path.join(model_dir, 'emotion_model.h5')
if not os.path.exists(model_path):
# 尝试加载权重文件
weights_path = os.path.join(model_dir, 'best_model_weights.h5')
if not os.path.exists(weights_path):
raise FileNotFoundError(f"无法找到模型文件: {model_path} 或权重文件: {weights_path}")
# 尝试从配置重建模型,然后加载权重
config_path = os.path.join(model_dir, 'config.pkl')
if os.path.exists(config_path):
with open(config_path, 'rb') as f:
config = pickle.load(f)
from models.lstm_model import EmotionModel
# 尝试获取必要的参数
try:
input_shape = (1, config.get('num_features', 193)) # 默认特征数
num_emotions = config.get('num_emotions', 6) # 默认情感类别数
include_language = config.get('include_language', False)
num_languages = 2 # 默认语言数
model = EmotionModel(
input_shape=input_shape,
num_emotions=num_emotions,
num_languages=num_languages if include_language else None,
include_language=include_language
)
model.build_model()
model.compile_model()
model.model.load_weights(weights_path)
print(f"已从权重文件加载模型: {weights_path}")
except Exception as e:
raise ValueError(f"从配置重建模型失败: {e}")
else:
raise FileNotFoundError(f"找不到模型配置文件: {config_path}")
else:
model = load_model(model_path)
print(f"已加载模型: {model_path}")
# 加载情感编码器
encoder_path = os.path.join(model_dir, 'emotion_encoder.pkl')
if not os.path.exists(encoder_path):
raise FileNotFoundError(f"找不到情感编码器文件: {encoder_path}")
with open(encoder_path, 'rb') as f:
emotion_encoder = pickle.load(f)
# 加载特征名称,用于确保特征提取一致性
feature_names_path = os.path.join(model_dir, 'feature_names.pkl')
if os.path.exists(feature_names_path):
with open(feature_names_path, 'rb') as f:
feature_names = pickle.load(f)
else:
feature_names = None
print("警告: 找不到特征名称文件,可能影响预测准确性")
# 检查是否有语言编码器
language_encoder_path = os.path.join(model_dir, 'language_encoder.pkl')
if os.path.exists(language_encoder_path):
with open(language_encoder_path, 'rb') as f:
language_encoder = pickle.load(f)
print("检测到多语言模型")
else:
language_encoder = None
# 加载特征缩放参数
scaler_path = os.path.join(model_dir, 'feature_scaler.pkl')
if os.path.exists(scaler_path):
with open(scaler_path, 'rb') as f:
feature_scaler = pickle.load(f)
else:
print("警告: 找不到特征缩放参数文件,将使用默认缩放")
feature_scaler = None
return model, emotion_encoder, feature_scaler, language_encoder, feature_names
except Exception as e:
print(f"加载模型文件时出错: {e}")
raise
def predict_emotion(audio_file, model_dir):
"""
预测单个音频文件的情感
Args:
audio_file: 音频文件路径
model_dir: 模型目录
Returns:
dict: 预测结果,包含情感标签和概率
"""
try:
# 检查文件是否存在
if not os.path.exists(audio_file):
raise FileNotFoundError(f"音频文件不存在: {audio_file}")
# 加载模型和相关文件
model, emotion_encoder, feature_scaler, language_encoder, feature_names = load_model_files(model_dir)
# 加载音频
print(f"加载音频文件: {audio_file}")
try:
audio, sr = librosa.load(audio_file, sr=22050, res_type='kaiser_fast')
except Exception as e:
raise ValueError(f"加载音频文件失败: {e}")
# 提取特征
print("提取音频特征...")
features = extract_features(audio)
# 转换为特征矩阵
X, _ = features_to_matrix([features], feature_names=feature_names)
# 标准化特征
if feature_scaler:
X_norm = normalize_features_with_params(X, feature_scaler)
else:
# 如果没有保存的缩放参数,使用简单标准化
X_norm = (X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-10)
# 重塑为LSTM输入格式
X_reshaped = reshape_for_lstm(X_norm)
# 预测
print("进行预测...")
if language_encoder:
# 如果是多语言模型,需要提供语言特征
# 这里假设无法确定语言,将尝试所有可能的语言
languages = language_encoder.classes_
best_prob = -1
best_emotion = None
best_language = None
best_emotion_probs = None
best_language_prob = None
for lang in languages:
# 为当前语言创建独热编码
lang_idx = language_encoder.transform([lang])[0]
lang_one_hot = np.zeros((1, len(languages)))
lang_one_hot[0, lang_idx] = 1
# 预测
probs = model.predict([X_reshaped, lang_one_hot])[0]
# 找出最可能的情感
emotion_idx = np.argmax(probs)
emotion_prob = probs[emotion_idx]
if emotion_prob > best_prob:
best_prob = emotion_prob
best_emotion = emotion_encoder.classes_[emotion_idx]
best_language = lang
best_emotion_probs = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
best_language_prob = 1.0 # 假设语言预测的置信度
result = {
'emotion': best_emotion,
'probability': float(best_prob),
'language': best_language,
'language_probability': float(best_language_prob),
'all_emotions': best_emotion_probs
}
else:
# 单语言模型
try:
probs = model.predict(X_reshaped)[0]
emotion_idx = np.argmax(probs)
emotion = emotion_encoder.classes_[emotion_idx]
probability = probs[emotion_idx]
# 构建结果
all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
result = {
'emotion': emotion,
'probability': float(probability),
'all_emotions': all_emotions
}
except Exception as e:
print(f"预测过程中出错: {e}")
# 尝试使用模型的model属性如果是自定义EmotionModel类
try:
if hasattr(model, 'model'):
probs = model.model.predict(X_reshaped)[0]
emotion_idx = np.argmax(probs)
emotion = emotion_encoder.classes_[emotion_idx]
probability = probs[emotion_idx]
all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
result = {
'emotion': emotion,
'probability': float(probability),
'all_emotions': all_emotions
}
else:
raise
except:
raise ValueError(f"模型预测失败: {e}")
return result
except Exception as e:
print(f"预测过程中出错: {e}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='预测音频情感')
parser.add_argument('--audio', type=str, required=True,
help='音频文件路径')
parser.add_argument('--model', type=str, default='./output/emotion_model',
help='模型目录')
args = parser.parse_args()
try:
result = predict_emotion(args.audio, args.model)
print("\n预测结果:")
print(f"情感: {result['emotion']}")
print(f"置信度: {result['probability']:.2f}")
if 'language' in result:
print(f"语言: {result['language']}")
print(f"语言置信度: {result['language_probability']:.2f}")
print("\n情感概率分布:")
for emotion, prob in result['all_emotions'].items():
print(f" {emotion}: {prob:.4f}")
except Exception as e:
print(f"错误: {e}")
sys.exit(1)