import sys import os import time import torch import argparse import uvicorn import soundfile as sf import librosa import numpy as np import torchaudio import uuid import re from pathlib import Path from pydub import AudioSegment from contextlib import asynccontextmanager from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.encoders import jsonable_encoder from typing import List, Dict, Any # ============ 全局配置 ============ # 设置离线模式 os.environ['HF_HUB_OFFLINE'] = '1' os.environ['TRANSFORMERS_OFFLINE'] = '1' # FireRedASR路径配置 FIRERED_ASR_PATH = 'G:/Work_data/workstation/Forwork-voice-txt/FireRedASR' FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L" # Pyannote路径配置 PYANNOTE_CONFIG_PATH = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml" # 临时文件夹 TEMP_DIR = Path("temp_transcription") TEMP_DIR.mkdir(exist_ok=True) # 时长阈值(秒) DURATION_THRESHOLD = 50.0 # 全局模型变量 pyannote_pipeline = None firered_model = None device = None # ============ FireRedASR 相关函数 ============ def setup_fireredasr_environment(): """设置FireRedASR环境""" if FIRERED_ASR_PATH not in sys.path: sys.path.insert(0, FIRERED_ASR_PATH) return True def load_fireredasr_model(model_dir): """加载FireRedASR模型""" setup_fireredasr_environment() try: # 修复PyTorch兼容性 torch.serialization.add_safe_globals([argparse.Namespace]) try: from fireredasr.models.fireredasr import FireRedAsr except ImportError: try: from FireRedASR.fireredasr.models.fireredasr import FireRedAsr except ImportError: import fireredasr from fireredasr.models.fireredasr import FireRedAsr model = FireRedAsr.from_pretrained("aed", model_dir) device_name = "cuda:0" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available(): try: model = model.to(device_name) except: pass if hasattr(model, 'eval'): model.eval() print(f"✓ FireRedASR模型加载成功,设备: {device_name}") return model, device_name except Exception as e: print(f"✗ FireRedASR模型加载失败: {e}") return None, None def preprocess_audio_for_asr(audio_path, output_path): """ 音频预处理(FireRedASR标准) Args: audio_path: 原始音频路径 output_path: 输出路径 Returns: True if success """ try: # 加载音频 audio_data, sr = sf.read(str(audio_path)) # 转换为单声道 if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) # 重采样到16kHz if sr != 16000: audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) sr = 16000 # 使用pydub标准化 temp_file = str(output_path) + ".temp.wav" sf.write(temp_file, audio_data, sr) audio_pydub = AudioSegment.from_file(temp_file) audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) audio_pydub.export(str(output_path), format="wav") if os.path.exists(temp_file): os.remove(temp_file) return True except Exception as e: print(f" ✗ 预处理失败: {e}") return False def transcribe_single_audio(audio_path, model, device_name): """ 识别单个音频文件 Args: audio_path: 音频文件路径 model: FireRedASR模型 device_name: 设备名称 Returns: 识别结果字典 {"text": str, "confidence": float} """ try: # 预处理 temp_path = TEMP_DIR / f"temp_asr_{uuid.uuid4().hex[:8]}.wav" success = preprocess_audio_for_asr(audio_path, temp_path) if not success: return {"text": "", "confidence": 0.0, "error": "Preprocessing failed"} # 识别 batch_uttid = ["audio_001"] batch_wav_path = [str(temp_path)] config = { "use_gpu": 1 if device_name.startswith("cuda") else 0, "beam_size": 5, "nbest": 1, "decode_max_len": 0 } with torch.no_grad(): result = model.transcribe(batch_uttid, batch_wav_path, config) # 清理临时文件 if temp_path.exists(): temp_path.unlink() if result and len(result) > 0: text = result[0].get('text', '').strip() confidence = result[0].get('confidence', result[0].get('score', 0.0)) if isinstance(confidence, (list, tuple)) and len(confidence) > 0: confidence = float(confidence[0]) else: confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 return {"text": text, "confidence": round(confidence, 3)} return {"text": "", "confidence": 0.0} except Exception as e: print(f" ✗ 识别失败: {e}") return {"text": "", "confidence": 0.0, "error": str(e)} def transcribe_audio_segments(segment_files, model, device_name): """ 批量识别音频片段 Args: segment_files: 片段文件信息列表 model: FireRedASR模型 device_name: 设备名称 Returns: 识别结果列表 """ results = [] print(f"开始识别 {len(segment_files)} 个音频片段...") for i, file_info in enumerate(segment_files): try: # 预处理 temp_path = TEMP_DIR / f"temp_seg_{i:03d}.wav" success = preprocess_audio_for_asr(file_info['filepath'], temp_path) if not success: results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "text": "", "confidence": 0.0, "segment_type": file_info['segment_type'], "error": "Preprocessing failed" }) continue # 识别 batch_uttid = [f"seg_{i:03d}"] batch_wav_path = [str(temp_path)] config = { "use_gpu": 1 if device_name.startswith("cuda") else 0, "beam_size": 5, "nbest": 1, "decode_max_len": 0 } with torch.no_grad(): transcription = model.transcribe(batch_uttid, batch_wav_path, config) # 清理临时文件 if temp_path.exists(): temp_path.unlink() if transcription and len(transcription) > 0: text = transcription[0].get('text', '').strip() confidence = transcription[0].get('confidence', transcription[0].get('score', 0.0)) if isinstance(confidence, (list, tuple)) and len(confidence) > 0: confidence = float(confidence[0]) else: confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "content": text, "confidence": round(confidence, 3), "segment_type": file_info['segment_type'] }) print(f" ✓ [{i + 1}/{len(segment_files)}] {text[:30]}...") else: results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "content": "", "confidence": 0.0, "segment_type": file_info['segment_type'] }) # 清理GPU缓存 if device_name.startswith("cuda"): torch.cuda.empty_cache() except Exception as e: print(f" ✗ [{i + 1}/{len(segment_files)}] 识别失败: {e}") results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "content": "", "confidence": 0.0, "segment_type": file_info['segment_type'], "error": str(e) }) return results # ============ Pyannote 相关函数 ============ def load_pyannote_pipeline(config_path): """加载Pyannote模型""" try: from pyannote.audio import Pipeline pipeline = Pipeline.from_pretrained(config_path) device_obj = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline.to(device_obj) print(f"✓ Pyannote模型加载成功,设备: {device_obj}") return pipeline except Exception as e: print(f"✗ Pyannote模型加载失败: {e}") return None def load_audio_for_segmentation(audio_path): """加载音频用于分段""" # 使用soundfile加载 audio_data, sr = sf.read(str(audio_path)) # 转换为单声道 if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) # 重采样到16kHz if sr != 16000: audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) sr = 16000 # 使用pydub audio_pydub = AudioSegment.from_file(str(audio_path)) audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) return audio_data, sr, audio_pydub def calculate_frame_energy(audio_data, sr, time_point, frame_length=0.025): """计算帧能量""" frame_samples = int(frame_length * sr) center_sample = int(time_point * sr) start_sample = max(0, center_sample - frame_samples // 2) end_sample = min(len(audio_data), start_sample + frame_samples) frame = audio_data[start_sample:end_sample] if len(frame) == 0: return 0.0 return np.sqrt(np.mean(frame ** 2)) def find_optimal_boundary(audio_data, sr, start_time, end_time, frame_length=0.025, hop_length=0.005): """在指定时间范围内找到能量最低点""" if start_time >= end_time: return start_time, 0.0 start_sample = int(start_time * sr) end_sample = int(end_time * sr) segment = audio_data[start_sample:end_sample] if len(segment) == 0: return start_time, 0.0 frame_samples = int(frame_length * sr) hop_samples = int(hop_length * sr) min_energy = float('inf') optimal_time = start_time for i in range(0, len(segment) - frame_samples + 1, hop_samples): frame = segment[i:i + frame_samples] energy = np.sqrt(np.mean(frame ** 2)) if energy < min_energy: min_energy = energy optimal_time = start_time + (i + frame_samples / 2) / sr return optimal_time, min_energy def optimize_segment_boundaries(segments, audio_data, sr, small_gap_threshold=0.1, min_gap_to_keep=0.05): """优化片段边界""" if not segments: return [] audio_duration = len(audio_data) / sr optimized = [] for i, seg in enumerate(segments): current_seg = seg.copy() # 优化起始边界 if i == 0: search_start = 0.0 search_end = seg["start"] else: prev_end = segments[i - 1]["end"] search_start = prev_end search_end = seg["start"] gap_size = search_end - search_start if gap_size > 0 and gap_size >= small_gap_threshold: optimal_start, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) current_seg["start"] = optimal_start # 优化结束边界 if i == len(segments) - 1: search_start = seg["end"] search_end = audio_duration else: next_start = segments[i + 1]["start"] search_start = seg["end"] search_end = next_start gap_size = search_end - search_start if gap_size > 0 and gap_size >= small_gap_threshold: optimal_end, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) current_seg["end"] = optimal_end optimized.append(current_seg) # 处理小间隙(共享边界) for i in range(len(optimized) - 1): gap = optimized[i + 1]["start"] - optimized[i]["end"] if 0 < gap < small_gap_threshold: optimal_point, _ = find_optimal_boundary( audio_data, sr, optimized[i]["end"], optimized[i + 1]["start"] ) optimized[i]["end"] = optimal_point optimized[i + 1]["start"] = optimal_point # 处理重叠和极小间隙 final_optimized = [] for i in range(len(optimized)): if i == 0: final_optimized.append(optimized[i]) continue prev_seg = final_optimized[-1] curr_seg = optimized[i] gap = curr_seg["start"] - prev_seg["end"] if gap < 0: midpoint = (prev_seg["end"] + curr_seg["start"]) / 2 prev_seg["end"] = midpoint curr_seg["start"] = midpoint elif 0 < gap < min_gap_to_keep: prev_seg["end"] = curr_seg["start"] final_optimized.append(curr_seg) return final_optimized def merge_segments_by_gap(segments, max_duration=50.0): """按时间间隔合并片段""" if not segments: return [] if len(segments) == 1: return segments groups = [[seg] for seg in segments] while True: gaps = [] for i in range(len(groups) - 1): group1_end = groups[i][-1]["end"] group2_start = groups[i + 1][0]["start"] gap = group2_start - group1_end merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"] gaps.append({ "index": i, "gap": gap, "merged_duration": merged_duration }) if not gaps: break gaps.sort(key=lambda x: x["gap"]) merged = False for gap_info in gaps: if gap_info["merged_duration"] <= max_duration: idx = gap_info["index"] groups[idx].extend(groups[idx + 1]) groups.pop(idx + 1) merged = True break if not merged: break merged_segments = [] for group in groups: merged_seg = { "start": group[0]["start"], "end": group[-1]["end"], "original_count": len(group), "is_merged": len(group) > 1 } merged_segments.append(merged_seg) return merged_segments def perform_segmentation(audio_path, pipeline_model, device_obj): """ 执行音频分段 Args: audio_path: 音频文件路径 pipeline_model: Pyannote模型 device_obj: 设备 Returns: 片段列表 """ print("\n[Pyannote分段] 开始处理...") # 加载音频 audio_data, sr, audio_pydub = load_audio_for_segmentation(audio_path) duration = len(audio_data) / sr print(f" 音频时长: {duration:.1f}秒") # Pyannote处理 waveform, sample_rate = torchaudio.load(str(audio_path)) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) sample_rate = 16000 waveform = waveform.to(device_obj) audio_dict = {"waveform": waveform, "sample_rate": sample_rate} diarization_result = pipeline_model(audio_dict) # 提取初始片段 initial_segments = [] if hasattr(diarization_result, 'speaker_diarization'): annotation = diarization_result.speaker_diarization for segment, track, speaker in annotation.itertracks(yield_label=True): initial_segments.append({ "start": round(segment.start, 3), "end": round(segment.end, 3) }) else: for segment, track, speaker in diarization_result.itertracks(yield_label=True): initial_segments.append({ "start": round(segment.start, 3), "end": round(segment.end, 3) }) print(f" 初始片段数: {len(initial_segments)}") # 合并片段 merged_segments = merge_segments_by_gap(initial_segments, max_duration=DURATION_THRESHOLD) print(f" 合并后片段数: {len(merged_segments)}") # 边界优化 optimized_segments = optimize_segment_boundaries(merged_segments, audio_data, sr) print(f" 边界优化完成") # 导出片段 session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" segment_folder = TEMP_DIR / session_id segment_folder.mkdir(exist_ok=True) segment_files = [] for i, seg in enumerate(optimized_segments): start_time = seg["start"] end_time = seg["end"] start_ms = max(0, int(start_time * 1000) - 200) end_ms = min(len(audio_pydub), int(end_time * 1000) + 200) segment_audio = audio_pydub[start_ms:end_ms] filename = f"segment_{i:03d}.wav" filepath = segment_folder / filename segment_audio.export(str(filepath), format="wav") # 判断segment_type segment_type = "forced" if seg.get("is_merged", False) else "natural" segment_files.append({ "filepath": filepath, "start_ms": int(start_time * 1000), "end_ms": int(end_time * 1000), "segment_type": segment_type }) print(f" 导出完成: {len(segment_files)} 个文件") return segment_files, session_id # ============ FastAPI 应用 ============ @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" global pyannote_pipeline, firered_model, device print("=" * 60) print("正在启动音频转录服务...") print("=" * 60) # 加载Pyannote模型 print("\n[1/2] 加载Pyannote模型...") pyannote_pipeline = load_pyannote_pipeline(PYANNOTE_CONFIG_PATH) # 加载FireRedASR模型 print("\n[2/2] 加载FireRedASR模型...") firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) if pyannote_pipeline is None or firered_model is None: print("\n✗ 模型加载失败,服务无法启动") else: print("\n✓ 所有模型加载成功") print("=" * 60) yield print("\n服务关闭中...") app = FastAPI( lifespan=lifespan, title="Audio Transcription Service", description="音频转录服务(分段+识别)" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/transcriptions") async def transcribe_audio(file: UploadFile = File(...)): """ 音频转录API 处理流程: 1. 检查音频时长 2. < 50s: 直接识别 3. >= 50s: Pyannote分段 → FireRedASR批量识别 """ global pyannote_pipeline, firered_model, device # 检查模型 if firered_model is None: raise HTTPException(status_code=500, detail="FireRedASR模型未加载") # 验证文件格式 allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] file_ext = os.path.splitext(file.filename)[1].lower() if file_ext not in allowed_extensions: raise HTTPException( status_code=400, detail=f"不支持的文件格式。支持: {', '.join(allowed_extensions)}" ) temp_audio_path = None session_folder = None try: print(f"\n{'=' * 60}") print(f"收到转录请求: {file.filename}") total_start_time = time.time() # 保存上传的文件 session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" temp_audio_path = TEMP_DIR / f"upload_{session_id}{file_ext}" with open(temp_audio_path, "wb") as f: content = await file.read() f.write(content) # 检查音频时长 audio_info = sf.info(str(temp_audio_path)) duration = audio_info.duration print(f"音频时长: {duration:.2f}秒") # 路由选择 if duration < DURATION_THRESHOLD: # 直接识别 print(f"\n模式: 直接识别(< {DURATION_THRESHOLD}秒)") result = transcribe_single_audio(temp_audio_path, firered_model, device) transcription = [{ "start": 0, "end": int(duration * 1000), "content": result.get("text", ""), "confidence": result.get("confidence", 0.0), "segment_type": "direct" }] statistics = { "total_segments": 1, "processing_time": round(time.time() - total_start_time, 2), "segmentation_types": { "natural": 0, "forced": 0, "direct": 1 }, "processing_method": "direct_asr" } else: # 分段 + 识别 print(f"\n模式: 分段识别(>= {DURATION_THRESHOLD}秒)") if pyannote_pipeline is None: raise HTTPException(status_code=500, detail="Pyannote模型未加载") # 分段 segment_files, seg_session_id = perform_segmentation( temp_audio_path, pyannote_pipeline, torch.device("cuda" if torch.cuda.is_available() else "cpu") ) session_folder = TEMP_DIR / seg_session_id # 识别 print(f"\n[FireRedASR识别] 开始处理...") transcription = transcribe_audio_segments(segment_files, firered_model, device) # 统计 seg_types = {"natural": 0, "forced": 0, "direct": 0} for seg in transcription: seg_types[seg["segment_type"]] += 1 statistics = { "total_segments": len(transcription), "processing_time": round(time.time() - total_start_time, 2), "segmentation_types": seg_types, "processing_method": "pipeline_dual_model" } print(f"\n{'=' * 60}") print(f"处理完成,总耗时: {statistics['processing_time']}秒") print(f"{'=' * 60}\n") # 构建响应 response_data = { "code": 200, "message": "success", "data": { "transcription": transcription, "statistics": statistics } } return JSONResponse( content=jsonable_encoder(response_data), status_code=200 ) except Exception as e: print(f"\n处理错误: {str(e)}") import traceback traceback.print_exc() return JSONResponse( content={ "code": 500, "message": f"处理失败: {str(e)}", "data": { "transcription": [], "statistics": {} } }, status_code=200 ) finally: # 清理临时文件 if temp_audio_path and temp_audio_path.exists(): try: temp_audio_path.unlink() except: pass if session_folder and session_folder.exists(): try: for file in session_folder.glob("*.wav"): file.unlink() session_folder.rmdir() except: pass @app.get("/health") async def health_check(): """健康检查""" return JSONResponse( content={ "status": "healthy", "pyannote_loaded": pyannote_pipeline is not None, "firered_loaded": firered_model is not None, "device": str(device) if device else "not initialized", "duration_threshold": DURATION_THRESHOLD, "message": "Service is running" }, status_code=200 ) @app.get("/") async def root(): """服务信息""" return { "service": "Audio Transcription Service", "version": "1.0", "description": "音频转录服务(智能分段 + 语音识别)", "features": [ f"短音频(<{DURATION_THRESHOLD}s)直接识别", f"长音频(≥{DURATION_THRESHOLD}s)智能分段后识别", "Pyannote高质量分段", "FireRedASR高精度识别", "自动边界优化" ], "endpoints": { "transcriptions": "POST /transcriptions - 音频转录", "health": "GET /health - 健康检查", "root": "GET / - 服务信息" }, "models": { "segmentation": "Pyannote Speaker Diarization 3.1", "transcription": "FireRedASR-AED-L" } } if __name__ == "__main__": print("=" * 60) print("音频转录服务 v1.0") print("=" * 60) print("特性:") print(f" ✓ 智能路由({DURATION_THRESHOLD}秒阈值)") print(" ✓ Pyannote分段 + FireRedASR识别") print(" ✓ 自动边界优化") print(" ✓ 移除说话人功能") print("=" * 60) print("启动服务器(端口9000)...") print() uvicorn.run(app, host="0.0.0.0", port=9000)