import sys sys.path.append('G:/Work_data/workstation/Forwork-voice-txt/FireRedASR') import os import time import torch import argparse import uvicorn import soundfile as sf import librosa import numpy as np from pathlib import Path from pydub import AudioSegment import re from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder from pydantic import BaseModel # 全局变量存储FireRedASR模型 firered_model = None device = None # ============ 参数配置区域 ============ # 模型加载模式 KEEP_MODEL_LOADED = True # True=常驻模式(快,占内存), False=按需模式(慢,省内存) # 模型路径 FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L" # 临时文件夹 TEMP_SEGMENTS_DIR = Path("temp_segments") # ===================================== class TranscriptionRequest(BaseModel): """批量转录请求模型""" folder_path: str def setup_fireredasr_environment(): """设置FireRedASR环境""" possible_paths = [ "./FireRedASR", "../FireRedASR", "FireRedASR" ] for path in possible_paths: if Path(path).exists(): if str(Path(path).absolute()) not in sys.path: sys.path.insert(0, str(Path(path).absolute())) return True return False def load_fireredasr_model(model_dir): """加载FireRedASR模型""" # 设置环境 if not setup_fireredasr_environment(): pass try: # 修复PyTorch兼容性 torch.serialization.add_safe_globals([argparse.Namespace]) # 尝试多种导入方式 try: from fireredasr.models.fireredasr import FireRedAsr except ImportError: try: from FireRedASR.fireredasr.models.fireredasr import FireRedAsr except ImportError: import fireredasr from fireredasr.models.fireredasr import FireRedAsr model = FireRedAsr.from_pretrained("aed", model_dir) # 尝试使用GPU device_name = "cuda:0" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available(): try: model = model.to(device_name) except: pass if hasattr(model, 'eval'): model.eval() print(f"✓ 模型加载成功,使用设备: {device_name}") return model, device_name except Exception as e: print(f"✗ FireRedASR模型加载失败: {e}") return None, None def unload_fireredasr_model(): """卸载模型,释放内存""" global firered_model if firered_model is not None: print("正在卸载模型...") del firered_model firered_model = None # 清空GPU缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() print("✓ 模型已卸载,内存已释放") def parse_speaker_from_filename(filename): """ 从文件名中解析说话人信息 支持格式: - 001_spk0_0.00-5.23.wav → speaker0 - 002_spk1_5.50-8.90.wav → speaker1 - segment_spk2.wav → speaker2 """ # 尝试匹配 spk + 数字 match = re.search(r'spk(\d+)', filename.lower()) if match: speaker_num = match.group(1) return f"speaker{speaker_num}" # 如果没有匹配到,返回默认 return "speaker0" def scan_audio_folder(folder_path): """ 扫描文件夹,获取所有音频文件 Args: folder_path: 文件夹路径(支持Windows和Linux格式) Returns: 排序后的音频文件列表 """ # 使用pathlib处理路径,自动兼容Windows和Linux folder = Path(folder_path) # 验证路径 if not folder.exists(): raise ValueError(f"路径不存在: {folder_path}") if not folder.is_dir(): raise ValueError(f"不是有效的文件夹: {folder_path}") # 支持的音频格式 audio_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] # 扫描音频文件 audio_files = [] for ext in audio_extensions: audio_files.extend(folder.glob(f"*{ext}")) # 按文件名排序 audio_files = sorted(audio_files, key=lambda x: x.name) return audio_files def preprocess_audio_file(audio_path, output_path): """ 完整的音频预处理(参考原代码以保证最高精度) 处理步骤: 1. 使用soundfile加载音频 2. 转换为单声道 3. 重采样到16kHz 4. 使用pydub标准化 5. 保存为标准wav格式 Args: audio_path: 原始音频文件路径 output_path: 输出wav文件路径 Returns: True if success, False otherwise """ try: # 第1步:使用soundfile加载音频数据 audio_data, sr = sf.read(str(audio_path)) # 第2步:转换为单声道 if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) # 第3步:重采样到16kHz if sr != 16000: audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) sr = 16000 # 第4步:使用pydub进行格式标准化 # 先保存为临时文件 temp_file = str(output_path) + ".temp.wav" sf.write(temp_file, audio_data, sr) # 使用pydub加载并标准化 audio_pydub = AudioSegment.from_file(temp_file) audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) # 第5步:导出为最终的标准wav audio_pydub.export(str(output_path), format="wav") # 删除临时文件 if os.path.exists(temp_file): os.remove(temp_file) return True except Exception as e: print(f" ✗ 预处理失败 ({audio_path.name}): {e}") return False def create_temp_segments(audio_files): """ 批量预处理音频文件 Args: audio_files: 原始音频文件路径列表 Returns: 预处理后的临时文件信息列表 """ # 创建临时文件夹 TEMP_SEGMENTS_DIR.mkdir(exist_ok=True) temp_files = [] print(f"正在预处理 {len(audio_files)} 个音频文件...") for i, audio_file in enumerate(audio_files): # 生成临时文件名 temp_filename = f"preprocessed_{i:03d}.wav" temp_filepath = TEMP_SEGMENTS_DIR / temp_filename # 预处理音频 success = preprocess_audio_file(audio_file, temp_filepath) if success: temp_files.append({ 'temp_file': temp_filepath, 'original_file': audio_file, 'filename': audio_file.name, 'index': i }) print(f" ✓ [{i + 1}/{len(audio_files)}] {audio_file.name}") else: # 预处理失败,记录但不中断 temp_files.append({ 'temp_file': None, 'original_file': audio_file, 'filename': audio_file.name, 'index': i, 'preprocess_error': True }) print(f"预处理完成: {len([f for f in temp_files if f.get('temp_file')])} / {len(audio_files)} 成功") return temp_files def transcribe_audio_files(temp_files, model, device_name): """ 使用FireRedASR识别音频文件(使用预处理后的临时文件) Args: temp_files: 预处理后的临时文件信息列表 model: FireRedASR模型 device_name: 设备名称(cuda/cpu) Returns: 识别结果列表 """ results = [] use_gpu = device_name.startswith("cuda") print(f"开始识别 {len(temp_files)} 个音频文件...") for i, file_info in enumerate(temp_files): filename = file_info['filename'] temp_file = file_info.get('temp_file') # 检查是否有预处理错误 if file_info.get('preprocess_error'): print(f" ✗ [{i + 1}/{len(temp_files)}] {filename}: 跳过(预处理失败)") results.append({ 'filename': filename, 'speaker': parse_speaker_from_filename(filename), 'text': '', 'confidence': 0.0, 'error': 'Preprocessing failed' }) continue try: batch_uttid = [f"file_{i:03d}"] batch_wav_path = [str(temp_file)] # 使用参考代码的配置以保证最高精度 config = { "use_gpu": 1 if use_gpu else 0, "beam_size": 5, "nbest": 1, "decode_max_len": 0 } with torch.no_grad(): transcription_result = model.transcribe( batch_uttid, batch_wav_path, config ) if transcription_result and len(transcription_result) > 0: result = transcription_result[0] text = result.get('text', '').strip() # 完整的置信度提取逻辑(参考原代码) confidence = result.get('confidence', result.get('score', 0.0)) if isinstance(confidence, (list, tuple)) and len(confidence) > 0: confidence = float(confidence[0]) elif not isinstance(confidence, (int, float)): confidence = 0.0 else: confidence = float(confidence) # 解析说话人信息 speaker = parse_speaker_from_filename(filename) if text: results.append({ 'filename': filename, 'speaker': speaker, 'text': text, 'confidence': round(confidence, 3) }) print(f" ✓ [{i + 1}/{len(temp_files)}] {filename}: {speaker} - {text[:30]}...") else: # 识别结果为空 results.append({ 'filename': filename, 'speaker': speaker, 'text': '', 'confidence': 0.0, 'error': 'Empty transcription' }) # 清理GPU缓存 if use_gpu: torch.cuda.empty_cache() except Exception as e: print(f" ✗ [{i + 1}/{len(temp_files)}] {filename}: 识别失败 - {e}") # 记录失败的文件 results.append({ 'filename': filename, 'speaker': parse_speaker_from_filename(filename), 'text': '', 'confidence': 0.0, 'error': str(e) }) continue # 统计成功数量 successful_count = len([r for r in results if r.get('text') and not r.get('error')]) print(f"识别完成: {successful_count}/{len(temp_files)} 个文件成功") return results def cleanup_temp_files(): """清理临时文件""" if TEMP_SEGMENTS_DIR.exists(): for file in TEMP_SEGMENTS_DIR.glob("preprocessed_*.wav"): try: file.unlink() except: pass # 尝试删除文件夹 try: TEMP_SEGMENTS_DIR.rmdir() except: pass @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI应用生命周期管理""" global firered_model, device print("=" * 60) print("正在启动FireRedASR服务...") print("=" * 60) # 初始化设备 device = "cuda:0" if torch.cuda.is_available() else "cpu" print(f"设备: {device}") # 根据模式决定是否加载模型 if KEEP_MODEL_LOADED: print("\n内存管理模式: 常驻模式(模型始终驻留内存)") if Path(FIRERED_MODEL_PATH).exists(): firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) if firered_model is not None: print("✓ 模型已加载到内存") else: print("✗ 模型加载失败") else: print(f"✗ 模型路径不存在: {FIRERED_MODEL_PATH}") else: print("\n内存管理模式: 按需模式(每次处理时才加载模型)") print("✓ 设备已初始化,模型将在首次请求时加载") firered_model = None print("=" * 60) yield # 应用运行期间 # 关闭时执行 print("服务正在关闭...") cleanup_temp_files() # 创建FastAPI应用实例 app = FastAPI(lifespan=lifespan, title="FireRedASR Batch Transcription Service (High Precision)") @app.post("/transcriptions_batch") async def transcribe_batch(request: TranscriptionRequest): """ 批量音频转录API端点(高精度版本) 特性: - 完整的音频预处理流程 - 强制标准化为16kHz单声道wav - 使用临时文件确保格式一致 - 参考代码的所有精度配置 输入: 包含音频片段的文件夹路径 输出: 每个片段的识别结果(speaker + text + confidence) """ global firered_model, device try: # 按需模式:临时加载模型 if not KEEP_MODEL_LOADED: if firered_model is None: print("\n临时加载模型...") if Path(FIRERED_MODEL_PATH).exists(): firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) else: return JSONResponse( content={ "data": [], "statistics": {}, "message": f"模型路径不存在: {FIRERED_MODEL_PATH}", "code": 500 }, status_code=200 ) # 检查FireRedASR模型是否已加载 if firered_model is None: return JSONResponse( content={ "data": [], "statistics": {}, "message": "FireRedASR模型未加载,请检查模型路径并重启服务", "code": 500 }, status_code=200 ) print(f"\n{'=' * 60}") print(f"收到批量转录请求") print(f"文件夹路径: {request.folder_path}") print(f"处理模式: {'常驻模式' if KEEP_MODEL_LOADED else '按需模式'}") start_time = time.time() # 第1步:扫描文件夹 try: audio_files = scan_audio_folder(request.folder_path) except ValueError as e: return JSONResponse( content={ "data": [], "statistics": {}, "message": str(e), "code": 400 }, status_code=200 ) if not audio_files: return JSONResponse( content={ "data": [], "statistics": { "total_files": 0, "successful": 0, "failed": 0, "processing_time": 0 }, "message": "文件夹中没有找到音频文件", "code": 400 }, status_code=200 ) print(f"找到 {len(audio_files)} 个音频文件") print(f"{'=' * 60}\n") # 第2步:音频预处理(创建临时文件) temp_files = create_temp_segments(audio_files) print() # 第3步:批量识别 results = transcribe_audio_files(temp_files, firered_model, device) print() # 第4步:清理临时文件 print("清理临时文件...") cleanup_temp_files() # 统计 successful = len([r for r in results if r.get('text') and not r.get('error')]) failed = len(results) - successful elapsed_time = time.time() - start_time print(f"{'=' * 60}") print(f"处理完成,总耗时: {elapsed_time:.2f}秒") print(f"成功: {successful}/{len(results)}") print(f"失败: {failed}/{len(results)}") print(f"{'=' * 60}\n") # 按需模式:处理完成后卸载模型 if not KEEP_MODEL_LOADED: unload_fireredasr_model() print() # 构建响应数据 response_data = { "data": results, "statistics": { "total_files": len(audio_files), "successful": successful, "failed": failed, "processing_time": round(elapsed_time, 2) }, "message": "success", "code": 200 } return JSONResponse( content=jsonable_encoder(response_data), status_code=200 ) except Exception as e: print(f"处理错误: {str(e)}") import traceback traceback.print_exc() # 确保清理临时文件 cleanup_temp_files() # 按需模式:出错也要卸载模型 if not KEEP_MODEL_LOADED: unload_fireredasr_model() return JSONResponse( content={ "data": [], "statistics": {}, "message": f"处理失败: {str(e)}", "code": 500 }, status_code=200 ) @app.get("/") async def root(): """根路径""" return { "service": "FireRedASR Batch Transcription Service", "version": "3.0 (High Precision Edition)", "description": "批量音频转录服务(高精度版本)", "memory_mode": "常驻模式" if KEEP_MODEL_LOADED else "按需模式", "model_loaded": firered_model is not None, "features": [ "完整的音频预处理流程", "强制标准化为16kHz单声道wav", "临时文件机制确保格式一致", "参考代码的所有精度配置", "说话人标识自动解析", "跨平台路径支持(Windows/Linux)", "灵活的内存管理模式" ], "endpoints": { "transcriptions_batch": "POST /transcriptions_batch - 批量转录", "health": "GET /health - 健康检查", "root": "GET / - 服务信息" } } @app.get("/health") async def health_check(): """健康检查接口""" return JSONResponse( content={ "status": "healthy", "memory_mode": "常驻模式" if KEEP_MODEL_LOADED else "按需模式", "firered_model_loaded": firered_model is not None, "device": str(device) if device else "not initialized", "description": "High precision audio transcription with full preprocessing", "message": "Service is running" }, status_code=200 ) if __name__ == '__main__': print("=" * 60) print("FireRedASR 批量转录服务 v3.0 (高精度版)") print("=" * 60) print("特性:") print(f" ✓ 内存管理: {'常驻模式' if KEEP_MODEL_LOADED else '按需模式'}") print(" ✓ 完整音频预处理(soundfile + librosa + pydub)") print(" ✓ 强制标准化为16kHz单声道wav") print(" ✓ 临时文件机制确保格式统一") print(" ✓ beam_size=5 高精度配置") print(" ✓ 说话人信息自动解析") print(" ✓ 跨平台路径支持") print("=" * 60) print("启动服务器...") print() uvicorn.run(app, host='0.0.0.0', port=7777)