You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
7.2 KiB
233 lines
7.2 KiB
import os |
|
|
|
# 设置离线模式 |
|
os.environ['HF_HUB_OFFLINE'] = '1' |
|
os.environ['TRANSFORMERS_OFFLINE'] = '1' |
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pyannote.audio import Pipeline |
|
import torch |
|
import torchaudio |
|
import json |
|
import time |
|
import uuid |
|
from pathlib import Path |
|
|
|
app = FastAPI(title="音频切割 API") |
|
|
|
# 配置 CORS |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
# 全局变量 |
|
pipeline = None |
|
device = None |
|
|
|
# 创建输出文件夹 |
|
APP_FOLDER = Path("APP") |
|
APP_FOLDER.mkdir(exist_ok=True) |
|
|
|
|
|
def merge_adjacent_segments(segments, max_duration=50.0): |
|
""" |
|
合并相邻片段(不限制说话人) |
|
合并条件: |
|
1. 在排序后的列表中位置相邻 |
|
2. 合并后总时长 <= max_duration 秒 |
|
""" |
|
if not segments: |
|
return [] |
|
|
|
# 按开始时间排序 |
|
sorted_segments = sorted(segments, key=lambda x: x["start"]) |
|
|
|
merged = [] |
|
current = sorted_segments[0].copy() |
|
current["merged_segments"] = 1 |
|
|
|
for i in range(1, len(sorted_segments)): |
|
next_segment = sorted_segments[i] |
|
|
|
# 计算合并后的总时长 |
|
merged_duration = next_segment["end"] - current["start"] |
|
|
|
# 判断是否应该合并: 合并后不超过最大时长 |
|
if merged_duration <= max_duration: |
|
# 合并: 延长当前片段的结束时间 |
|
current["end"] = next_segment["end"] |
|
current["merged_segments"] += 1 |
|
# 如果说话人不同,更新为混合说话人标记 |
|
if "speaker" in current and "speaker" in next_segment: |
|
if current["speaker"] != next_segment["speaker"]: |
|
current["speaker"] = "mixed_speakers" |
|
else: |
|
# 不合并: 保存当前片段, 开始新片段 |
|
merged.append(current) |
|
current = next_segment.copy() |
|
current["merged_segments"] = 1 |
|
|
|
# 添加最后一个片段 |
|
merged.append(current) |
|
|
|
return merged |
|
|
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
"""启动时加载模型(仅本地)""" |
|
global pipeline, device |
|
|
|
print("正在加载本地模型...") |
|
|
|
# 本地模型路径 |
|
local_config = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml" |
|
|
|
# 加载本地模型(使用默认参数) |
|
pipeline = Pipeline.from_pretrained(local_config) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.to(device) |
|
print(f"模型加载成功!使用设备: {device}") |
|
print("使用默认参数进行音频切割") |
|
|
|
|
|
@app.post("/api/upload") |
|
async def upload_audio(file: UploadFile = File(...)): |
|
"""上传音频文件并进行音频切割""" |
|
|
|
if not pipeline: |
|
raise HTTPException(status_code=500, detail="模型未加载") |
|
|
|
# 验证文件格式 |
|
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] |
|
file_ext = os.path.splitext(file.filename)[1].lower() |
|
if file_ext not in allowed_extensions: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"不支持的文件格式。支持的格式: {', '.join(allowed_extensions)}" |
|
) |
|
|
|
try: |
|
# 生成唯一文件名 |
|
unique_id = str(uuid.uuid4())[:8] |
|
timestamp = int(time.time()) |
|
audio_filename = f"{timestamp}_{unique_id}{file_ext}" |
|
audio_path = APP_FOLDER / audio_filename |
|
|
|
# 保存音频文件 |
|
with open(audio_path, "wb") as f: |
|
content = await file.read() |
|
f.write(content) |
|
|
|
print(f"音频文件已保存: {audio_filename}") |
|
|
|
# 加载音频 |
|
waveform, sample_rate = torchaudio.load(str(audio_path)) |
|
|
|
# 转换为单声道 |
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
# 重采样到16kHz |
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
|
waveform = resampler(waveform) |
|
sample_rate = 16000 |
|
|
|
# 移动到设备 |
|
waveform = waveform.to(device) |
|
audio = {"waveform": waveform, "sample_rate": sample_rate} |
|
|
|
# 音频切割 - 使用默认参数 |
|
print("正在进行音频切割...") |
|
start_time = time.time() |
|
|
|
# 调用 pipeline(不传入任何说话人参数) |
|
diarization_result = pipeline(audio) |
|
|
|
# 提取结果 |
|
result = [] |
|
if hasattr(diarization_result, 'speaker_diarization'): |
|
# 新版本 API |
|
annotation = diarization_result.speaker_diarization |
|
for segment, track, speaker in annotation.itertracks(yield_label=True): |
|
result.append({ |
|
"start": round(segment.start, 2), |
|
"end": round(segment.end, 2), |
|
"speaker": f"speaker_{speaker}" |
|
}) |
|
else: |
|
# 旧版本 API |
|
for segment, track, speaker in diarization_result.itertracks(yield_label=True): |
|
result.append({ |
|
"start": round(segment.start, 2), |
|
"end": round(segment.end, 2), |
|
"speaker": f"speaker_{speaker}" |
|
}) |
|
|
|
original_count = len(result) |
|
print(f"原始片段数: {original_count}") |
|
|
|
# 合并相邻片段 |
|
merged_result = merge_adjacent_segments(result, max_duration=50.0) |
|
final_count = len(merged_result) |
|
print(f"合并后最终片段数: {final_count}") |
|
|
|
processing_time = time.time() - start_time |
|
print(f"处理完成,耗时: {processing_time:.2f} 秒") |
|
print(f"检测到 {len(set(s['speaker'] for s in result))} 个说话人") |
|
|
|
# 保存 JSON 结果 |
|
json_filename = f"{timestamp}_{unique_id}_result.json" |
|
json_path = APP_FOLDER / json_filename |
|
with open(json_path, 'w', encoding='utf-8') as f: |
|
json.dump(merged_result, f, ensure_ascii=False, indent=2) |
|
|
|
print(f"结果已保存: {json_filename}") |
|
|
|
# 返回结果 |
|
return JSONResponse(content={ |
|
"success": True, |
|
"audio_file": audio_filename, |
|
"json_file": json_filename, |
|
"processing_time": round(processing_time, 2), |
|
"original_segments_count": original_count, |
|
"final_segments_count": final_count, |
|
"speakers_detected": len(set(s["speaker"] for s in result)), |
|
"data": merged_result |
|
}) |
|
|
|
except Exception as e: |
|
print(f"处理错误: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
"""根路径""" |
|
return { |
|
"message": "音频切割 API 服务", |
|
"version": "1.0", |
|
"config": { |
|
"mode": "default_parameters", |
|
"max_segment_duration": "50s" |
|
}, |
|
"endpoints": { |
|
"upload": "/api/upload - POST 上传音频文件" |
|
} |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8001) |