""" 预测脚本 用于对单个音频文件进行情感预测 """ import os import pickle import numpy as np import librosa from tensorflow.keras.models import load_model import sys # 导入自定义模块 from data_utils.feature_extractor import ( extract_features, features_to_matrix, normalize_features_with_params, reshape_for_lstm ) def load_model_files(model_dir): """ 加载模型和相关文件 Args: model_dir: 模型目录 Returns: model: Keras模型 emotion_encoder: 情感标签编码器 feature_scaler: 特征缩放器 language_encoder: 语言标签编码器(如果有) feature_names: 特征名称列表 """ try: # 确保模型目录存在 if not os.path.exists(model_dir): os.makedirs(model_dir, exist_ok=True) print(f"创建模型目录: {model_dir}") # 加载模型 model_path = os.path.join(model_dir, 'emotion_model.h5') if not os.path.exists(model_path): # 尝试加载权重文件 weights_path = os.path.join(model_dir, 'best_model_weights.h5') if not os.path.exists(weights_path): raise FileNotFoundError(f"无法找到模型文件: {model_path} 或权重文件: {weights_path}") # 尝试从配置重建模型,然后加载权重 config_path = os.path.join(model_dir, 'config.pkl') if os.path.exists(config_path): with open(config_path, 'rb') as f: config = pickle.load(f) from models.lstm_model import EmotionModel # 尝试获取必要的参数 try: input_shape = (1, config.get('num_features', 193)) # 默认特征数 num_emotions = config.get('num_emotions', 6) # 默认情感类别数 include_language = config.get('include_language', False) num_languages = 2 # 默认语言数 model = EmotionModel( input_shape=input_shape, num_emotions=num_emotions, num_languages=num_languages if include_language else None, include_language=include_language ) model.build_model() model.compile_model() model.model.load_weights(weights_path) print(f"已从权重文件加载模型: {weights_path}") except Exception as e: raise ValueError(f"从配置重建模型失败: {e}") else: raise FileNotFoundError(f"找不到模型配置文件: {config_path}") else: model = load_model(model_path) print(f"已加载模型: {model_path}") # 加载情感编码器 encoder_path = os.path.join(model_dir, 'emotion_encoder.pkl') if not os.path.exists(encoder_path): raise FileNotFoundError(f"找不到情感编码器文件: {encoder_path}") with open(encoder_path, 'rb') as f: emotion_encoder = pickle.load(f) # 加载特征名称,用于确保特征提取一致性 feature_names_path = os.path.join(model_dir, 'feature_names.pkl') if os.path.exists(feature_names_path): with open(feature_names_path, 'rb') as f: feature_names = pickle.load(f) else: feature_names = None print("警告: 找不到特征名称文件,可能影响预测准确性") # 检查是否有语言编码器 language_encoder_path = os.path.join(model_dir, 'language_encoder.pkl') if os.path.exists(language_encoder_path): with open(language_encoder_path, 'rb') as f: language_encoder = pickle.load(f) print("检测到多语言模型") else: language_encoder = None # 加载特征缩放参数 scaler_path = os.path.join(model_dir, 'feature_scaler.pkl') if os.path.exists(scaler_path): with open(scaler_path, 'rb') as f: feature_scaler = pickle.load(f) else: print("警告: 找不到特征缩放参数文件,将使用默认缩放") feature_scaler = None return model, emotion_encoder, feature_scaler, language_encoder, feature_names except Exception as e: print(f"加载模型文件时出错: {e}") raise def predict_emotion(audio_file, model_dir): """ 预测单个音频文件的情感 Args: audio_file: 音频文件路径 model_dir: 模型目录 Returns: dict: 预测结果,包含情感标签和概率 """ try: # 检查文件是否存在 if not os.path.exists(audio_file): raise FileNotFoundError(f"音频文件不存在: {audio_file}") # 加载模型和相关文件 model, emotion_encoder, feature_scaler, language_encoder, feature_names = load_model_files(model_dir) # 加载音频 print(f"加载音频文件: {audio_file}") try: audio, sr = librosa.load(audio_file, sr=22050, res_type='kaiser_fast') except Exception as e: raise ValueError(f"加载音频文件失败: {e}") # 提取特征 print("提取音频特征...") features = extract_features(audio) # 转换为特征矩阵 X, _ = features_to_matrix([features], feature_names=feature_names) # 标准化特征 if feature_scaler: X_norm = normalize_features_with_params(X, feature_scaler) else: # 如果没有保存的缩放参数,使用简单标准化 X_norm = (X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-10) # 重塑为LSTM输入格式 X_reshaped = reshape_for_lstm(X_norm) # 预测 print("进行预测...") if language_encoder: # 如果是多语言模型,需要提供语言特征 # 这里假设无法确定语言,将尝试所有可能的语言 languages = language_encoder.classes_ best_prob = -1 best_emotion = None best_language = None best_emotion_probs = None best_language_prob = None for lang in languages: # 为当前语言创建独热编码 lang_idx = language_encoder.transform([lang])[0] lang_one_hot = np.zeros((1, len(languages))) lang_one_hot[0, lang_idx] = 1 # 预测 probs = model.predict([X_reshaped, lang_one_hot])[0] # 找出最可能的情感 emotion_idx = np.argmax(probs) emotion_prob = probs[emotion_idx] if emotion_prob > best_prob: best_prob = emotion_prob best_emotion = emotion_encoder.classes_[emotion_idx] best_language = lang best_emotion_probs = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)} best_language_prob = 1.0 # 假设语言预测的置信度 result = { 'emotion': best_emotion, 'probability': float(best_prob), 'language': best_language, 'language_probability': float(best_language_prob), 'all_emotions': best_emotion_probs } else: # 单语言模型 try: probs = model.predict(X_reshaped)[0] emotion_idx = np.argmax(probs) emotion = emotion_encoder.classes_[emotion_idx] probability = probs[emotion_idx] # 构建结果 all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)} result = { 'emotion': emotion, 'probability': float(probability), 'all_emotions': all_emotions } except Exception as e: print(f"预测过程中出错: {e}") # 尝试使用模型的model属性(如果是自定义EmotionModel类) try: if hasattr(model, 'model'): probs = model.model.predict(X_reshaped)[0] emotion_idx = np.argmax(probs) emotion = emotion_encoder.classes_[emotion_idx] probability = probs[emotion_idx] all_emotions = {emotion: float(probs[i]) for i, emotion in enumerate(emotion_encoder.classes_)} result = { 'emotion': emotion, 'probability': float(probability), 'all_emotions': all_emotions } else: raise except: raise ValueError(f"模型预测失败: {e}") return result except Exception as e: print(f"预测过程中出错: {e}") import traceback traceback.print_exc() raise if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='预测音频情感') parser.add_argument('--audio', type=str, required=True, help='音频文件路径') parser.add_argument('--model', type=str, default='./output/emotion_model', help='模型目录') args = parser.parse_args() try: result = predict_emotion(args.audio, args.model) print("\n预测结果:") print(f"情感: {result['emotion']}") print(f"置信度: {result['probability']:.2f}") if 'language' in result: print(f"语言: {result['language']}") print(f"语言置信度: {result['language_probability']:.2f}") print("\n情感概率分布:") for emotion, prob in result['all_emotions'].items(): print(f" {emotion}: {prob:.4f}") except Exception as e: print(f"错误: {e}") sys.exit(1)