This commit is contained in:
lzc
2025-07-02 13:54:05 +08:00
commit 4b3870440c
6351 changed files with 282880 additions and 0 deletions

217
main.py Normal file
View File

@@ -0,0 +1,217 @@
"""
多语言语音情感识别系统主程序
"""
import os
import argparse
import sys
import importlib
def check_dependencies():
"""
检查所需的Python依赖是否安装
Returns:
bool: 如果所有依赖都已安装则返回True
"""
required_packages = [
'numpy', 'pandas', 'matplotlib', 'librosa',
'sklearn', 'tensorflow', 'tqdm', 'soundfile',
'resampy' # 添加resampy作为必需依赖
]
missing_packages = []
for package in required_packages:
try:
importlib.import_module(package)
except ImportError:
missing_packages.append(package)
if missing_packages:
print("错误: 缺少以下Python依赖包:")
for pkg in missing_packages:
print(f" - {pkg}")
print("\n请使用以下命令安装缺少的依赖:")
print(f"pip install {' '.join(missing_packages)}")
return False
return True
def parse_args():
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(description='多语言语音情感识别系统')
# 模式选择
parser.add_argument('--mode', type=str, default='train',
choices=['train', 'predict'],
help='运行模式: train-训练模型, predict-使用模型预测')
# 数据路径
parser.add_argument('--casia_path', type=str, default='./CAISA',
help='CASIA数据集路径')
parser.add_argument('--savee_path', type=str, default='./SAVEE',
help='SAVEE数据集路径')
parser.add_argument('--ravdess_path', type=str, default='./RAVDESS',
help='RAVDESS数据集路径')
# 模型参数
parser.add_argument('--lstm_units', type=int, default=128,
help='LSTM层的单元数')
parser.add_argument('--dropout', type=float, default=0.3,
help='Dropout比率')
parser.add_argument('--learning_rate', type=float, default=0.001,
help='学习率')
parser.add_argument('--batch_size', type=int, default=32,
help='批量大小')
parser.add_argument('--epochs', type=int, default=100,
help='训练轮数')
# 预测参数
parser.add_argument('--audio_file', type=str, default=None,
help='要预测的音频文件路径')
parser.add_argument('--model_path', type=str, default='./output/emotion_model',
help='模型路径')
# 其他参数
parser.add_argument('--include_language', action='store_true',
help='是否将语言作为额外特征')
parser.add_argument('--output_dir', type=str, default='./output',
help='输出目录')
parser.add_argument('--seed', type=int, default=42,
help='随机种子')
parser.add_argument('--regularization', type=float, default=0.001,
help='L2正则化系数')
parser.add_argument('--model_name', type=str, default='emotion_model',
help='模型名称')
args = parser.parse_args()
return args
def verify_file_exists(file_path, description="文件"):
"""
验证文件是否存在
Args:
file_path: 文件路径
description: 文件描述
Returns:
bool: 文件存在则返回True否则返回False
"""
if not os.path.exists(file_path):
print(f"错误: {description}不存在: {file_path}")
return False
return True
def main():
"""
主函数
"""
# 检查依赖
if not check_dependencies():
print("请安装必要的依赖后再运行程序。")
sys.exit(1)
# 解析命令行参数
args = parse_args()
# 根据模式选择执行的功能
if args.mode == 'train':
# 导入训练模块
try:
from train import train_model
except ImportError as e:
print(f"错误: 无法导入训练模块: {e}")
sys.exit(1)
print("="*50)
print("多语言语音情感识别系统 - 训练模式")
print("="*50)
print(f"CASIA路径: {args.casia_path}")
print(f"SAVEE路径: {args.savee_path}")
print(f"RAVDESS路径: {args.ravdess_path}")
print(f"输出目录: {args.output_dir}")
print("-"*50)
# 检查至少有一个数据集路径存在
has_valid_path = False
if os.path.exists(args.casia_path):
has_valid_path = True
if os.path.exists(args.savee_path):
has_valid_path = True
if os.path.exists(args.ravdess_path):
has_valid_path = True
if not has_valid_path:
print("错误: 未找到任何有效的数据集路径,请检查--casia_path、--savee_path、--ravdess_path参数")
print("提示: 路径可以是绝对路径,例如 C:/数据集/CASIA")
sys.exit(1)
# 开始训练
try:
train_model(args)
except Exception as e:
print(f"训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
elif args.mode == 'predict':
# 导入预测模块
try:
from predict import predict_emotion
except ImportError as e:
print(f"错误: 无法导入预测模块: {e}")
sys.exit(1)
# 检查音频文件是否存在
if not args.audio_file:
print("错误: 预测模式下需要指定音频文件,请使用--audio_file参数")
sys.exit(1)
if not verify_file_exists(args.audio_file, "音频文件"):
sys.exit(1)
# 检查模型目录是否存在
model_path = os.path.join(args.output_dir, args.model_name)
if not os.path.exists(model_path):
print(f"警告: 模型目录不存在: {model_path}将尝试使用指定的model_path参数: {args.model_path}")
model_path = args.model_path
print("="*50)
print("多语言语音情感识别系统 - 预测模式")
print("="*50)
print(f"音频文件: {args.audio_file}")
print(f"模型路径: {model_path}")
print("-"*50)
# 开始预测
try:
result = predict_emotion(args.audio_file, model_path)
print("\n预测结果:")
print(f"情感: {result['emotion']}")
print(f"置信度: {result['probability']:.2f}")
if 'language' in result:
print(f"语言: {result['language']}")
print(f"语言置信度: {result['language_probability']:.2f}")
print("\n情感概率分布:")
for emotion, prob in result['all_emotions'].items():
print(f" {emotion}: {prob:.4f}")
except Exception as e:
print(f"预测过程中出现错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
else:
print(f"错误: 未知的模式 '{args.mode}'")
sys.exit(1)
if __name__ == "__main__":
main()