You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

865 lines
25 KiB

import sys
import os
import time
import torch
import argparse
import uvicorn
import soundfile as sf
import librosa
import numpy as np
import torchaudio
import uuid
import re
from pathlib import Path
from pydub import AudioSegment
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.encoders import jsonable_encoder
from typing import List, Dict, Any
# ============ 全局配置 ============
# 设置离线模式
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
# FireRedASR路径配置
FIRERED_ASR_PATH = 'G:/Work_data/workstation/Forwork-voice-txt/FireRedASR'
FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L"
# Pyannote路径配置
PYANNOTE_CONFIG_PATH = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml"
# 临时文件夹
TEMP_DIR = Path("temp_transcription")
TEMP_DIR.mkdir(exist_ok=True)
# 时长阈值(秒)
DURATION_THRESHOLD = 50.0
# 全局模型变量
pyannote_pipeline = None
firered_model = None
device = None
# ============ FireRedASR 相关函数 ============
def setup_fireredasr_environment():
"""设置FireRedASR环境"""
if FIRERED_ASR_PATH not in sys.path:
sys.path.insert(0, FIRERED_ASR_PATH)
return True
def load_fireredasr_model(model_dir):
"""加载FireRedASR模型"""
setup_fireredasr_environment()
try:
# 修复PyTorch兼容性
torch.serialization.add_safe_globals([argparse.Namespace])
try:
from fireredasr.models.fireredasr import FireRedAsr
except ImportError:
try:
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr
except ImportError:
import fireredasr
from fireredasr.models.fireredasr import FireRedAsr
model = FireRedAsr.from_pretrained("aed", model_dir)
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
try:
model = model.to(device_name)
except:
pass
if hasattr(model, 'eval'):
model.eval()
print(f"✓ FireRedASR模型加载成功,设备: {device_name}")
return model, device_name
except Exception as e:
print(f"✗ FireRedASR模型加载失败: {e}")
return None, None
def preprocess_audio_for_asr(audio_path, output_path):
"""
音频预处理(FireRedASR标准)
Args:
audio_path: 原始音频路径
output_path: 输出路径
Returns:
True if success
"""
try:
# 加载音频
audio_data, sr = sf.read(str(audio_path))
# 转换为单声道
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# 重采样到16kHz
if sr != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
sr = 16000
# 使用pydub标准化
temp_file = str(output_path) + ".temp.wav"
sf.write(temp_file, audio_data, sr)
audio_pydub = AudioSegment.from_file(temp_file)
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1)
audio_pydub.export(str(output_path), format="wav")
if os.path.exists(temp_file):
os.remove(temp_file)
return True
except Exception as e:
print(f" ✗ 预处理失败: {e}")
return False
def transcribe_single_audio(audio_path, model, device_name):
"""
识别单个音频文件
Args:
audio_path: 音频文件路径
model: FireRedASR模型
device_name: 设备名称
Returns:
识别结果字典 {"text": str, "confidence": float}
"""
try:
# 预处理
temp_path = TEMP_DIR / f"temp_asr_{uuid.uuid4().hex[:8]}.wav"
success = preprocess_audio_for_asr(audio_path, temp_path)
if not success:
return {"text": "", "confidence": 0.0, "error": "Preprocessing failed"}
# 识别
batch_uttid = ["audio_001"]
batch_wav_path = [str(temp_path)]
config = {
"use_gpu": 1 if device_name.startswith("cuda") else 0,
"beam_size": 5,
"nbest": 1,
"decode_max_len": 0
}
with torch.no_grad():
result = model.transcribe(batch_uttid, batch_wav_path, config)
# 清理临时文件
if temp_path.exists():
temp_path.unlink()
if result and len(result) > 0:
text = result[0].get('text', '').strip()
confidence = result[0].get('confidence', result[0].get('score', 0.0))
if isinstance(confidence, (list, tuple)) and len(confidence) > 0:
confidence = float(confidence[0])
else:
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0
return {"text": text, "confidence": round(confidence, 3)}
return {"text": "", "confidence": 0.0}
except Exception as e:
print(f" ✗ 识别失败: {e}")
return {"text": "", "confidence": 0.0, "error": str(e)}
def transcribe_audio_segments(segment_files, model, device_name):
"""
批量识别音频片段
Args:
segment_files: 片段文件信息列表
model: FireRedASR模型
device_name: 设备名称
Returns:
识别结果列表
"""
results = []
print(f"开始识别 {len(segment_files)} 个音频片段...")
for i, file_info in enumerate(segment_files):
try:
# 预处理
temp_path = TEMP_DIR / f"temp_seg_{i:03d}.wav"
success = preprocess_audio_for_asr(file_info['filepath'], temp_path)
if not success:
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"text": "",
"confidence": 0.0,
"segment_type": file_info['segment_type'],
"error": "Preprocessing failed"
})
continue
# 识别
batch_uttid = [f"seg_{i:03d}"]
batch_wav_path = [str(temp_path)]
config = {
"use_gpu": 1 if device_name.startswith("cuda") else 0,
"beam_size": 5,
"nbest": 1,
"decode_max_len": 0
}
with torch.no_grad():
transcription = model.transcribe(batch_uttid, batch_wav_path, config)
# 清理临时文件
if temp_path.exists():
temp_path.unlink()
if transcription and len(transcription) > 0:
text = transcription[0].get('text', '').strip()
confidence = transcription[0].get('confidence', transcription[0].get('score', 0.0))
if isinstance(confidence, (list, tuple)) and len(confidence) > 0:
confidence = float(confidence[0])
else:
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"content": text,
"confidence": round(confidence, 3),
"segment_type": file_info['segment_type']
})
print(f" ✓ [{i + 1}/{len(segment_files)}] {text[:30]}...")
else:
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"content": "",
"confidence": 0.0,
"segment_type": file_info['segment_type']
})
# 清理GPU缓存
if device_name.startswith("cuda"):
torch.cuda.empty_cache()
except Exception as e:
print(f" ✗ [{i + 1}/{len(segment_files)}] 识别失败: {e}")
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"content": "",
"confidence": 0.0,
"segment_type": file_info['segment_type'],
"error": str(e)
})
return results
# ============ Pyannote 相关函数 ============
def load_pyannote_pipeline(config_path):
"""加载Pyannote模型"""
try:
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained(config_path)
device_obj = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device_obj)
print(f"✓ Pyannote模型加载成功,设备: {device_obj}")
return pipeline
except Exception as e:
print(f"✗ Pyannote模型加载失败: {e}")
return None
def load_audio_for_segmentation(audio_path):
"""加载音频用于分段"""
# 使用soundfile加载
audio_data, sr = sf.read(str(audio_path))
# 转换为单声道
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# 重采样到16kHz
if sr != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
sr = 16000
# 使用pydub
audio_pydub = AudioSegment.from_file(str(audio_path))
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1)
return audio_data, sr, audio_pydub
def calculate_frame_energy(audio_data, sr, time_point, frame_length=0.025):
"""计算帧能量"""
frame_samples = int(frame_length * sr)
center_sample = int(time_point * sr)
start_sample = max(0, center_sample - frame_samples // 2)
end_sample = min(len(audio_data), start_sample + frame_samples)
frame = audio_data[start_sample:end_sample]
if len(frame) == 0:
return 0.0
return np.sqrt(np.mean(frame ** 2))
def find_optimal_boundary(audio_data, sr, start_time, end_time,
frame_length=0.025, hop_length=0.005):
"""在指定时间范围内找到能量最低点"""
if start_time >= end_time:
return start_time, 0.0
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
segment = audio_data[start_sample:end_sample]
if len(segment) == 0:
return start_time, 0.0
frame_samples = int(frame_length * sr)
hop_samples = int(hop_length * sr)
min_energy = float('inf')
optimal_time = start_time
for i in range(0, len(segment) - frame_samples + 1, hop_samples):
frame = segment[i:i + frame_samples]
energy = np.sqrt(np.mean(frame ** 2))
if energy < min_energy:
min_energy = energy
optimal_time = start_time + (i + frame_samples / 2) / sr
return optimal_time, min_energy
def optimize_segment_boundaries(segments, audio_data, sr,
small_gap_threshold=0.1,
min_gap_to_keep=0.05):
"""优化片段边界"""
if not segments:
return []
audio_duration = len(audio_data) / sr
optimized = []
for i, seg in enumerate(segments):
current_seg = seg.copy()
# 优化起始边界
if i == 0:
search_start = 0.0
search_end = seg["start"]
else:
prev_end = segments[i - 1]["end"]
search_start = prev_end
search_end = seg["start"]
gap_size = search_end - search_start
if gap_size > 0 and gap_size >= small_gap_threshold:
optimal_start, _ = find_optimal_boundary(audio_data, sr, search_start, search_end)
current_seg["start"] = optimal_start
# 优化结束边界
if i == len(segments) - 1:
search_start = seg["end"]
search_end = audio_duration
else:
next_start = segments[i + 1]["start"]
search_start = seg["end"]
search_end = next_start
gap_size = search_end - search_start
if gap_size > 0 and gap_size >= small_gap_threshold:
optimal_end, _ = find_optimal_boundary(audio_data, sr, search_start, search_end)
current_seg["end"] = optimal_end
optimized.append(current_seg)
# 处理小间隙(共享边界)
for i in range(len(optimized) - 1):
gap = optimized[i + 1]["start"] - optimized[i]["end"]
if 0 < gap < small_gap_threshold:
optimal_point, _ = find_optimal_boundary(
audio_data, sr,
optimized[i]["end"],
optimized[i + 1]["start"]
)
optimized[i]["end"] = optimal_point
optimized[i + 1]["start"] = optimal_point
# 处理重叠和极小间隙
final_optimized = []
for i in range(len(optimized)):
if i == 0:
final_optimized.append(optimized[i])
continue
prev_seg = final_optimized[-1]
curr_seg = optimized[i]
gap = curr_seg["start"] - prev_seg["end"]
if gap < 0:
midpoint = (prev_seg["end"] + curr_seg["start"]) / 2
prev_seg["end"] = midpoint
curr_seg["start"] = midpoint
elif 0 < gap < min_gap_to_keep:
prev_seg["end"] = curr_seg["start"]
final_optimized.append(curr_seg)
return final_optimized
def merge_segments_by_gap(segments, max_duration=50.0):
"""按时间间隔合并片段"""
if not segments:
return []
if len(segments) == 1:
return segments
groups = [[seg] for seg in segments]
while True:
gaps = []
for i in range(len(groups) - 1):
group1_end = groups[i][-1]["end"]
group2_start = groups[i + 1][0]["start"]
gap = group2_start - group1_end
merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"]
gaps.append({
"index": i,
"gap": gap,
"merged_duration": merged_duration
})
if not gaps:
break
gaps.sort(key=lambda x: x["gap"])
merged = False
for gap_info in gaps:
if gap_info["merged_duration"] <= max_duration:
idx = gap_info["index"]
groups[idx].extend(groups[idx + 1])
groups.pop(idx + 1)
merged = True
break
if not merged:
break
merged_segments = []
for group in groups:
merged_seg = {
"start": group[0]["start"],
"end": group[-1]["end"],
"original_count": len(group),
"is_merged": len(group) > 1
}
merged_segments.append(merged_seg)
return merged_segments
def perform_segmentation(audio_path, pipeline_model, device_obj):
"""
执行音频分段
Args:
audio_path: 音频文件路径
pipeline_model: Pyannote模型
device_obj: 设备
Returns:
片段列表
"""
print("\n[Pyannote分段] 开始处理...")
# 加载音频
audio_data, sr, audio_pydub = load_audio_for_segmentation(audio_path)
duration = len(audio_data) / sr
print(f" 音频时长: {duration:.1f}")
# Pyannote处理
waveform, sample_rate = torchaudio.load(str(audio_path))
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
waveform = waveform.to(device_obj)
audio_dict = {"waveform": waveform, "sample_rate": sample_rate}
diarization_result = pipeline_model(audio_dict)
# 提取初始片段
initial_segments = []
if hasattr(diarization_result, 'speaker_diarization'):
annotation = diarization_result.speaker_diarization
for segment, track, speaker in annotation.itertracks(yield_label=True):
initial_segments.append({
"start": round(segment.start, 3),
"end": round(segment.end, 3)
})
else:
for segment, track, speaker in diarization_result.itertracks(yield_label=True):
initial_segments.append({
"start": round(segment.start, 3),
"end": round(segment.end, 3)
})
print(f" 初始片段数: {len(initial_segments)}")
# 合并片段
merged_segments = merge_segments_by_gap(initial_segments, max_duration=DURATION_THRESHOLD)
print(f" 合并后片段数: {len(merged_segments)}")
# 边界优化
optimized_segments = optimize_segment_boundaries(merged_segments, audio_data, sr)
print(f" 边界优化完成")
# 导出片段
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
segment_folder = TEMP_DIR / session_id
segment_folder.mkdir(exist_ok=True)
segment_files = []
for i, seg in enumerate(optimized_segments):
start_time = seg["start"]
end_time = seg["end"]
start_ms = max(0, int(start_time * 1000) - 200)
end_ms = min(len(audio_pydub), int(end_time * 1000) + 200)
segment_audio = audio_pydub[start_ms:end_ms]
filename = f"segment_{i:03d}.wav"
filepath = segment_folder / filename
segment_audio.export(str(filepath), format="wav")
# 判断segment_type
segment_type = "forced" if seg.get("is_merged", False) else "natural"
segment_files.append({
"filepath": filepath,
"start_ms": int(start_time * 1000),
"end_ms": int(end_time * 1000),
"segment_type": segment_type
})
print(f" 导出完成: {len(segment_files)} 个文件")
return segment_files, session_id
# ============ FastAPI 应用 ============
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
global pyannote_pipeline, firered_model, device
print("=" * 60)
print("正在启动音频转录服务...")
print("=" * 60)
# 加载Pyannote模型
print("\n[1/2] 加载Pyannote模型...")
pyannote_pipeline = load_pyannote_pipeline(PYANNOTE_CONFIG_PATH)
# 加载FireRedASR模型
print("\n[2/2] 加载FireRedASR模型...")
firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH)
if pyannote_pipeline is None or firered_model is None:
print("\n✗ 模型加载失败,服务无法启动")
else:
print("\n✓ 所有模型加载成功")
print("=" * 60)
yield
print("\n服务关闭中...")
app = FastAPI(
lifespan=lifespan,
title="Audio Transcription Service",
description="音频转录服务(分段+识别)"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/transcriptions")
async def transcribe_audio(file: UploadFile = File(...)):
"""
音频转录API
处理流程:
1. 检查音频时长
2. < 50s: 直接识别
3. >= 50s: Pyannote分段 → FireRedASR批量识别
"""
global pyannote_pipeline, firered_model, device
# 检查模型
if firered_model is None:
raise HTTPException(status_code=500, detail="FireRedASR模型未加载")
# 验证文件格式
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
file_ext = os.path.splitext(file.filename)[1].lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"不支持的文件格式。支持: {', '.join(allowed_extensions)}"
)
temp_audio_path = None
session_folder = None
try:
print(f"\n{'=' * 60}")
print(f"收到转录请求: {file.filename}")
total_start_time = time.time()
# 保存上传的文件
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
temp_audio_path = TEMP_DIR / f"upload_{session_id}{file_ext}"
with open(temp_audio_path, "wb") as f:
content = await file.read()
f.write(content)
# 检查音频时长
audio_info = sf.info(str(temp_audio_path))
duration = audio_info.duration
print(f"音频时长: {duration:.2f}")
# 路由选择
if duration < DURATION_THRESHOLD:
# 直接识别
print(f"\n模式: 直接识别(< {DURATION_THRESHOLD}秒)")
result = transcribe_single_audio(temp_audio_path, firered_model, device)
transcription = [{
"start": 0,
"end": int(duration * 1000),
"content": result.get("text", ""),
"confidence": result.get("confidence", 0.0),
"segment_type": "direct"
}]
statistics = {
"total_segments": 1,
"processing_time": round(time.time() - total_start_time, 2),
"segmentation_types": {
"natural": 0,
"forced": 0,
"direct": 1
},
"processing_method": "direct_asr"
}
else:
# 分段 + 识别
print(f"\n模式: 分段识别(>= {DURATION_THRESHOLD}秒)")
if pyannote_pipeline is None:
raise HTTPException(status_code=500, detail="Pyannote模型未加载")
# 分段
segment_files, seg_session_id = perform_segmentation(
temp_audio_path,
pyannote_pipeline,
torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
session_folder = TEMP_DIR / seg_session_id
# 识别
print(f"\n[FireRedASR识别] 开始处理...")
transcription = transcribe_audio_segments(segment_files, firered_model, device)
# 统计
seg_types = {"natural": 0, "forced": 0, "direct": 0}
for seg in transcription:
seg_types[seg["segment_type"]] += 1
statistics = {
"total_segments": len(transcription),
"processing_time": round(time.time() - total_start_time, 2),
"segmentation_types": seg_types,
"processing_method": "pipeline_dual_model"
}
print(f"\n{'=' * 60}")
print(f"处理完成,总耗时: {statistics['processing_time']}")
print(f"{'=' * 60}\n")
# 构建响应
response_data = {
"code": 200,
"message": "success",
"data": {
"transcription": transcription,
"statistics": statistics
}
}
return JSONResponse(
content=jsonable_encoder(response_data),
status_code=200
)
except Exception as e:
print(f"\n处理错误: {str(e)}")
import traceback
traceback.print_exc()
return JSONResponse(
content={
"code": 500,
"message": f"处理失败: {str(e)}",
"data": {
"transcription": [],
"statistics": {}
}
},
status_code=200
)
finally:
# 清理临时文件
if temp_audio_path and temp_audio_path.exists():
try:
temp_audio_path.unlink()
except:
pass
if session_folder and session_folder.exists():
try:
for file in session_folder.glob("*.wav"):
file.unlink()
session_folder.rmdir()
except:
pass
@app.get("/health")
async def health_check():
"""健康检查"""
return JSONResponse(
content={
"status": "healthy",
"pyannote_loaded": pyannote_pipeline is not None,
"firered_loaded": firered_model is not None,
"device": str(device) if device else "not initialized",
"duration_threshold": DURATION_THRESHOLD,
"message": "Service is running"
},
status_code=200
)
@app.get("/")
async def root():
"""服务信息"""
return {
"service": "Audio Transcription Service",
"version": "1.0",
"description": "音频转录服务(智能分段 + 语音识别)",
"features": [
f"短音频(<{DURATION_THRESHOLD}s)直接识别",
f"长音频(≥{DURATION_THRESHOLD}s)智能分段后识别",
"Pyannote高质量分段",
"FireRedASR高精度识别",
"自动边界优化"
],
"endpoints": {
"transcriptions": "POST /transcriptions - 音频转录",
"health": "GET /health - 健康检查",
"root": "GET / - 服务信息"
},
"models": {
"segmentation": "Pyannote Speaker Diarization 3.1",
"transcription": "FireRedASR-AED-L"
}
}
if __name__ == "__main__":
print("=" * 60)
print("音频转录服务 v1.0")
print("=" * 60)
print("特性:")
print(f" ✓ 智能路由({DURATION_THRESHOLD}秒阈值)")
print(" ✓ Pyannote分段 + FireRedASR识别")
print(" ✓ 自动边界优化")
print(" ✓ 移除说话人功能")
print("=" * 60)
print("启动服务器(端口9000)...")
print()
uvicorn.run(app, host="0.0.0.0", port=9000)