You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
669 lines
20 KiB
669 lines
20 KiB
import os |
|
import sys |
|
|
|
sys.path.append('G:/Work_data/workstation/Forwork-voice-txt/FireRedASR') |
|
|
|
# 设置离线模式 |
|
os.environ['HF_HUB_OFFLINE'] = '1' |
|
os.environ['TRANSFORMERS_OFFLINE'] = '1' |
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.encoders import jsonable_encoder |
|
from pyannote.audio import Pipeline |
|
import torch |
|
import torchaudio |
|
from pydub import AudioSegment |
|
import soundfile as sf |
|
import numpy as np |
|
import librosa |
|
import time |
|
import uuid |
|
from pathlib import Path |
|
|
|
app = FastAPI(title="音频分段 API") |
|
|
|
# 配置 CORS |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
# 全局变量 |
|
pipeline = None |
|
device = None |
|
|
|
# 创建输出文件夹 |
|
OUTPUT_FOLDER = Path("segmented_audio") |
|
OUTPUT_FOLDER.mkdir(exist_ok=True) |
|
|
|
|
|
def load_audio(audio_path): |
|
"""加载音频文件""" |
|
# 使用 soundfile 加载音频数据 |
|
audio_data, sr = sf.read(str(audio_path)) |
|
|
|
# 转换为单声道 |
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
# 重采样到 16kHz |
|
if sr != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
sr = 16000 |
|
|
|
# 使用 pydub 加载用于分段导出 |
|
audio_pydub = AudioSegment.from_file(str(audio_path)) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
return audio_data, sr, audio_pydub |
|
|
|
|
|
def calculate_frame_energy(audio_data, sr, time_point, frame_length=0.025): |
|
""" |
|
计算指定时间点的帧能量(RMS) |
|
|
|
Args: |
|
audio_data: 音频数据 |
|
sr: 采样率 |
|
time_point: 时间点(秒) |
|
frame_length: 帧长(秒) |
|
|
|
Returns: |
|
能量值 |
|
""" |
|
frame_samples = int(frame_length * sr) |
|
center_sample = int(time_point * sr) |
|
|
|
# 提取帧(中心对齐) |
|
start_sample = max(0, center_sample - frame_samples // 2) |
|
end_sample = min(len(audio_data), start_sample + frame_samples) |
|
|
|
frame = audio_data[start_sample:end_sample] |
|
|
|
if len(frame) == 0: |
|
return 0.0 |
|
|
|
# RMS 能量 |
|
return np.sqrt(np.mean(frame ** 2)) |
|
|
|
|
|
def find_optimal_boundary(audio_data, sr, start_time, end_time, |
|
frame_length=0.025, hop_length=0.005): |
|
""" |
|
在指定时间范围内找到能量最低点 |
|
|
|
Args: |
|
audio_data: 音频数据 |
|
sr: 采样率 |
|
start_time: 搜索起始时间(秒) |
|
end_time: 搜索结束时间(秒) |
|
frame_length: 帧长(秒) |
|
hop_length: 帧移(秒) |
|
|
|
Returns: |
|
最优边界时间点, 最小能量值 |
|
""" |
|
# 边界检查 |
|
if start_time >= end_time: |
|
return start_time, 0.0 |
|
|
|
# 提取音频片段 |
|
start_sample = int(start_time * sr) |
|
end_sample = int(end_time * sr) |
|
segment = audio_data[start_sample:end_sample] |
|
|
|
if len(segment) == 0: |
|
return start_time, 0.0 |
|
|
|
# 参数 |
|
frame_samples = int(frame_length * sr) |
|
hop_samples = int(hop_length * sr) |
|
|
|
# 计算能量曲线 |
|
min_energy = float('inf') |
|
optimal_time = start_time |
|
|
|
for i in range(0, len(segment) - frame_samples + 1, hop_samples): |
|
frame = segment[i:i + frame_samples] |
|
|
|
# RMS 能量 |
|
energy = np.sqrt(np.mean(frame ** 2)) |
|
|
|
if energy < min_energy: |
|
min_energy = energy |
|
# 帧中心时间戳 |
|
optimal_time = start_time + (i + frame_samples / 2) / sr |
|
|
|
return optimal_time, min_energy |
|
|
|
|
|
def optimize_segment_boundaries(segments, audio_data, sr, |
|
small_gap_threshold=0.1, |
|
min_gap_to_keep=0.05): |
|
""" |
|
优化片段边界(双边优化) |
|
|
|
策略: |
|
- 小间隙(< 0.1s):共享边界点 |
|
- 大间隙(≥ 0.1s):独立优化 |
|
- 处理重叠和极小间隙 |
|
|
|
Args: |
|
segments: 片段列表 [{"start": ..., "end": ...}, ...] |
|
audio_data: 音频数据 |
|
sr: 采样率 |
|
small_gap_threshold: 小间隙阈值(秒) |
|
min_gap_to_keep: 最小保留间隙(秒) |
|
|
|
Returns: |
|
优化后的片段列表 |
|
""" |
|
if not segments: |
|
return [] |
|
|
|
audio_duration = len(audio_data) / sr |
|
optimized = [] |
|
|
|
for i, seg in enumerate(segments): |
|
current_seg = seg.copy() |
|
|
|
# === 优化起始边界 === |
|
if i == 0: |
|
# 第一个片段:从音频开头到片段起始 |
|
search_start = 0.0 |
|
search_end = seg["start"] |
|
else: |
|
# 其他片段:从前一片段结束到当前片段起始 |
|
prev_end = segments[i - 1]["end"] |
|
search_start = prev_end |
|
search_end = seg["start"] |
|
|
|
# 计算间隙大小 |
|
gap_size = search_end - search_start |
|
|
|
if gap_size > 0: |
|
if gap_size >= small_gap_threshold: |
|
# 大间隙:独立优化起始边界 |
|
optimal_start, _ = find_optimal_boundary( |
|
audio_data, sr, search_start, search_end |
|
) |
|
current_seg["start"] = optimal_start |
|
else: |
|
# 小间隙:稍后处理(共享边界) |
|
pass |
|
|
|
# === 优化结束边界 === |
|
if i == len(segments) - 1: |
|
# 最后一个片段:从片段结束到音频末尾 |
|
search_start = seg["end"] |
|
search_end = audio_duration |
|
else: |
|
# 其他片段:从当前片段结束到下一片段起始 |
|
next_start = segments[i + 1]["start"] |
|
search_start = seg["end"] |
|
search_end = next_start |
|
|
|
gap_size = search_end - search_start |
|
|
|
if gap_size > 0: |
|
if gap_size >= small_gap_threshold: |
|
# 大间隙:独立优化结束边界 |
|
optimal_end, _ = find_optimal_boundary( |
|
audio_data, sr, search_start, search_end |
|
) |
|
current_seg["end"] = optimal_end |
|
else: |
|
# 小间隙:稍后处理(共享边界) |
|
pass |
|
|
|
optimized.append(current_seg) |
|
|
|
# === 处理小间隙(共享边界优化)=== |
|
for i in range(len(optimized) - 1): |
|
gap = optimized[i + 1]["start"] - optimized[i]["end"] |
|
|
|
if 0 < gap < small_gap_threshold: |
|
# 在小间隙内找唯一的最优点 |
|
optimal_point, _ = find_optimal_boundary( |
|
audio_data, sr, |
|
optimized[i]["end"], |
|
optimized[i + 1]["start"] |
|
) |
|
# 共享边界 |
|
optimized[i]["end"] = optimal_point |
|
optimized[i + 1]["start"] = optimal_point |
|
|
|
# === 处理重叠和极小间隙 === |
|
final_optimized = [] |
|
for i in range(len(optimized)): |
|
if i == 0: |
|
final_optimized.append(optimized[i]) |
|
continue |
|
|
|
prev_seg = final_optimized[-1] |
|
curr_seg = optimized[i] |
|
|
|
gap = curr_seg["start"] - prev_seg["end"] |
|
|
|
if gap < 0: |
|
# 重叠:取中点 |
|
midpoint = (prev_seg["end"] + curr_seg["start"]) / 2 |
|
prev_seg["end"] = midpoint |
|
curr_seg["start"] = midpoint |
|
print(f"检测到重叠,修正为中点: {midpoint:.3f}s") |
|
elif 0 < gap < min_gap_to_keep: |
|
# 极小间隙:消除(合并到前一片段) |
|
prev_seg["end"] = curr_seg["start"] |
|
print(f"消除极小间隙: {gap * 1000:.1f}ms") |
|
|
|
final_optimized.append(curr_seg) |
|
|
|
return final_optimized |
|
|
|
|
|
def merge_segments_by_gap(segments, max_duration=50.0): |
|
""" |
|
按时间间隔优先合并相邻片段 |
|
|
|
策略: |
|
1. 计算所有相邻片段的间隔 |
|
2. 按间隔从小到大排序 |
|
3. 贪心合并(优先合并间隔小的) |
|
4. 合并条件:总时长 ≤ max_duration |
|
|
|
Args: |
|
segments: 片段列表 [{"start": ..., "end": ...}, ...] |
|
max_duration: 合并后最大时长(秒) |
|
|
|
Returns: |
|
合并后的片段列表 |
|
""" |
|
if not segments: |
|
return [] |
|
|
|
if len(segments) == 1: |
|
return segments |
|
|
|
# 初始化:每个片段是独立的组 |
|
groups = [[seg] for seg in segments] |
|
|
|
# 迭代合并 |
|
while True: |
|
# 计算所有相邻组的间隔 |
|
gaps = [] |
|
for i in range(len(groups) - 1): |
|
group1_end = groups[i][-1]["end"] |
|
group2_start = groups[i + 1][0]["start"] |
|
gap = group2_start - group1_end |
|
|
|
# 计算合并后的总时长 |
|
merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"] |
|
|
|
gaps.append({ |
|
"index": i, |
|
"gap": gap, |
|
"merged_duration": merged_duration |
|
}) |
|
|
|
if not gaps: |
|
break |
|
|
|
# 按间隔排序(从小到大) |
|
gaps.sort(key=lambda x: x["gap"]) |
|
|
|
# 找到第一个可以合并的间隔 |
|
merged = False |
|
for gap_info in gaps: |
|
if gap_info["merged_duration"] <= max_duration: |
|
# 可以合并 |
|
idx = gap_info["index"] |
|
# 合并两个组 |
|
groups[idx].extend(groups[idx + 1]) |
|
groups.pop(idx + 1) |
|
merged = True |
|
print(f"合并间隔 {gap_info['gap']:.3f}s, 合并后时长 {gap_info['merged_duration']:.2f}s") |
|
break |
|
|
|
# 如果没有可以合并的,退出 |
|
if not merged: |
|
break |
|
|
|
# 将组转换为合并后的片段 |
|
merged_segments = [] |
|
for group in groups: |
|
merged_seg = { |
|
"start": group[0]["start"], |
|
"end": group[-1]["end"], |
|
"original_count": len(group) |
|
} |
|
merged_segments.append(merged_seg) |
|
|
|
return merged_segments |
|
|
|
|
|
def export_audio_segments(segments, audio_pydub, output_folder, session_id): |
|
""" |
|
导出音频片段文件 |
|
|
|
Args: |
|
segments: 片段列表 |
|
audio_pydub: pydub音频对象 |
|
output_folder: 输出文件夹 |
|
session_id: 会话ID |
|
|
|
Returns: |
|
导出的文件信息列表 |
|
""" |
|
session_folder = output_folder / session_id |
|
session_folder.mkdir(exist_ok=True) |
|
|
|
exported_files = [] |
|
|
|
for i, seg in enumerate(segments): |
|
start_time = seg["start"] |
|
end_time = seg["end"] |
|
|
|
# 转换为毫秒(添加25ms padding) |
|
start_ms = max(0, int(start_time * 1000) - 25) |
|
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25) |
|
|
|
# 提取音频片段 |
|
segment_audio = audio_pydub[start_ms:end_ms] |
|
|
|
# 文件名 |
|
filename = f"segment_{i + 1:03d}.wav" |
|
filepath = session_folder / filename |
|
|
|
# 导出 |
|
segment_audio.export(str(filepath), format="wav") |
|
|
|
exported_files.append({ |
|
"index": i + 1, |
|
"filename": filename, |
|
"filepath": str(filepath), |
|
"start_time": round(start_time, 3), |
|
"end_time": round(end_time, 3), |
|
"duration": round(end_time - start_time, 3) |
|
}) |
|
|
|
return exported_files |
|
|
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
"""启动时加载模型""" |
|
global pipeline, device |
|
|
|
print("正在加载 Pyannote 模型...") |
|
|
|
# 本地模型路径 |
|
local_config = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml" |
|
|
|
# 加载 Pyannote 模型 |
|
pipeline = Pipeline.from_pretrained(local_config) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.to(device) |
|
print(f"Pyannote 模型加载成功!使用设备: {device}") |
|
|
|
print("服务已就绪") |
|
|
|
|
|
@app.post("/api/segment") |
|
async def segment_audio(file: UploadFile = File(...)): |
|
""" |
|
音频分段 API |
|
|
|
处理流程: |
|
1. Pyannote 初始分段 |
|
2. 按时间间隔合并 |
|
3. 边界优化 |
|
4. 导出音频片段 |
|
""" |
|
if not pipeline: |
|
raise HTTPException(status_code=500, detail="Pyannote 模型未加载") |
|
|
|
# 验证文件格式 |
|
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] |
|
file_ext = os.path.splitext(file.filename)[1].lower() |
|
if file_ext not in allowed_extensions: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"不支持的文件格式。支持的格式: {', '.join(allowed_extensions)}" |
|
) |
|
|
|
temp_audio_path = None |
|
|
|
try: |
|
# 生成会话ID |
|
session_id = f"{int(time.time())}_{str(uuid.uuid4())[:8]}" |
|
|
|
# 保存临时音频文件 |
|
temp_audio_path = Path(f"temp_{session_id}{file_ext}") |
|
with open(temp_audio_path, "wb") as f: |
|
content = await file.read() |
|
f.write(content) |
|
|
|
print(f"\n{'=' * 60}") |
|
print(f"开始处理音频: {file.filename}") |
|
print(f"会话ID: {session_id}") |
|
|
|
total_start_time = time.time() |
|
|
|
# ===== 阶段1:音频预处理 ===== |
|
print(f"\n[阶段1] 音频预处理...") |
|
preprocess_start = time.time() |
|
|
|
audio_data, sr, audio_pydub = load_audio(temp_audio_path) |
|
duration = len(audio_data) / sr |
|
print(f" 音频时长: {duration:.1f}秒") |
|
|
|
preprocess_time = time.time() - preprocess_start |
|
print(f" 耗时: {preprocess_time:.2f}秒") |
|
|
|
# ===== 阶段2:Pyannote 初始分段 ===== |
|
print(f"\n[阶段2] Pyannote 初始分段...") |
|
segment_start = time.time() |
|
|
|
# 加载音频(用于 pyannote) |
|
waveform, sample_rate = torchaudio.load(str(temp_audio_path)) |
|
|
|
# 转换为单声道 |
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
# 重采样到 16kHz |
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
|
waveform = resampler(waveform) |
|
sample_rate = 16000 |
|
|
|
# 移动到设备 |
|
waveform = waveform.to(device) |
|
audio = {"waveform": waveform, "sample_rate": sample_rate} |
|
|
|
# 调用 pipeline |
|
diarization_result = pipeline(audio) |
|
|
|
# 提取结果 |
|
initial_segments = [] |
|
if hasattr(diarization_result, 'speaker_diarization'): |
|
annotation = diarization_result.speaker_diarization |
|
for segment, track, speaker in annotation.itertracks(yield_label=True): |
|
initial_segments.append({ |
|
"start": round(segment.start, 3), |
|
"end": round(segment.end, 3), |
|
"speaker": f"speaker_{speaker}" |
|
}) |
|
else: |
|
for segment, track, speaker in diarization_result.itertracks(yield_label=True): |
|
initial_segments.append({ |
|
"start": round(segment.start, 3), |
|
"end": round(segment.end, 3), |
|
"speaker": f"speaker_{speaker}" |
|
}) |
|
|
|
initial_count = len(initial_segments) |
|
print(f" 初始片段数: {initial_count}") |
|
|
|
segment_time = time.time() - segment_start |
|
print(f" 耗时: {segment_time:.2f}秒") |
|
|
|
# ===== 阶段3:按间隔合并 ===== |
|
print(f"\n[阶段3] 按时间间隔合并...") |
|
merge_start = time.time() |
|
|
|
merged_segments = merge_segments_by_gap(initial_segments, max_duration=50.0) |
|
merged_count = len(merged_segments) |
|
print(f" 合并后片段数: {merged_count}") |
|
print(f" 减少比例: {(1 - merged_count / initial_count) * 100:.1f}%") |
|
|
|
merge_time = time.time() - merge_start |
|
print(f" 耗时: {merge_time:.2f}秒") |
|
|
|
# ===== 阶段4:边界优化 ===== |
|
print(f"\n[阶段4] 边界优化(能量分析)...") |
|
optimize_start = time.time() |
|
|
|
optimized_segments = optimize_segment_boundaries( |
|
merged_segments, |
|
audio_data, |
|
sr, |
|
small_gap_threshold=0.1, |
|
min_gap_to_keep=0.05 |
|
) |
|
|
|
optimize_time = time.time() - optimize_start |
|
print(f" 优化完成") |
|
print(f" 耗时: {optimize_time:.2f}秒") |
|
|
|
# ===== 阶段5:导出音频片段 ===== |
|
print(f"\n[阶段5] 导出音频片段...") |
|
export_start = time.time() |
|
|
|
exported_files = export_audio_segments( |
|
optimized_segments, |
|
audio_pydub, |
|
OUTPUT_FOLDER, |
|
session_id |
|
) |
|
|
|
export_time = time.time() - export_start |
|
print(f" 导出完成: {len(exported_files)} 个文件") |
|
print(f" 输出目录: {OUTPUT_FOLDER / session_id}") |
|
print(f" 耗时: {export_time:.2f}秒") |
|
|
|
# ===== 总结 ===== |
|
total_time = time.time() - total_start_time |
|
print(f"\n{'=' * 60}") |
|
print(f"处理完成!") |
|
print(f"总耗时: {total_time:.2f}秒") |
|
print(f"{'=' * 60}\n") |
|
|
|
# 构建响应 |
|
response_data = { |
|
"session_id": session_id, |
|
"segments": exported_files, |
|
"statistics": { |
|
"original_segments": initial_count, |
|
"merged_segments": merged_count, |
|
"final_segments": len(optimized_segments), |
|
"merge_reduction": round((1 - merged_count / initial_count) * 100, 1), |
|
"audio_duration": round(duration, 2), |
|
"processing_time": round(total_time, 2), |
|
"time_breakdown": { |
|
"preprocess": round(preprocess_time, 2), |
|
"segmentation": round(segment_time, 2), |
|
"merge": round(merge_time, 2), |
|
"optimization": round(optimize_time, 2), |
|
"export": round(export_time, 2) |
|
} |
|
}, |
|
"output_folder": str(OUTPUT_FOLDER / session_id) |
|
} |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": jsonable_encoder(response_data), |
|
"message": "success", |
|
"code": 200 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
except Exception as e: |
|
print(f"\n处理错误: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": str(e), |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
finally: |
|
# 清理临时文件 |
|
if temp_audio_path and temp_audio_path.exists(): |
|
temp_audio_path.unlink(missing_ok=True) |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""健康检查接口""" |
|
return JSONResponse( |
|
content={ |
|
"status": "healthy", |
|
"pyannote_model_loaded": pipeline is not None, |
|
"architecture": "pyannote_segment_only", |
|
"description": "音频分段服务(含边界优化)", |
|
"message": "Service is running" |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
"""根路径""" |
|
return { |
|
"message": "音频分段 API 服务", |
|
"version": "3.0", |
|
"config": { |
|
"mode": "pyannote + gap-based merge + boundary optimization", |
|
"max_segment_duration": "50s", |
|
"boundary_optimization": "enabled (RMS energy)", |
|
"merge_strategy": "gap-priority" |
|
}, |
|
"endpoints": { |
|
"segment": "/api/segment - POST 上传音频文件进行分段", |
|
"health": "/health - GET 健康检查" |
|
} |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
print("=" * 60) |
|
print("音频分段服务启动中...") |
|
print("=" * 60) |
|
print("功能特性:") |
|
print(" ✓ Pyannote 高质量初始分段") |
|
print(" ✓ 按时间间隔优先合并(智能合并策略)") |
|
print(" ✓ 边界能量优化(RMS短时能量分析)") |
|
print(" ✓ 双边独立优化(小间隙共享边界)") |
|
print(" ✓ 自动处理重叠和极小间隙") |
|
print("=" * 60) |
|
print("启动服务器...") |
|
print("=" * 60) |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8001) |