This commit is contained in:
lzc
2025-07-02 13:54:05 +08:00
commit 4b3870440c
6351 changed files with 282880 additions and 0 deletions

277
predict.py Normal file
View File

@@ -0,0 +1,277 @@
"""
预测脚本
用于对单个音频文件进行情感预测
"""
import os
import pickle
import numpy as np
import librosa
from tensorflow.keras.models import load_model
import sys
# 导入自定义模块
from data_utils.feature_extractor import (
extract_features, features_to_matrix,
normalize_features_with_params, reshape_for_lstm
)
def load_model_files(model_dir):
"""
加载模型和相关文件
Args:
model_dir: 模型目录
Returns:
model: Keras模型
emotion_encoder: 情感标签编码器
feature_scaler: 特征缩放器
language_encoder: 语言标签编码器(如果有)
feature_names: 特征名称列表
"""
try:
# 确保模型目录存在
if not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
print(f"创建模型目录: {model_dir}")
# 加载模型
model_path = os.path.join(model_dir, 'emotion_model.h5')
if not os.path.exists(model_path):
# 尝试加载权重文件
weights_path = os.path.join(model_dir, 'best_model_weights.h5')
if not os.path.exists(weights_path):
raise FileNotFoundError(f"无法找到模型文件: {model_path} 或权重文件: {weights_path}")
# 尝试从配置重建模型,然后加载权重
config_path = os.path.join(model_dir, 'config.pkl')
if os.path.exists(config_path):
with open(config_path, 'rb') as f:
config = pickle.load(f)
from models.lstm_model import EmotionModel
# 尝试获取必要的参数
try:
input_shape = (1, config.get('num_features', 193)) # 默认特征数
num_emotions = config.get('num_emotions', 6) # 默认情感类别数
include_language = config.get('include_language', False)
num_languages = 2 # 默认语言数
model = EmotionModel(
input_shape=input_shape,
num_emotions=num_emotions,
num_languages=num_languages if include_language else None,
include_language=include_language
)
model.build_model()
model.compile_model()
model.model.load_weights(weights_path)
print(f"已从权重文件加载模型: {weights_path}")
except Exception as e:
raise ValueError(f"从配置重建模型失败: {e}")
else:
raise FileNotFoundError(f"找不到模型配置文件: {config_path}")
else:
model = load_model(model_path)
print(f"已加载模型: {model_path}")
# 加载情感编码器
encoder_path = os.path.join(model_dir, 'emotion_encoder.pkl')
if not os.path.exists(encoder_path):
raise FileNotFoundError(f"找不到情感编码器文件: {encoder_path}")
with open(encoder_path, 'rb') as f:
emotion_encoder = pickle.load(f)
# 加载特征名称,用于确保特征提取一致性
feature_names_path = os.path.join(model_dir, 'feature_names.pkl')
if os.path.exists(feature_names_path):
with open(feature_names_path, 'rb') as f:
feature_names = pickle.load(f)
else:
feature_names = None
print("警告: 找不到特征名称文件,可能影响预测准确性")
# 检查是否有语言编码器
language_encoder_path = os.path.join(model_dir, 'language_encoder.pkl')
if os.path.exists(language_encoder_path):
with open(language_encoder_path, 'rb') as f:
language_encoder = pickle.load(f)
print("检测到多语言模型")
else:
language_encoder = None
# 加载特征缩放参数
scaler_path = os.path.join(model_dir, 'feature_scaler.pkl')
if os.path.exists(scaler_path):
with open(scaler_path, 'rb') as f:
feature_scaler = pickle.load(f)
else:
print("警告: 找不到特征缩放参数文件,将使用默认缩放")
feature_scaler = None
return model, emotion_encoder, feature_scaler, language_encoder, feature_names
except Exception as e:
print(f"加载模型文件时出错: {e}")
raise
def predict_emotion(audio_file, model_dir):
"""
预测单个音频文件的情感
Args:
audio_file: 音频文件路径
model_dir: 模型目录
Returns:
dict: 预测结果,包含情感标签和概率
"""
try:
# 检查文件是否存在
if not os.path.exists(audio_file):
raise FileNotFoundError(f"音频文件不存在: {audio_file}")
# 加载模型和相关文件
model, emotion_encoder, feature_scaler, language_encoder, feature_names = load_model_files(model_dir)
# 加载音频
print(f"加载音频文件: {audio_file}")
try:
audio, sr = librosa.load(audio_file, sr=22050, res_type='kaiser_fast')
except Exception as e:
raise ValueError(f"加载音频文件失败: {e}")
# 提取特征
print("提取音频特征...")
features = extract_features(audio)
# 转换为特征矩阵
X, _ = features_to_matrix([features], feature_names=feature_names)
# 标准化特征
if feature_scaler:
X_norm = normalize_features_with_params(X, feature_scaler)
else:
# 如果没有保存的缩放参数,使用简单标准化
X_norm = (X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-10)
# 重塑为LSTM输入格式
X_reshaped = reshape_for_lstm(X_norm)
# 预测
print("进行预测...")
if language_encoder:
# 如果是多语言模型,需要提供语言特征
# 这里假设无法确定语言,将尝试所有可能的语言
languages = language_encoder.classes_
best_prob = -1
best_emotion = None
best_language = None
best_emotion_probs = None
best_language_prob = None
for lang in languages:
# 为当前语言创建独热编码
lang_idx = language_encoder.transform([lang])[0]
lang_one_hot = np.zeros((1, len(languages)))
lang_one_hot[0, lang_idx] = 1
# 预测
probs = model.predict([X_reshaped, lang_one_hot])[0]
# 找出最可能的情感
emotion_idx = np.argmax(probs)
emotion_prob = probs[emotion_idx]
if emotion_prob > best_prob:
best_prob = emotion_prob
best_emotion = emotion_encoder.classes_[emotion_idx]
best_language = lang
best_emotion_probs = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
best_language_prob = 1.0 # 假设语言预测的置信度
result = {
'emotion': best_emotion,
'probability': float(best_prob),
'language': best_language,
'language_probability': float(best_language_prob),
'all_emotions': best_emotion_probs
}
else:
# 单语言模型
try:
probs = model.predict(X_reshaped)[0]
emotion_idx = np.argmax(probs)
emotion = emotion_encoder.classes_[emotion_idx]
probability = probs[emotion_idx]
# 构建结果
all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
result = {
'emotion': emotion,
'probability': float(probability),
'all_emotions': all_emotions
}
except Exception as e:
print(f"预测过程中出错: {e}")
# 尝试使用模型的model属性如果是自定义EmotionModel类
try:
if hasattr(model, 'model'):
probs = model.model.predict(X_reshaped)[0]
emotion_idx = np.argmax(probs)
emotion = emotion_encoder.classes_[emotion_idx]
probability = probs[emotion_idx]
all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
result = {
'emotion': emotion,
'probability': float(probability),
'all_emotions': all_emotions
}
else:
raise
except:
raise ValueError(f"模型预测失败: {e}")
return result
except Exception as e:
print(f"预测过程中出错: {e}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='预测音频情感')
parser.add_argument('--audio', type=str, required=True,
help='音频文件路径')
parser.add_argument('--model', type=str, default='./output/emotion_model',
help='模型目录')
args = parser.parse_args()
try:
result = predict_emotion(args.audio, args.model)
print("\n预测结果:")
print(f"情感: {result['emotion']}")
print(f"置信度: {result['probability']:.2f}")
if 'language' in result:
print(f"语言: {result['language']}")
print(f"语言置信度: {result['language_probability']:.2f}")
print("\n情感概率分布:")
for emotion, prob in result['all_emotions'].items():
print(f" {emotion}: {prob:.4f}")
except Exception as e:
print(f"错误: {e}")
sys.exit(1)