217 lines
7.3 KiB
Python
217 lines
7.3 KiB
Python
|
||
"""
|
||
多语言语音情感识别系统主程序
|
||
"""
|
||
|
||
import os
|
||
import argparse
|
||
import sys
|
||
import importlib
|
||
|
||
def check_dependencies():
|
||
"""
|
||
检查所需的Python依赖是否安装
|
||
|
||
Returns:
|
||
bool: 如果所有依赖都已安装,则返回True
|
||
"""
|
||
required_packages = [
|
||
'numpy', 'pandas', 'matplotlib', 'librosa',
|
||
'sklearn', 'tensorflow', 'tqdm', 'soundfile',
|
||
'resampy' # 添加resampy作为必需依赖
|
||
]
|
||
|
||
missing_packages = []
|
||
|
||
for package in required_packages:
|
||
try:
|
||
importlib.import_module(package)
|
||
except ImportError:
|
||
missing_packages.append(package)
|
||
|
||
if missing_packages:
|
||
print("错误: 缺少以下Python依赖包:")
|
||
for pkg in missing_packages:
|
||
print(f" - {pkg}")
|
||
print("\n请使用以下命令安装缺少的依赖:")
|
||
print(f"pip install {' '.join(missing_packages)}")
|
||
return False
|
||
|
||
return True
|
||
|
||
def parse_args():
|
||
"""
|
||
解析命令行参数
|
||
"""
|
||
parser = argparse.ArgumentParser(description='多语言语音情感识别系统')
|
||
|
||
# 模式选择
|
||
parser.add_argument('--mode', type=str, default='train',
|
||
choices=['train', 'predict'],
|
||
help='运行模式: train-训练模型, predict-使用模型预测')
|
||
|
||
# 数据路径
|
||
parser.add_argument('--casia_path', type=str, default='./CAISA',
|
||
help='CASIA数据集路径')
|
||
parser.add_argument('--savee_path', type=str, default='./SAVEE',
|
||
help='SAVEE数据集路径')
|
||
parser.add_argument('--ravdess_path', type=str, default='./RAVDESS',
|
||
help='RAVDESS数据集路径')
|
||
|
||
# 模型参数
|
||
parser.add_argument('--lstm_units', type=int, default=128,
|
||
help='LSTM层的单元数')
|
||
parser.add_argument('--dropout', type=float, default=0.3,
|
||
help='Dropout比率')
|
||
parser.add_argument('--learning_rate', type=float, default=0.001,
|
||
help='学习率')
|
||
parser.add_argument('--batch_size', type=int, default=32,
|
||
help='批量大小')
|
||
parser.add_argument('--epochs', type=int, default=100,
|
||
help='训练轮数')
|
||
|
||
# 预测参数
|
||
parser.add_argument('--audio_file', type=str, default=None,
|
||
help='要预测的音频文件路径')
|
||
parser.add_argument('--model_path', type=str, default='./output/emotion_model',
|
||
help='模型路径')
|
||
|
||
# 其他参数
|
||
parser.add_argument('--include_language', action='store_true',
|
||
help='是否将语言作为额外特征')
|
||
parser.add_argument('--output_dir', type=str, default='./output',
|
||
help='输出目录')
|
||
parser.add_argument('--seed', type=int, default=42,
|
||
help='随机种子')
|
||
parser.add_argument('--regularization', type=float, default=0.001,
|
||
help='L2正则化系数')
|
||
parser.add_argument('--model_name', type=str, default='emotion_model',
|
||
help='模型名称')
|
||
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
def verify_file_exists(file_path, description="文件"):
|
||
"""
|
||
验证文件是否存在
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
description: 文件描述
|
||
|
||
Returns:
|
||
bool: 文件存在则返回True,否则返回False
|
||
"""
|
||
if not os.path.exists(file_path):
|
||
print(f"错误: {description}不存在: {file_path}")
|
||
return False
|
||
return True
|
||
|
||
def main():
|
||
"""
|
||
主函数
|
||
"""
|
||
# 检查依赖
|
||
if not check_dependencies():
|
||
print("请安装必要的依赖后再运行程序。")
|
||
sys.exit(1)
|
||
|
||
# 解析命令行参数
|
||
args = parse_args()
|
||
|
||
# 根据模式选择执行的功能
|
||
if args.mode == 'train':
|
||
# 导入训练模块
|
||
try:
|
||
from train import train_model
|
||
except ImportError as e:
|
||
print(f"错误: 无法导入训练模块: {e}")
|
||
sys.exit(1)
|
||
|
||
print("="*50)
|
||
print("多语言语音情感识别系统 - 训练模式")
|
||
print("="*50)
|
||
print(f"CASIA路径: {args.casia_path}")
|
||
print(f"SAVEE路径: {args.savee_path}")
|
||
print(f"RAVDESS路径: {args.ravdess_path}")
|
||
print(f"输出目录: {args.output_dir}")
|
||
print("-"*50)
|
||
|
||
# 检查至少有一个数据集路径存在
|
||
has_valid_path = False
|
||
if os.path.exists(args.casia_path):
|
||
has_valid_path = True
|
||
if os.path.exists(args.savee_path):
|
||
has_valid_path = True
|
||
if os.path.exists(args.ravdess_path):
|
||
has_valid_path = True
|
||
|
||
if not has_valid_path:
|
||
print("错误: 未找到任何有效的数据集路径,请检查--casia_path、--savee_path、--ravdess_path参数")
|
||
print("提示: 路径可以是绝对路径,例如 C:/数据集/CASIA")
|
||
sys.exit(1)
|
||
|
||
# 开始训练
|
||
try:
|
||
train_model(args)
|
||
except Exception as e:
|
||
print(f"训练过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
elif args.mode == 'predict':
|
||
# 导入预测模块
|
||
try:
|
||
from predict import predict_emotion
|
||
except ImportError as e:
|
||
print(f"错误: 无法导入预测模块: {e}")
|
||
sys.exit(1)
|
||
|
||
# 检查音频文件是否存在
|
||
if not args.audio_file:
|
||
print("错误: 预测模式下需要指定音频文件,请使用--audio_file参数")
|
||
sys.exit(1)
|
||
|
||
if not verify_file_exists(args.audio_file, "音频文件"):
|
||
sys.exit(1)
|
||
|
||
# 检查模型目录是否存在
|
||
model_path = os.path.join(args.output_dir, args.model_name)
|
||
if not os.path.exists(model_path):
|
||
print(f"警告: 模型目录不存在: {model_path},将尝试使用指定的model_path参数: {args.model_path}")
|
||
model_path = args.model_path
|
||
|
||
print("="*50)
|
||
print("多语言语音情感识别系统 - 预测模式")
|
||
print("="*50)
|
||
print(f"音频文件: {args.audio_file}")
|
||
print(f"模型路径: {model_path}")
|
||
print("-"*50)
|
||
|
||
# 开始预测
|
||
try:
|
||
result = predict_emotion(args.audio_file, model_path)
|
||
print("\n预测结果:")
|
||
print(f"情感: {result['emotion']}")
|
||
print(f"置信度: {result['probability']:.2f}")
|
||
if 'language' in result:
|
||
print(f"语言: {result['language']}")
|
||
print(f"语言置信度: {result['language_probability']:.2f}")
|
||
|
||
print("\n情感概率分布:")
|
||
for emotion, prob in result['all_emotions'].items():
|
||
print(f" {emotion}: {prob:.4f}")
|
||
|
||
except Exception as e:
|
||
print(f"预测过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
else:
|
||
print(f"错误: 未知的模式 '{args.mode}'")
|
||
sys.exit(1)
|
||
|
||
if __name__ == "__main__":
|
||
main() |