You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

669 lines
20 KiB

import os
import sys
sys.path.append('G:/Work_data/workstation/Forwork-voice-txt/FireRedASR')
# 设置离线模式
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.encoders import jsonable_encoder
from pyannote.audio import Pipeline
import torch
import torchaudio
from pydub import AudioSegment
import soundfile as sf
import numpy as np
import librosa
import time
import uuid
from pathlib import Path
app = FastAPI(title="音频分段 API")
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量
pipeline = None
device = None
# 创建输出文件夹
OUTPUT_FOLDER = Path("segmented_audio")
OUTPUT_FOLDER.mkdir(exist_ok=True)
def load_audio(audio_path):
"""加载音频文件"""
# 使用 soundfile 加载音频数据
audio_data, sr = sf.read(str(audio_path))
# 转换为单声道
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# 重采样到 16kHz
if sr != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
sr = 16000
# 使用 pydub 加载用于分段导出
audio_pydub = AudioSegment.from_file(str(audio_path))
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1)
return audio_data, sr, audio_pydub
def calculate_frame_energy(audio_data, sr, time_point, frame_length=0.025):
"""
计算指定时间点的帧能量(RMS)
Args:
audio_data: 音频数据
sr: 采样率
time_point: 时间点(秒)
frame_length: 帧长(秒)
Returns:
能量值
"""
frame_samples = int(frame_length * sr)
center_sample = int(time_point * sr)
# 提取帧(中心对齐)
start_sample = max(0, center_sample - frame_samples // 2)
end_sample = min(len(audio_data), start_sample + frame_samples)
frame = audio_data[start_sample:end_sample]
if len(frame) == 0:
return 0.0
# RMS 能量
return np.sqrt(np.mean(frame ** 2))
def find_optimal_boundary(audio_data, sr, start_time, end_time,
frame_length=0.025, hop_length=0.005):
"""
在指定时间范围内找到能量最低点
Args:
audio_data: 音频数据
sr: 采样率
start_time: 搜索起始时间(秒)
end_time: 搜索结束时间(秒)
frame_length: 帧长(秒)
hop_length: 帧移(秒)
Returns:
最优边界时间点, 最小能量值
"""
# 边界检查
if start_time >= end_time:
return start_time, 0.0
# 提取音频片段
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
segment = audio_data[start_sample:end_sample]
if len(segment) == 0:
return start_time, 0.0
# 参数
frame_samples = int(frame_length * sr)
hop_samples = int(hop_length * sr)
# 计算能量曲线
min_energy = float('inf')
optimal_time = start_time
for i in range(0, len(segment) - frame_samples + 1, hop_samples):
frame = segment[i:i + frame_samples]
# RMS 能量
energy = np.sqrt(np.mean(frame ** 2))
if energy < min_energy:
min_energy = energy
# 帧中心时间戳
optimal_time = start_time + (i + frame_samples / 2) / sr
return optimal_time, min_energy
def optimize_segment_boundaries(segments, audio_data, sr,
small_gap_threshold=0.1,
min_gap_to_keep=0.05):
"""
优化片段边界(双边优化)
策略:
- 小间隙(< 0.1s):共享边界点
- 大间隙(≥ 0.1s):独立优化
- 处理重叠和极小间隙
Args:
segments: 片段列表 [{"start": ..., "end": ...}, ...]
audio_data: 音频数据
sr: 采样率
small_gap_threshold: 小间隙阈值(秒)
min_gap_to_keep: 最小保留间隙(秒)
Returns:
优化后的片段列表
"""
if not segments:
return []
audio_duration = len(audio_data) / sr
optimized = []
for i, seg in enumerate(segments):
current_seg = seg.copy()
# === 优化起始边界 ===
if i == 0:
# 第一个片段:从音频开头到片段起始
search_start = 0.0
search_end = seg["start"]
else:
# 其他片段:从前一片段结束到当前片段起始
prev_end = segments[i - 1]["end"]
search_start = prev_end
search_end = seg["start"]
# 计算间隙大小
gap_size = search_end - search_start
if gap_size > 0:
if gap_size >= small_gap_threshold:
# 大间隙:独立优化起始边界
optimal_start, _ = find_optimal_boundary(
audio_data, sr, search_start, search_end
)
current_seg["start"] = optimal_start
else:
# 小间隙:稍后处理(共享边界)
pass
# === 优化结束边界 ===
if i == len(segments) - 1:
# 最后一个片段:从片段结束到音频末尾
search_start = seg["end"]
search_end = audio_duration
else:
# 其他片段:从当前片段结束到下一片段起始
next_start = segments[i + 1]["start"]
search_start = seg["end"]
search_end = next_start
gap_size = search_end - search_start
if gap_size > 0:
if gap_size >= small_gap_threshold:
# 大间隙:独立优化结束边界
optimal_end, _ = find_optimal_boundary(
audio_data, sr, search_start, search_end
)
current_seg["end"] = optimal_end
else:
# 小间隙:稍后处理(共享边界)
pass
optimized.append(current_seg)
# === 处理小间隙(共享边界优化)===
for i in range(len(optimized) - 1):
gap = optimized[i + 1]["start"] - optimized[i]["end"]
if 0 < gap < small_gap_threshold:
# 在小间隙内找唯一的最优点
optimal_point, _ = find_optimal_boundary(
audio_data, sr,
optimized[i]["end"],
optimized[i + 1]["start"]
)
# 共享边界
optimized[i]["end"] = optimal_point
optimized[i + 1]["start"] = optimal_point
# === 处理重叠和极小间隙 ===
final_optimized = []
for i in range(len(optimized)):
if i == 0:
final_optimized.append(optimized[i])
continue
prev_seg = final_optimized[-1]
curr_seg = optimized[i]
gap = curr_seg["start"] - prev_seg["end"]
if gap < 0:
# 重叠:取中点
midpoint = (prev_seg["end"] + curr_seg["start"]) / 2
prev_seg["end"] = midpoint
curr_seg["start"] = midpoint
print(f"检测到重叠,修正为中点: {midpoint:.3f}s")
elif 0 < gap < min_gap_to_keep:
# 极小间隙:消除(合并到前一片段)
prev_seg["end"] = curr_seg["start"]
print(f"消除极小间隙: {gap * 1000:.1f}ms")
final_optimized.append(curr_seg)
return final_optimized
def merge_segments_by_gap(segments, max_duration=50.0):
"""
按时间间隔优先合并相邻片段
策略:
1. 计算所有相邻片段的间隔
2. 按间隔从小到大排序
3. 贪心合并(优先合并间隔小的)
4. 合并条件:总时长 ≤ max_duration
Args:
segments: 片段列表 [{"start": ..., "end": ...}, ...]
max_duration: 合并后最大时长(秒)
Returns:
合并后的片段列表
"""
if not segments:
return []
if len(segments) == 1:
return segments
# 初始化:每个片段是独立的组
groups = [[seg] for seg in segments]
# 迭代合并
while True:
# 计算所有相邻组的间隔
gaps = []
for i in range(len(groups) - 1):
group1_end = groups[i][-1]["end"]
group2_start = groups[i + 1][0]["start"]
gap = group2_start - group1_end
# 计算合并后的总时长
merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"]
gaps.append({
"index": i,
"gap": gap,
"merged_duration": merged_duration
})
if not gaps:
break
# 按间隔排序(从小到大)
gaps.sort(key=lambda x: x["gap"])
# 找到第一个可以合并的间隔
merged = False
for gap_info in gaps:
if gap_info["merged_duration"] <= max_duration:
# 可以合并
idx = gap_info["index"]
# 合并两个组
groups[idx].extend(groups[idx + 1])
groups.pop(idx + 1)
merged = True
print(f"合并间隔 {gap_info['gap']:.3f}s, 合并后时长 {gap_info['merged_duration']:.2f}s")
break
# 如果没有可以合并的,退出
if not merged:
break
# 将组转换为合并后的片段
merged_segments = []
for group in groups:
merged_seg = {
"start": group[0]["start"],
"end": group[-1]["end"],
"original_count": len(group)
}
merged_segments.append(merged_seg)
return merged_segments
def export_audio_segments(segments, audio_pydub, output_folder, session_id):
"""
导出音频片段文件
Args:
segments: 片段列表
audio_pydub: pydub音频对象
output_folder: 输出文件夹
session_id: 会话ID
Returns:
导出的文件信息列表
"""
session_folder = output_folder / session_id
session_folder.mkdir(exist_ok=True)
exported_files = []
for i, seg in enumerate(segments):
start_time = seg["start"]
end_time = seg["end"]
# 转换为毫秒(添加25ms padding)
start_ms = max(0, int(start_time * 1000) - 25)
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25)
# 提取音频片段
segment_audio = audio_pydub[start_ms:end_ms]
# 文件名
filename = f"segment_{i + 1:03d}.wav"
filepath = session_folder / filename
# 导出
segment_audio.export(str(filepath), format="wav")
exported_files.append({
"index": i + 1,
"filename": filename,
"filepath": str(filepath),
"start_time": round(start_time, 3),
"end_time": round(end_time, 3),
"duration": round(end_time - start_time, 3)
})
return exported_files
@app.on_event("startup")
async def load_model():
"""启动时加载模型"""
global pipeline, device
print("正在加载 Pyannote 模型...")
# 本地模型路径
local_config = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml"
# 加载 Pyannote 模型
pipeline = Pipeline.from_pretrained(local_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
print(f"Pyannote 模型加载成功!使用设备: {device}")
print("服务已就绪")
@app.post("/api/segment")
async def segment_audio(file: UploadFile = File(...)):
"""
音频分段 API
处理流程:
1. Pyannote 初始分段
2. 按时间间隔合并
3. 边界优化
4. 导出音频片段
"""
if not pipeline:
raise HTTPException(status_code=500, detail="Pyannote 模型未加载")
# 验证文件格式
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
file_ext = os.path.splitext(file.filename)[1].lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"不支持的文件格式。支持的格式: {', '.join(allowed_extensions)}"
)
temp_audio_path = None
try:
# 生成会话ID
session_id = f"{int(time.time())}_{str(uuid.uuid4())[:8]}"
# 保存临时音频文件
temp_audio_path = Path(f"temp_{session_id}{file_ext}")
with open(temp_audio_path, "wb") as f:
content = await file.read()
f.write(content)
print(f"\n{'=' * 60}")
print(f"开始处理音频: {file.filename}")
print(f"会话ID: {session_id}")
total_start_time = time.time()
# ===== 阶段1:音频预处理 =====
print(f"\n[阶段1] 音频预处理...")
preprocess_start = time.time()
audio_data, sr, audio_pydub = load_audio(temp_audio_path)
duration = len(audio_data) / sr
print(f" 音频时长: {duration:.1f}")
preprocess_time = time.time() - preprocess_start
print(f" 耗时: {preprocess_time:.2f}")
# ===== 阶段2:Pyannote 初始分段 =====
print(f"\n[阶段2] Pyannote 初始分段...")
segment_start = time.time()
# 加载音频(用于 pyannote)
waveform, sample_rate = torchaudio.load(str(temp_audio_path))
# 转换为单声道
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 重采样到 16kHz
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
# 移动到设备
waveform = waveform.to(device)
audio = {"waveform": waveform, "sample_rate": sample_rate}
# 调用 pipeline
diarization_result = pipeline(audio)
# 提取结果
initial_segments = []
if hasattr(diarization_result, 'speaker_diarization'):
annotation = diarization_result.speaker_diarization
for segment, track, speaker in annotation.itertracks(yield_label=True):
initial_segments.append({
"start": round(segment.start, 3),
"end": round(segment.end, 3),
"speaker": f"speaker_{speaker}"
})
else:
for segment, track, speaker in diarization_result.itertracks(yield_label=True):
initial_segments.append({
"start": round(segment.start, 3),
"end": round(segment.end, 3),
"speaker": f"speaker_{speaker}"
})
initial_count = len(initial_segments)
print(f" 初始片段数: {initial_count}")
segment_time = time.time() - segment_start
print(f" 耗时: {segment_time:.2f}")
# ===== 阶段3:按间隔合并 =====
print(f"\n[阶段3] 按时间间隔合并...")
merge_start = time.time()
merged_segments = merge_segments_by_gap(initial_segments, max_duration=50.0)
merged_count = len(merged_segments)
print(f" 合并后片段数: {merged_count}")
print(f" 减少比例: {(1 - merged_count / initial_count) * 100:.1f}%")
merge_time = time.time() - merge_start
print(f" 耗时: {merge_time:.2f}")
# ===== 阶段4:边界优化 =====
print(f"\n[阶段4] 边界优化(能量分析)...")
optimize_start = time.time()
optimized_segments = optimize_segment_boundaries(
merged_segments,
audio_data,
sr,
small_gap_threshold=0.1,
min_gap_to_keep=0.05
)
optimize_time = time.time() - optimize_start
print(f" 优化完成")
print(f" 耗时: {optimize_time:.2f}")
# ===== 阶段5:导出音频片段 =====
print(f"\n[阶段5] 导出音频片段...")
export_start = time.time()
exported_files = export_audio_segments(
optimized_segments,
audio_pydub,
OUTPUT_FOLDER,
session_id
)
export_time = time.time() - export_start
print(f" 导出完成: {len(exported_files)} 个文件")
print(f" 输出目录: {OUTPUT_FOLDER / session_id}")
print(f" 耗时: {export_time:.2f}")
# ===== 总结 =====
total_time = time.time() - total_start_time
print(f"\n{'=' * 60}")
print(f"处理完成!")
print(f"总耗时: {total_time:.2f}")
print(f"{'=' * 60}\n")
# 构建响应
response_data = {
"session_id": session_id,
"segments": exported_files,
"statistics": {
"original_segments": initial_count,
"merged_segments": merged_count,
"final_segments": len(optimized_segments),
"merge_reduction": round((1 - merged_count / initial_count) * 100, 1),
"audio_duration": round(duration, 2),
"processing_time": round(total_time, 2),
"time_breakdown": {
"preprocess": round(preprocess_time, 2),
"segmentation": round(segment_time, 2),
"merge": round(merge_time, 2),
"optimization": round(optimize_time, 2),
"export": round(export_time, 2)
}
},
"output_folder": str(OUTPUT_FOLDER / session_id)
}
return JSONResponse(
content={
"data": jsonable_encoder(response_data),
"message": "success",
"code": 200
},
status_code=200
)
except Exception as e:
print(f"\n处理错误: {str(e)}")
import traceback
traceback.print_exc()
return JSONResponse(
content={
"data": "",
"message": str(e),
"code": 500
},
status_code=200
)
finally:
# 清理临时文件
if temp_audio_path and temp_audio_path.exists():
temp_audio_path.unlink(missing_ok=True)
@app.get("/health")
async def health_check():
"""健康检查接口"""
return JSONResponse(
content={
"status": "healthy",
"pyannote_model_loaded": pipeline is not None,
"architecture": "pyannote_segment_only",
"description": "音频分段服务(含边界优化)",
"message": "Service is running"
},
status_code=200
)
@app.get("/")
async def root():
"""根路径"""
return {
"message": "音频分段 API 服务",
"version": "3.0",
"config": {
"mode": "pyannote + gap-based merge + boundary optimization",
"max_segment_duration": "50s",
"boundary_optimization": "enabled (RMS energy)",
"merge_strategy": "gap-priority"
},
"endpoints": {
"segment": "/api/segment - POST 上传音频文件进行分段",
"health": "/health - GET 健康检查"
}
}
if __name__ == "__main__":
import uvicorn
print("=" * 60)
print("音频分段服务启动中...")
print("=" * 60)
print("功能特性:")
print(" ✓ Pyannote 高质量初始分段")
print(" ✓ 按时间间隔优先合并(智能合并策略)")
print(" ✓ 边界能量优化(RMS短时能量分析)")
print(" ✓ 双边独立优化(小间隙共享边界)")
print(" ✓ 自动处理重叠和极小间隙")
print("=" * 60)
print("启动服务器...")
print("=" * 60)
uvicorn.run(app, host="0.0.0.0", port=8001)