You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

233 lines
7.2 KiB

import os
# 设置离线模式
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pyannote.audio import Pipeline
import torch
import torchaudio
import json
import time
import uuid
from pathlib import Path
app = FastAPI(title="音频切割 API")
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量
pipeline = None
device = None
# 创建输出文件夹
APP_FOLDER = Path("APP")
APP_FOLDER.mkdir(exist_ok=True)
def merge_adjacent_segments(segments, max_duration=50.0):
"""
合并相邻片段(不限制说话人)
合并条件:
1. 在排序后的列表中位置相邻
2. 合并后总时长 <= max_duration 秒
"""
if not segments:
return []
# 按开始时间排序
sorted_segments = sorted(segments, key=lambda x: x["start"])
merged = []
current = sorted_segments[0].copy()
current["merged_segments"] = 1
for i in range(1, len(sorted_segments)):
next_segment = sorted_segments[i]
# 计算合并后的总时长
merged_duration = next_segment["end"] - current["start"]
# 判断是否应该合并: 合并后不超过最大时长
if merged_duration <= max_duration:
# 合并: 延长当前片段的结束时间
current["end"] = next_segment["end"]
current["merged_segments"] += 1
# 如果说话人不同,更新为混合说话人标记
if "speaker" in current and "speaker" in next_segment:
if current["speaker"] != next_segment["speaker"]:
current["speaker"] = "mixed_speakers"
else:
# 不合并: 保存当前片段, 开始新片段
merged.append(current)
current = next_segment.copy()
current["merged_segments"] = 1
# 添加最后一个片段
merged.append(current)
return merged
@app.on_event("startup")
async def load_model():
"""启动时加载模型(仅本地)"""
global pipeline, device
print("正在加载本地模型...")
# 本地模型路径
local_config = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml"
# 加载本地模型(使用默认参数)
pipeline = Pipeline.from_pretrained(local_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
print(f"模型加载成功!使用设备: {device}")
print("使用默认参数进行音频切割")
@app.post("/api/upload")
async def upload_audio(file: UploadFile = File(...)):
"""上传音频文件并进行音频切割"""
if not pipeline:
raise HTTPException(status_code=500, detail="模型未加载")
# 验证文件格式
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
file_ext = os.path.splitext(file.filename)[1].lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"不支持的文件格式。支持的格式: {', '.join(allowed_extensions)}"
)
try:
# 生成唯一文件名
unique_id = str(uuid.uuid4())[:8]
timestamp = int(time.time())
audio_filename = f"{timestamp}_{unique_id}{file_ext}"
audio_path = APP_FOLDER / audio_filename
# 保存音频文件
with open(audio_path, "wb") as f:
content = await file.read()
f.write(content)
print(f"音频文件已保存: {audio_filename}")
# 加载音频
waveform, sample_rate = torchaudio.load(str(audio_path))
# 转换为单声道
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 重采样到16kHz
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
# 移动到设备
waveform = waveform.to(device)
audio = {"waveform": waveform, "sample_rate": sample_rate}
# 音频切割 - 使用默认参数
print("正在进行音频切割...")
start_time = time.time()
# 调用 pipeline(不传入任何说话人参数)
diarization_result = pipeline(audio)
# 提取结果
result = []
if hasattr(diarization_result, 'speaker_diarization'):
# 新版本 API
annotation = diarization_result.speaker_diarization
for segment, track, speaker in annotation.itertracks(yield_label=True):
result.append({
"start": round(segment.start, 2),
"end": round(segment.end, 2),
"speaker": f"speaker_{speaker}"
})
else:
# 旧版本 API
for segment, track, speaker in diarization_result.itertracks(yield_label=True):
result.append({
"start": round(segment.start, 2),
"end": round(segment.end, 2),
"speaker": f"speaker_{speaker}"
})
original_count = len(result)
print(f"原始片段数: {original_count}")
# 合并相邻片段
merged_result = merge_adjacent_segments(result, max_duration=50.0)
final_count = len(merged_result)
print(f"合并后最终片段数: {final_count}")
processing_time = time.time() - start_time
print(f"处理完成,耗时: {processing_time:.2f}")
print(f"检测到 {len(set(s['speaker'] for s in result))} 个说话人")
# 保存 JSON 结果
json_filename = f"{timestamp}_{unique_id}_result.json"
json_path = APP_FOLDER / json_filename
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(merged_result, f, ensure_ascii=False, indent=2)
print(f"结果已保存: {json_filename}")
# 返回结果
return JSONResponse(content={
"success": True,
"audio_file": audio_filename,
"json_file": json_filename,
"processing_time": round(processing_time, 2),
"original_segments_count": original_count,
"final_segments_count": final_count,
"speakers_detected": len(set(s["speaker"] for s in result)),
"data": merged_result
})
except Exception as e:
print(f"处理错误: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@app.get("/")
async def root():
"""根路径"""
return {
"message": "音频切割 API 服务",
"version": "1.0",
"config": {
"mode": "default_parameters",
"max_segment_duration": "50s"
},
"endpoints": {
"upload": "/api/upload - POST 上传音频文件"
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)