Files
yuyinfenxi/data_utils/data_loader.py
2025-07-02 13:54:05 +08:00

362 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载模块,用于读取并处理语音情感数据集
支持CASIA、SAVEE和RAVDESS三种数据集
"""
import os
import glob
import numpy as np
import pandas as pd
import librosa
import re
from sklearn.model_selection import train_test_split
import warnings
# 定义常量
SAMPLE_RATE = 22050 # 统一采样率
MAX_DURATION = 5 # 最大音频长度(秒)
MAX_SAMPLES = SAMPLE_RATE * MAX_DURATION # 最大样本数
# 情绪映射字典
EMOTION_MAPPING = {
# CASIA 情感映射
'angry': 'angry',
'fear': 'fear',
'happy': 'happy',
'neutral': 'neutral',
'sad': 'sad',
'surprise': 'surprise',
# SAVEE 情感映射
'a': 'angry',
'f': 'fear',
'h': 'happy',
'n': 'neutral',
'sa': 'sad',
'su': 'surprise',
'd': 'disgust', # 注意SAVEE有厌恶情绪但CASIA没有
# RAVDESS 情感映射
'01': 'neutral',
'02': 'calm', # 注意RAVDESS有平静情绪但CASIA没有
'03': 'happy',
'04': 'sad',
'05': 'angry',
'06': 'fear',
'07': 'disgust', # 注意RAVDESS有厌恶情绪但CASIA没有
'08': 'surprise'
}
# 语言映射
LANGUAGE_MAPPING = {
'casia': 'zh', # 中文
'savee': 'en', # 英文
'ravdess': 'en' # 英文
}
def load_casia(data_path):
"""
加载CASIA中文情感语音数据集
Args:
data_path: CASIA数据集路径
Returns:
data_list: 包含(音频数据, 情感标签, 语言标签)的列表
"""
data_list = []
# 确保路径存在
if not os.path.exists(data_path):
print(f"警告: CASIA数据路径不存在: {data_path}")
return data_list
# 尝试获取演员列表
try:
actors = os.listdir(data_path)
except Exception as e:
print(f"错误: 无法读取CASIA目录: {e}")
return data_list
success_count = 0
error_count = 0
for actor in actors:
# 跳过隐藏文件或非目录
if actor.startswith('_') or not os.path.isdir(os.path.join(data_path, actor)):
continue
actor_path = os.path.join(data_path, actor)
emotions = os.listdir(actor_path)
for emotion in emotions:
# 跳过隐藏文件或非目录
if emotion.startswith('_') or not os.path.isdir(os.path.join(actor_path, emotion)):
continue
emotion_path = os.path.join(actor_path, emotion)
audio_files = glob.glob(os.path.join(emotion_path, "*.wav"))
for audio_file in audio_files:
# 读取音频文件
try:
audio, sr = librosa.load(audio_file, sr=SAMPLE_RATE, res_type='kaiser_fast')
# 统一音频长度
if len(audio) < MAX_SAMPLES:
# 音频太短用0填充
padding = MAX_SAMPLES - len(audio)
audio = np.pad(audio, (0, padding), 'constant')
else:
# 音频太长,截断
audio = audio[:MAX_SAMPLES]
data_list.append((audio, EMOTION_MAPPING[emotion], LANGUAGE_MAPPING['casia']))
success_count += 1
except Exception as e:
error_count += 1
if error_count < 10: # 只显示前10个错误避免日志过多
print(f"Error loading {audio_file}: {e}")
elif error_count == 10:
print("过多加载错误,后续错误将不再显示...")
print(f"CASIA数据集: 成功加载 {success_count} 个文件,失败 {error_count} 个文件")
return data_list
def load_savee(data_path):
"""
加载SAVEE英文情感语音数据集
Args:
data_path: SAVEE数据集路径
Returns:
data_list: 包含(音频数据, 情感标签, 语言标签)的列表
"""
data_list = []
# 确保路径存在
if not os.path.exists(data_path):
print(f"警告: SAVEE数据路径不存在: {data_path}")
return data_list
audio_path = os.path.join(data_path, "AudioData")
if not os.path.exists(audio_path):
print(f"警告: SAVEE AudioData路径不存在: {audio_path}")
return data_list
# SAVEE数据集中的四个说话者
actors = ['DC', 'JE', 'JK', 'KL']
success_count = 0
error_count = 0
for actor in actors:
actor_path = os.path.join(audio_path, actor)
if not os.path.isdir(actor_path):
print(f"警告: SAVEE演员目录不存在: {actor_path}")
continue
audio_files = glob.glob(os.path.join(actor_path, "*.wav"))
for audio_file in audio_files:
file_name = os.path.basename(audio_file)
# 提取情感标签SAVEE使用文件名的前1-2个字母作为情感标签
if file_name.startswith("sa"):
emotion = "sa"
elif file_name.startswith("su"):
emotion = "su"
else:
emotion = file_name[0]
try:
audio, sr = librosa.load(audio_file, sr=SAMPLE_RATE, res_type='kaiser_fast')
# 统一音频长度
if len(audio) < MAX_SAMPLES:
padding = MAX_SAMPLES - len(audio)
audio = np.pad(audio, (0, padding), 'constant')
else:
audio = audio[:MAX_SAMPLES]
data_list.append((audio, EMOTION_MAPPING[emotion], LANGUAGE_MAPPING['savee']))
success_count += 1
except Exception as e:
error_count += 1
if error_count < 10: # 只显示前10个错误
print(f"Error loading {audio_file}: {e}")
elif error_count == 10:
print("过多加载错误,后续错误将不再显示...")
print(f"SAVEE数据集: 成功加载 {success_count} 个文件,失败 {error_count} 个文件")
return data_list
def load_ravdess(data_path):
"""
加载RAVDESS英文情感语音数据集
Args:
data_path: RAVDESS数据集路径
Returns:
data_list: 包含(音频数据, 情感标签, 语言标签)的列表
"""
data_list = []
# 确保路径存在
if not os.path.exists(data_path):
print(f"警告: RAVDESS数据路径不存在: {data_path}")
return data_list
# 获取所有演员目录
try:
actor_dirs = glob.glob(os.path.join(data_path, "Actor_*"))
except Exception as e:
print(f"错误: 无法获取RAVDESS演员目录: {e}")
return data_list
if not actor_dirs:
print(f"警告: RAVDESS演员目录为空: {data_path}")
success_count = 0
error_count = 0
for actor_dir in actor_dirs:
audio_files = glob.glob(os.path.join(actor_dir, "*.wav"))
for audio_file in audio_files:
file_name = os.path.basename(audio_file)
# RAVDESS文件名格式: 03-01-05-01-02-01-12.wav
# 05 表示情感类别 (angry)
parts = file_name.split('-')
if len(parts) >= 3:
emotion = parts[2]
try:
audio, sr = librosa.load(audio_file, sr=SAMPLE_RATE, res_type='kaiser_fast')
# 统一音频长度
if len(audio) < MAX_SAMPLES:
padding = MAX_SAMPLES - len(audio)
audio = np.pad(audio, (0, padding), 'constant')
else:
audio = audio[:MAX_SAMPLES]
data_list.append((audio, EMOTION_MAPPING[emotion], LANGUAGE_MAPPING['ravdess']))
success_count += 1
except Exception as e:
error_count += 1
if error_count < 10: # 只显示前10个错误
print(f"Error loading {audio_file}: {e}")
elif error_count == 10:
print("过多加载错误,后续错误将不再显示...")
print(f"RAVDESS数据集: 成功加载 {success_count} 个文件,失败 {error_count} 个文件")
return data_list
def load_all_data(casia_path, savee_path, ravdess_path, selected_emotions=None):
"""
加载所有数据集
Args:
casia_path: CASIA数据集路径
savee_path: SAVEE数据集路径
ravdess_path: RAVDESS数据集路径
selected_emotions: 要使用的情感列表如果为None则使用所有共有的情感
Returns:
X: 音频数据列表
y_emotion: 情感标签列表
y_language: 语言标签列表
"""
print("加载CASIA数据集...")
casia_data = load_casia(casia_path)
print("加载SAVEE数据集...")
savee_data = load_savee(savee_path)
print("加载RAVDESS数据集...")
ravdess_data = load_ravdess(ravdess_path)
# 合并所有数据
all_data = casia_data + savee_data + ravdess_data
# 检查是否有数据被加载
if not all_data:
raise ValueError("没有成功加载任何数据!请检查数据路径和文件格式。")
# 如果指定了要使用的情感列表,筛选数据
if selected_emotions:
filtered_data = [item for item in all_data if item[1] in selected_emotions]
if not filtered_data:
print(f"警告: 筛选后没有匹配的情感数据。可用的情感标签: {set(item[1] for item in all_data)}")
print(f"您请求的情感标签: {selected_emotions}")
# 回退到使用所有数据
filtered_data = all_data
all_data = filtered_data
print(f"总共加载了 {len(all_data)} 个有效音频文件")
# 显示各情感类别的数据分布
emotion_counts = {}
for item in all_data:
emotion = item[1]
if emotion in emotion_counts:
emotion_counts[emotion] += 1
else:
emotion_counts[emotion] = 1
print("数据分布:")
for emotion, count in emotion_counts.items():
print(f" {emotion}: {count} 个样本")
# 分离数据、情感标签和语言标签
X = [item[0] for item in all_data]
y_emotion = [item[1] for item in all_data]
y_language = [item[2] for item in all_data]
return X, y_emotion, y_language
def prepare_data(X, y_emotion, y_language, test_size=0.2, val_size=0.2, random_state=42):
"""
准备训练集、验证集和测试集
Args:
X: 音频数据
y_emotion: 情感标签
y_language: 语言标签
test_size: 测试集比例
val_size: 验证集比例
random_state: 随机种子
Returns:
训练集、验证集和测试集数据和标签
"""
# 确保数据不为空
if len(X) == 0:
raise ValueError("数据集为空,无法进行划分!请确保至少加载了一些有效的音频文件。")
# 先划分出测试集
X_train_val, X_test, y_emotion_train_val, y_emotion_test, y_language_train_val, y_language_test = train_test_split(
X, y_emotion, y_language, test_size=test_size, random_state=random_state, stratify=y_emotion
)
# 从剩余数据中划分训练集和验证集
val_ratio = val_size / (1 - test_size)
X_train, X_val, y_emotion_train, y_emotion_val, y_language_train, y_language_val = train_test_split(
X_train_val, y_emotion_train_val, y_language_train_val,
test_size=val_ratio, random_state=random_state, stratify=y_emotion_train_val
)
# 打印数据集大小
print(f"数据集划分: 训练集 {len(X_train)} 个样本, 验证集 {len(X_val)} 个样本, 测试集 {len(X_test)} 个样本")
return (
X_train, y_emotion_train, y_language_train,
X_val, y_emotion_val, y_language_val,
X_test, y_emotion_test, y_language_test
)