362 lines
12 KiB
Python
362 lines
12 KiB
Python
"""
|
||
数据加载模块,用于读取并处理语音情感数据集
|
||
支持CASIA、SAVEE和RAVDESS三种数据集
|
||
"""
|
||
|
||
import os
|
||
import glob
|
||
import numpy as np
|
||
import pandas as pd
|
||
import librosa
|
||
import re
|
||
from sklearn.model_selection import train_test_split
|
||
import warnings
|
||
|
||
# 定义常量
|
||
SAMPLE_RATE = 22050 # 统一采样率
|
||
MAX_DURATION = 5 # 最大音频长度(秒)
|
||
MAX_SAMPLES = SAMPLE_RATE * MAX_DURATION # 最大样本数
|
||
|
||
# 情绪映射字典
|
||
EMOTION_MAPPING = {
|
||
# CASIA 情感映射
|
||
'angry': 'angry',
|
||
'fear': 'fear',
|
||
'happy': 'happy',
|
||
'neutral': 'neutral',
|
||
'sad': 'sad',
|
||
'surprise': 'surprise',
|
||
|
||
# SAVEE 情感映射
|
||
'a': 'angry',
|
||
'f': 'fear',
|
||
'h': 'happy',
|
||
'n': 'neutral',
|
||
'sa': 'sad',
|
||
'su': 'surprise',
|
||
'd': 'disgust', # 注意:SAVEE有厌恶情绪,但CASIA没有
|
||
|
||
# RAVDESS 情感映射
|
||
'01': 'neutral',
|
||
'02': 'calm', # 注意:RAVDESS有平静情绪,但CASIA没有
|
||
'03': 'happy',
|
||
'04': 'sad',
|
||
'05': 'angry',
|
||
'06': 'fear',
|
||
'07': 'disgust', # 注意:RAVDESS有厌恶情绪,但CASIA没有
|
||
'08': 'surprise'
|
||
}
|
||
|
||
# 语言映射
|
||
LANGUAGE_MAPPING = {
|
||
'casia': 'zh', # 中文
|
||
'savee': 'en', # 英文
|
||
'ravdess': 'en' # 英文
|
||
}
|
||
|
||
def load_casia(data_path):
|
||
"""
|
||
加载CASIA中文情感语音数据集
|
||
|
||
Args:
|
||
data_path: CASIA数据集路径
|
||
|
||
Returns:
|
||
data_list: 包含(音频数据, 情感标签, 语言标签)的列表
|
||
"""
|
||
data_list = []
|
||
|
||
# 确保路径存在
|
||
if not os.path.exists(data_path):
|
||
print(f"警告: CASIA数据路径不存在: {data_path}")
|
||
return data_list
|
||
|
||
# 尝试获取演员列表
|
||
try:
|
||
actors = os.listdir(data_path)
|
||
except Exception as e:
|
||
print(f"错误: 无法读取CASIA目录: {e}")
|
||
return data_list
|
||
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
for actor in actors:
|
||
# 跳过隐藏文件或非目录
|
||
if actor.startswith('_') or not os.path.isdir(os.path.join(data_path, actor)):
|
||
continue
|
||
|
||
actor_path = os.path.join(data_path, actor)
|
||
emotions = os.listdir(actor_path)
|
||
|
||
for emotion in emotions:
|
||
# 跳过隐藏文件或非目录
|
||
if emotion.startswith('_') or not os.path.isdir(os.path.join(actor_path, emotion)):
|
||
continue
|
||
|
||
emotion_path = os.path.join(actor_path, emotion)
|
||
audio_files = glob.glob(os.path.join(emotion_path, "*.wav"))
|
||
|
||
for audio_file in audio_files:
|
||
# 读取音频文件
|
||
try:
|
||
audio, sr = librosa.load(audio_file, sr=SAMPLE_RATE, res_type='kaiser_fast')
|
||
|
||
# 统一音频长度
|
||
if len(audio) < MAX_SAMPLES:
|
||
# 音频太短,用0填充
|
||
padding = MAX_SAMPLES - len(audio)
|
||
audio = np.pad(audio, (0, padding), 'constant')
|
||
else:
|
||
# 音频太长,截断
|
||
audio = audio[:MAX_SAMPLES]
|
||
|
||
data_list.append((audio, EMOTION_MAPPING[emotion], LANGUAGE_MAPPING['casia']))
|
||
success_count += 1
|
||
|
||
except Exception as e:
|
||
error_count += 1
|
||
if error_count < 10: # 只显示前10个错误,避免日志过多
|
||
print(f"Error loading {audio_file}: {e}")
|
||
elif error_count == 10:
|
||
print("过多加载错误,后续错误将不再显示...")
|
||
|
||
print(f"CASIA数据集: 成功加载 {success_count} 个文件,失败 {error_count} 个文件")
|
||
return data_list
|
||
|
||
def load_savee(data_path):
|
||
"""
|
||
加载SAVEE英文情感语音数据集
|
||
|
||
Args:
|
||
data_path: SAVEE数据集路径
|
||
|
||
Returns:
|
||
data_list: 包含(音频数据, 情感标签, 语言标签)的列表
|
||
"""
|
||
data_list = []
|
||
|
||
# 确保路径存在
|
||
if not os.path.exists(data_path):
|
||
print(f"警告: SAVEE数据路径不存在: {data_path}")
|
||
return data_list
|
||
|
||
audio_path = os.path.join(data_path, "AudioData")
|
||
if not os.path.exists(audio_path):
|
||
print(f"警告: SAVEE AudioData路径不存在: {audio_path}")
|
||
return data_list
|
||
|
||
# SAVEE数据集中的四个说话者
|
||
actors = ['DC', 'JE', 'JK', 'KL']
|
||
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
for actor in actors:
|
||
actor_path = os.path.join(audio_path, actor)
|
||
if not os.path.isdir(actor_path):
|
||
print(f"警告: SAVEE演员目录不存在: {actor_path}")
|
||
continue
|
||
|
||
audio_files = glob.glob(os.path.join(actor_path, "*.wav"))
|
||
|
||
for audio_file in audio_files:
|
||
file_name = os.path.basename(audio_file)
|
||
# 提取情感标签,SAVEE使用文件名的前1-2个字母作为情感标签
|
||
if file_name.startswith("sa"):
|
||
emotion = "sa"
|
||
elif file_name.startswith("su"):
|
||
emotion = "su"
|
||
else:
|
||
emotion = file_name[0]
|
||
|
||
try:
|
||
audio, sr = librosa.load(audio_file, sr=SAMPLE_RATE, res_type='kaiser_fast')
|
||
|
||
# 统一音频长度
|
||
if len(audio) < MAX_SAMPLES:
|
||
padding = MAX_SAMPLES - len(audio)
|
||
audio = np.pad(audio, (0, padding), 'constant')
|
||
else:
|
||
audio = audio[:MAX_SAMPLES]
|
||
|
||
data_list.append((audio, EMOTION_MAPPING[emotion], LANGUAGE_MAPPING['savee']))
|
||
success_count += 1
|
||
|
||
except Exception as e:
|
||
error_count += 1
|
||
if error_count < 10: # 只显示前10个错误
|
||
print(f"Error loading {audio_file}: {e}")
|
||
elif error_count == 10:
|
||
print("过多加载错误,后续错误将不再显示...")
|
||
|
||
print(f"SAVEE数据集: 成功加载 {success_count} 个文件,失败 {error_count} 个文件")
|
||
return data_list
|
||
|
||
def load_ravdess(data_path):
|
||
"""
|
||
加载RAVDESS英文情感语音数据集
|
||
|
||
Args:
|
||
data_path: RAVDESS数据集路径
|
||
|
||
Returns:
|
||
data_list: 包含(音频数据, 情感标签, 语言标签)的列表
|
||
"""
|
||
data_list = []
|
||
|
||
# 确保路径存在
|
||
if not os.path.exists(data_path):
|
||
print(f"警告: RAVDESS数据路径不存在: {data_path}")
|
||
return data_list
|
||
|
||
# 获取所有演员目录
|
||
try:
|
||
actor_dirs = glob.glob(os.path.join(data_path, "Actor_*"))
|
||
except Exception as e:
|
||
print(f"错误: 无法获取RAVDESS演员目录: {e}")
|
||
return data_list
|
||
|
||
if not actor_dirs:
|
||
print(f"警告: RAVDESS演员目录为空: {data_path}")
|
||
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
for actor_dir in actor_dirs:
|
||
audio_files = glob.glob(os.path.join(actor_dir, "*.wav"))
|
||
|
||
for audio_file in audio_files:
|
||
file_name = os.path.basename(audio_file)
|
||
|
||
# RAVDESS文件名格式: 03-01-05-01-02-01-12.wav
|
||
# 05 表示情感类别 (angry)
|
||
parts = file_name.split('-')
|
||
if len(parts) >= 3:
|
||
emotion = parts[2]
|
||
|
||
try:
|
||
audio, sr = librosa.load(audio_file, sr=SAMPLE_RATE, res_type='kaiser_fast')
|
||
|
||
# 统一音频长度
|
||
if len(audio) < MAX_SAMPLES:
|
||
padding = MAX_SAMPLES - len(audio)
|
||
audio = np.pad(audio, (0, padding), 'constant')
|
||
else:
|
||
audio = audio[:MAX_SAMPLES]
|
||
|
||
data_list.append((audio, EMOTION_MAPPING[emotion], LANGUAGE_MAPPING['ravdess']))
|
||
success_count += 1
|
||
|
||
except Exception as e:
|
||
error_count += 1
|
||
if error_count < 10: # 只显示前10个错误
|
||
print(f"Error loading {audio_file}: {e}")
|
||
elif error_count == 10:
|
||
print("过多加载错误,后续错误将不再显示...")
|
||
|
||
print(f"RAVDESS数据集: 成功加载 {success_count} 个文件,失败 {error_count} 个文件")
|
||
return data_list
|
||
|
||
def load_all_data(casia_path, savee_path, ravdess_path, selected_emotions=None):
|
||
"""
|
||
加载所有数据集
|
||
|
||
Args:
|
||
casia_path: CASIA数据集路径
|
||
savee_path: SAVEE数据集路径
|
||
ravdess_path: RAVDESS数据集路径
|
||
selected_emotions: 要使用的情感列表,如果为None,则使用所有共有的情感
|
||
|
||
Returns:
|
||
X: 音频数据列表
|
||
y_emotion: 情感标签列表
|
||
y_language: 语言标签列表
|
||
"""
|
||
print("加载CASIA数据集...")
|
||
casia_data = load_casia(casia_path)
|
||
|
||
print("加载SAVEE数据集...")
|
||
savee_data = load_savee(savee_path)
|
||
|
||
print("加载RAVDESS数据集...")
|
||
ravdess_data = load_ravdess(ravdess_path)
|
||
|
||
# 合并所有数据
|
||
all_data = casia_data + savee_data + ravdess_data
|
||
|
||
# 检查是否有数据被加载
|
||
if not all_data:
|
||
raise ValueError("没有成功加载任何数据!请检查数据路径和文件格式。")
|
||
|
||
# 如果指定了要使用的情感列表,筛选数据
|
||
if selected_emotions:
|
||
filtered_data = [item for item in all_data if item[1] in selected_emotions]
|
||
if not filtered_data:
|
||
print(f"警告: 筛选后没有匹配的情感数据。可用的情感标签: {set(item[1] for item in all_data)}")
|
||
print(f"您请求的情感标签: {selected_emotions}")
|
||
# 回退到使用所有数据
|
||
filtered_data = all_data
|
||
all_data = filtered_data
|
||
|
||
print(f"总共加载了 {len(all_data)} 个有效音频文件")
|
||
|
||
# 显示各情感类别的数据分布
|
||
emotion_counts = {}
|
||
for item in all_data:
|
||
emotion = item[1]
|
||
if emotion in emotion_counts:
|
||
emotion_counts[emotion] += 1
|
||
else:
|
||
emotion_counts[emotion] = 1
|
||
|
||
print("数据分布:")
|
||
for emotion, count in emotion_counts.items():
|
||
print(f" {emotion}: {count} 个样本")
|
||
|
||
# 分离数据、情感标签和语言标签
|
||
X = [item[0] for item in all_data]
|
||
y_emotion = [item[1] for item in all_data]
|
||
y_language = [item[2] for item in all_data]
|
||
|
||
return X, y_emotion, y_language
|
||
|
||
def prepare_data(X, y_emotion, y_language, test_size=0.2, val_size=0.2, random_state=42):
|
||
"""
|
||
准备训练集、验证集和测试集
|
||
|
||
Args:
|
||
X: 音频数据
|
||
y_emotion: 情感标签
|
||
y_language: 语言标签
|
||
test_size: 测试集比例
|
||
val_size: 验证集比例
|
||
random_state: 随机种子
|
||
|
||
Returns:
|
||
训练集、验证集和测试集数据和标签
|
||
"""
|
||
# 确保数据不为空
|
||
if len(X) == 0:
|
||
raise ValueError("数据集为空,无法进行划分!请确保至少加载了一些有效的音频文件。")
|
||
|
||
# 先划分出测试集
|
||
X_train_val, X_test, y_emotion_train_val, y_emotion_test, y_language_train_val, y_language_test = train_test_split(
|
||
X, y_emotion, y_language, test_size=test_size, random_state=random_state, stratify=y_emotion
|
||
)
|
||
|
||
# 从剩余数据中划分训练集和验证集
|
||
val_ratio = val_size / (1 - test_size)
|
||
X_train, X_val, y_emotion_train, y_emotion_val, y_language_train, y_language_val = train_test_split(
|
||
X_train_val, y_emotion_train_val, y_language_train_val,
|
||
test_size=val_ratio, random_state=random_state, stratify=y_emotion_train_val
|
||
)
|
||
|
||
# 打印数据集大小
|
||
print(f"数据集划分: 训练集 {len(X_train)} 个样本, 验证集 {len(X_val)} 个样本, 测试集 {len(X_test)} 个样本")
|
||
|
||
return (
|
||
X_train, y_emotion_train, y_language_train,
|
||
X_val, y_emotion_val, y_language_val,
|
||
X_test, y_emotion_test, y_language_test
|
||
) |