Files
yuyinfenxi/main.py
2025-07-02 13:54:05 +08:00

217 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
多语言语音情感识别系统主程序
"""
import os
import argparse
import sys
import importlib
def check_dependencies():
"""
检查所需的Python依赖是否安装
Returns:
bool: 如果所有依赖都已安装则返回True
"""
required_packages = [
'numpy', 'pandas', 'matplotlib', 'librosa',
'sklearn', 'tensorflow', 'tqdm', 'soundfile',
'resampy' # 添加resampy作为必需依赖
]
missing_packages = []
for package in required_packages:
try:
importlib.import_module(package)
except ImportError:
missing_packages.append(package)
if missing_packages:
print("错误: 缺少以下Python依赖包:")
for pkg in missing_packages:
print(f" - {pkg}")
print("\n请使用以下命令安装缺少的依赖:")
print(f"pip install {' '.join(missing_packages)}")
return False
return True
def parse_args():
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(description='多语言语音情感识别系统')
# 模式选择
parser.add_argument('--mode', type=str, default='train',
choices=['train', 'predict'],
help='运行模式: train-训练模型, predict-使用模型预测')
# 数据路径
parser.add_argument('--casia_path', type=str, default='./CAISA',
help='CASIA数据集路径')
parser.add_argument('--savee_path', type=str, default='./SAVEE',
help='SAVEE数据集路径')
parser.add_argument('--ravdess_path', type=str, default='./RAVDESS',
help='RAVDESS数据集路径')
# 模型参数
parser.add_argument('--lstm_units', type=int, default=128,
help='LSTM层的单元数')
parser.add_argument('--dropout', type=float, default=0.3,
help='Dropout比率')
parser.add_argument('--learning_rate', type=float, default=0.001,
help='学习率')
parser.add_argument('--batch_size', type=int, default=32,
help='批量大小')
parser.add_argument('--epochs', type=int, default=100,
help='训练轮数')
# 预测参数
parser.add_argument('--audio_file', type=str, default=None,
help='要预测的音频文件路径')
parser.add_argument('--model_path', type=str, default='./output/emotion_model',
help='模型路径')
# 其他参数
parser.add_argument('--include_language', action='store_true',
help='是否将语言作为额外特征')
parser.add_argument('--output_dir', type=str, default='./output',
help='输出目录')
parser.add_argument('--seed', type=int, default=42,
help='随机种子')
parser.add_argument('--regularization', type=float, default=0.001,
help='L2正则化系数')
parser.add_argument('--model_name', type=str, default='emotion_model',
help='模型名称')
args = parser.parse_args()
return args
def verify_file_exists(file_path, description="文件"):
"""
验证文件是否存在
Args:
file_path: 文件路径
description: 文件描述
Returns:
bool: 文件存在则返回True否则返回False
"""
if not os.path.exists(file_path):
print(f"错误: {description}不存在: {file_path}")
return False
return True
def main():
"""
主函数
"""
# 检查依赖
if not check_dependencies():
print("请安装必要的依赖后再运行程序。")
sys.exit(1)
# 解析命令行参数
args = parse_args()
# 根据模式选择执行的功能
if args.mode == 'train':
# 导入训练模块
try:
from train import train_model
except ImportError as e:
print(f"错误: 无法导入训练模块: {e}")
sys.exit(1)
print("="*50)
print("多语言语音情感识别系统 - 训练模式")
print("="*50)
print(f"CASIA路径: {args.casia_path}")
print(f"SAVEE路径: {args.savee_path}")
print(f"RAVDESS路径: {args.ravdess_path}")
print(f"输出目录: {args.output_dir}")
print("-"*50)
# 检查至少有一个数据集路径存在
has_valid_path = False
if os.path.exists(args.casia_path):
has_valid_path = True
if os.path.exists(args.savee_path):
has_valid_path = True
if os.path.exists(args.ravdess_path):
has_valid_path = True
if not has_valid_path:
print("错误: 未找到任何有效的数据集路径,请检查--casia_path、--savee_path、--ravdess_path参数")
print("提示: 路径可以是绝对路径,例如 C:/数据集/CASIA")
sys.exit(1)
# 开始训练
try:
train_model(args)
except Exception as e:
print(f"训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
elif args.mode == 'predict':
# 导入预测模块
try:
from predict import predict_emotion
except ImportError as e:
print(f"错误: 无法导入预测模块: {e}")
sys.exit(1)
# 检查音频文件是否存在
if not args.audio_file:
print("错误: 预测模式下需要指定音频文件,请使用--audio_file参数")
sys.exit(1)
if not verify_file_exists(args.audio_file, "音频文件"):
sys.exit(1)
# 检查模型目录是否存在
model_path = os.path.join(args.output_dir, args.model_name)
if not os.path.exists(model_path):
print(f"警告: 模型目录不存在: {model_path}将尝试使用指定的model_path参数: {args.model_path}")
model_path = args.model_path
print("="*50)
print("多语言语音情感识别系统 - 预测模式")
print("="*50)
print(f"音频文件: {args.audio_file}")
print(f"模型路径: {model_path}")
print("-"*50)
# 开始预测
try:
result = predict_emotion(args.audio_file, model_path)
print("\n预测结果:")
print(f"情感: {result['emotion']}")
print(f"置信度: {result['probability']:.2f}")
if 'language' in result:
print(f"语言: {result['language']}")
print(f"语言置信度: {result['language_probability']:.2f}")
print("\n情感概率分布:")
for emotion, prob in result['all_emotions'].items():
print(f" {emotion}: {prob:.4f}")
except Exception as e:
print(f"预测过程中出现错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
else:
print(f"错误: 未知的模式 '{args.mode}'")
sys.exit(1)
if __name__ == "__main__":
main()