277 lines
10 KiB
Python
277 lines
10 KiB
Python
|
||
"""
|
||
预测脚本
|
||
用于对单个音频文件进行情感预测
|
||
"""
|
||
|
||
import os
|
||
import pickle
|
||
import numpy as np
|
||
import librosa
|
||
from tensorflow.keras.models import load_model
|
||
import sys
|
||
|
||
# 导入自定义模块
|
||
from data_utils.feature_extractor import (
|
||
extract_features, features_to_matrix,
|
||
normalize_features_with_params, reshape_for_lstm
|
||
)
|
||
|
||
def load_model_files(model_dir):
|
||
"""
|
||
加载模型和相关文件
|
||
|
||
Args:
|
||
model_dir: 模型目录
|
||
|
||
Returns:
|
||
model: Keras模型
|
||
emotion_encoder: 情感标签编码器
|
||
feature_scaler: 特征缩放器
|
||
language_encoder: 语言标签编码器(如果有)
|
||
feature_names: 特征名称列表
|
||
"""
|
||
try:
|
||
# 确保模型目录存在
|
||
if not os.path.exists(model_dir):
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
print(f"创建模型目录: {model_dir}")
|
||
|
||
# 加载模型
|
||
model_path = os.path.join(model_dir, 'emotion_model.h5')
|
||
if not os.path.exists(model_path):
|
||
# 尝试加载权重文件
|
||
weights_path = os.path.join(model_dir, 'best_model_weights.h5')
|
||
if not os.path.exists(weights_path):
|
||
raise FileNotFoundError(f"无法找到模型文件: {model_path} 或权重文件: {weights_path}")
|
||
|
||
# 尝试从配置重建模型,然后加载权重
|
||
config_path = os.path.join(model_dir, 'config.pkl')
|
||
if os.path.exists(config_path):
|
||
with open(config_path, 'rb') as f:
|
||
config = pickle.load(f)
|
||
|
||
from models.lstm_model import EmotionModel
|
||
|
||
# 尝试获取必要的参数
|
||
try:
|
||
input_shape = (1, config.get('num_features', 193)) # 默认特征数
|
||
num_emotions = config.get('num_emotions', 6) # 默认情感类别数
|
||
include_language = config.get('include_language', False)
|
||
num_languages = 2 # 默认语言数
|
||
|
||
model = EmotionModel(
|
||
input_shape=input_shape,
|
||
num_emotions=num_emotions,
|
||
num_languages=num_languages if include_language else None,
|
||
include_language=include_language
|
||
)
|
||
model.build_model()
|
||
model.compile_model()
|
||
model.model.load_weights(weights_path)
|
||
print(f"已从权重文件加载模型: {weights_path}")
|
||
except Exception as e:
|
||
raise ValueError(f"从配置重建模型失败: {e}")
|
||
else:
|
||
raise FileNotFoundError(f"找不到模型配置文件: {config_path}")
|
||
else:
|
||
model = load_model(model_path)
|
||
print(f"已加载模型: {model_path}")
|
||
|
||
# 加载情感编码器
|
||
encoder_path = os.path.join(model_dir, 'emotion_encoder.pkl')
|
||
if not os.path.exists(encoder_path):
|
||
raise FileNotFoundError(f"找不到情感编码器文件: {encoder_path}")
|
||
|
||
with open(encoder_path, 'rb') as f:
|
||
emotion_encoder = pickle.load(f)
|
||
|
||
# 加载特征名称,用于确保特征提取一致性
|
||
feature_names_path = os.path.join(model_dir, 'feature_names.pkl')
|
||
if os.path.exists(feature_names_path):
|
||
with open(feature_names_path, 'rb') as f:
|
||
feature_names = pickle.load(f)
|
||
else:
|
||
feature_names = None
|
||
print("警告: 找不到特征名称文件,可能影响预测准确性")
|
||
|
||
# 检查是否有语言编码器
|
||
language_encoder_path = os.path.join(model_dir, 'language_encoder.pkl')
|
||
if os.path.exists(language_encoder_path):
|
||
with open(language_encoder_path, 'rb') as f:
|
||
language_encoder = pickle.load(f)
|
||
print("检测到多语言模型")
|
||
else:
|
||
language_encoder = None
|
||
|
||
# 加载特征缩放参数
|
||
scaler_path = os.path.join(model_dir, 'feature_scaler.pkl')
|
||
if os.path.exists(scaler_path):
|
||
with open(scaler_path, 'rb') as f:
|
||
feature_scaler = pickle.load(f)
|
||
else:
|
||
print("警告: 找不到特征缩放参数文件,将使用默认缩放")
|
||
feature_scaler = None
|
||
|
||
return model, emotion_encoder, feature_scaler, language_encoder, feature_names
|
||
|
||
except Exception as e:
|
||
print(f"加载模型文件时出错: {e}")
|
||
raise
|
||
|
||
def predict_emotion(audio_file, model_dir):
|
||
"""
|
||
预测单个音频文件的情感
|
||
|
||
Args:
|
||
audio_file: 音频文件路径
|
||
model_dir: 模型目录
|
||
|
||
Returns:
|
||
dict: 预测结果,包含情感标签和概率
|
||
"""
|
||
try:
|
||
# 检查文件是否存在
|
||
if not os.path.exists(audio_file):
|
||
raise FileNotFoundError(f"音频文件不存在: {audio_file}")
|
||
|
||
# 加载模型和相关文件
|
||
model, emotion_encoder, feature_scaler, language_encoder, feature_names = load_model_files(model_dir)
|
||
|
||
# 加载音频
|
||
print(f"加载音频文件: {audio_file}")
|
||
try:
|
||
audio, sr = librosa.load(audio_file, sr=22050, res_type='kaiser_fast')
|
||
except Exception as e:
|
||
raise ValueError(f"加载音频文件失败: {e}")
|
||
|
||
# 提取特征
|
||
print("提取音频特征...")
|
||
features = extract_features(audio)
|
||
|
||
# 转换为特征矩阵
|
||
X, _ = features_to_matrix([features], feature_names=feature_names)
|
||
|
||
# 标准化特征
|
||
if feature_scaler:
|
||
X_norm = normalize_features_with_params(X, feature_scaler)
|
||
else:
|
||
# 如果没有保存的缩放参数,使用简单标准化
|
||
X_norm = (X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-10)
|
||
|
||
# 重塑为LSTM输入格式
|
||
X_reshaped = reshape_for_lstm(X_norm)
|
||
|
||
# 预测
|
||
print("进行预测...")
|
||
if language_encoder:
|
||
# 如果是多语言模型,需要提供语言特征
|
||
# 这里假设无法确定语言,将尝试所有可能的语言
|
||
languages = language_encoder.classes_
|
||
best_prob = -1
|
||
best_emotion = None
|
||
best_language = None
|
||
best_emotion_probs = None
|
||
best_language_prob = None
|
||
|
||
for lang in languages:
|
||
# 为当前语言创建独热编码
|
||
lang_idx = language_encoder.transform([lang])[0]
|
||
lang_one_hot = np.zeros((1, len(languages)))
|
||
lang_one_hot[0, lang_idx] = 1
|
||
|
||
# 预测
|
||
probs = model.predict([X_reshaped, lang_one_hot])[0]
|
||
|
||
# 找出最可能的情感
|
||
emotion_idx = np.argmax(probs)
|
||
emotion_prob = probs[emotion_idx]
|
||
|
||
if emotion_prob > best_prob:
|
||
best_prob = emotion_prob
|
||
best_emotion = emotion_encoder.classes_[emotion_idx]
|
||
best_language = lang
|
||
best_emotion_probs = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
|
||
best_language_prob = 1.0 # 假设语言预测的置信度
|
||
|
||
result = {
|
||
'emotion': best_emotion,
|
||
'probability': float(best_prob),
|
||
'language': best_language,
|
||
'language_probability': float(best_language_prob),
|
||
'all_emotions': best_emotion_probs
|
||
}
|
||
else:
|
||
# 单语言模型
|
||
try:
|
||
probs = model.predict(X_reshaped)[0]
|
||
emotion_idx = np.argmax(probs)
|
||
emotion = emotion_encoder.classes_[emotion_idx]
|
||
probability = probs[emotion_idx]
|
||
|
||
# 构建结果
|
||
all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
|
||
|
||
result = {
|
||
'emotion': emotion,
|
||
'probability': float(probability),
|
||
'all_emotions': all_emotions
|
||
}
|
||
except Exception as e:
|
||
print(f"预测过程中出错: {e}")
|
||
# 尝试使用模型的model属性(如果是自定义EmotionModel类)
|
||
try:
|
||
if hasattr(model, 'model'):
|
||
probs = model.model.predict(X_reshaped)[0]
|
||
emotion_idx = np.argmax(probs)
|
||
emotion = emotion_encoder.classes_[emotion_idx]
|
||
probability = probs[emotion_idx]
|
||
|
||
all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)}
|
||
|
||
result = {
|
||
'emotion': emotion,
|
||
'probability': float(probability),
|
||
'all_emotions': all_emotions
|
||
}
|
||
else:
|
||
raise
|
||
except:
|
||
raise ValueError(f"模型预测失败: {e}")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"预测过程中出错: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='预测音频情感')
|
||
parser.add_argument('--audio', type=str, required=True,
|
||
help='音频文件路径')
|
||
parser.add_argument('--model', type=str, default='./output/emotion_model',
|
||
help='模型目录')
|
||
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
result = predict_emotion(args.audio, args.model)
|
||
print("\n预测结果:")
|
||
print(f"情感: {result['emotion']}")
|
||
print(f"置信度: {result['probability']:.2f}")
|
||
|
||
if 'language' in result:
|
||
print(f"语言: {result['language']}")
|
||
print(f"语言置信度: {result['language_probability']:.2f}")
|
||
|
||
print("\n情感概率分布:")
|
||
for emotion, prob in result['all_emotions'].items():
|
||
print(f" {emotion}: {prob:.4f}")
|
||
|
||
except Exception as e:
|
||
print(f"错误: {e}")
|
||
sys.exit(1) |