You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
651 lines
19 KiB
651 lines
19 KiB
import sys |
|
|
|
sys.path.append('G:/Work_data/workstation/Forwork-voice-txt/FireRedASR') |
|
|
|
import os |
|
import time |
|
import torch |
|
import argparse |
|
import uvicorn |
|
import soundfile as sf |
|
import librosa |
|
import numpy as np |
|
from pathlib import Path |
|
from pydub import AudioSegment |
|
import re |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI |
|
from fastapi.responses import JSONResponse |
|
from fastapi.encoders import jsonable_encoder |
|
from pydantic import BaseModel |
|
|
|
# 全局变量存储FireRedASR模型 |
|
firered_model = None |
|
device = None |
|
|
|
# ============ 参数配置区域 ============ |
|
# 模型加载模式 |
|
KEEP_MODEL_LOADED = True # True=常驻模式(快,占内存), False=按需模式(慢,省内存) |
|
|
|
# 模型路径 |
|
FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L" |
|
|
|
# 临时文件夹 |
|
TEMP_SEGMENTS_DIR = Path("temp_segments") |
|
|
|
|
|
# ===================================== |
|
|
|
|
|
class TranscriptionRequest(BaseModel): |
|
"""批量转录请求模型""" |
|
folder_path: str |
|
|
|
|
|
def setup_fireredasr_environment(): |
|
"""设置FireRedASR环境""" |
|
possible_paths = [ |
|
"./FireRedASR", |
|
"../FireRedASR", |
|
"FireRedASR" |
|
] |
|
|
|
for path in possible_paths: |
|
if Path(path).exists(): |
|
if str(Path(path).absolute()) not in sys.path: |
|
sys.path.insert(0, str(Path(path).absolute())) |
|
return True |
|
return False |
|
|
|
|
|
def load_fireredasr_model(model_dir): |
|
"""加载FireRedASR模型""" |
|
# 设置环境 |
|
if not setup_fireredasr_environment(): |
|
pass |
|
|
|
try: |
|
# 修复PyTorch兼容性 |
|
torch.serialization.add_safe_globals([argparse.Namespace]) |
|
|
|
# 尝试多种导入方式 |
|
try: |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
try: |
|
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
import fireredasr |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
model = FireRedAsr.from_pretrained("aed", model_dir) |
|
|
|
# 尝试使用GPU |
|
device_name = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
if torch.cuda.is_available(): |
|
try: |
|
model = model.to(device_name) |
|
except: |
|
pass |
|
|
|
if hasattr(model, 'eval'): |
|
model.eval() |
|
|
|
print(f"✓ 模型加载成功,使用设备: {device_name}") |
|
return model, device_name |
|
|
|
except Exception as e: |
|
print(f"✗ FireRedASR模型加载失败: {e}") |
|
return None, None |
|
|
|
|
|
def unload_fireredasr_model(): |
|
"""卸载模型,释放内存""" |
|
global firered_model |
|
|
|
if firered_model is not None: |
|
print("正在卸载模型...") |
|
del firered_model |
|
firered_model = None |
|
|
|
# 清空GPU缓存 |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
print("✓ 模型已卸载,内存已释放") |
|
|
|
|
|
def parse_speaker_from_filename(filename): |
|
""" |
|
从文件名中解析说话人信息 |
|
|
|
支持格式: |
|
- 001_spk0_0.00-5.23.wav → speaker0 |
|
- 002_spk1_5.50-8.90.wav → speaker1 |
|
- segment_spk2.wav → speaker2 |
|
""" |
|
# 尝试匹配 spk + 数字 |
|
match = re.search(r'spk(\d+)', filename.lower()) |
|
if match: |
|
speaker_num = match.group(1) |
|
return f"speaker{speaker_num}" |
|
|
|
# 如果没有匹配到,返回默认 |
|
return "speaker0" |
|
|
|
|
|
def scan_audio_folder(folder_path): |
|
""" |
|
扫描文件夹,获取所有音频文件 |
|
|
|
Args: |
|
folder_path: 文件夹路径(支持Windows和Linux格式) |
|
|
|
Returns: |
|
排序后的音频文件列表 |
|
""" |
|
# 使用pathlib处理路径,自动兼容Windows和Linux |
|
folder = Path(folder_path) |
|
|
|
# 验证路径 |
|
if not folder.exists(): |
|
raise ValueError(f"路径不存在: {folder_path}") |
|
|
|
if not folder.is_dir(): |
|
raise ValueError(f"不是有效的文件夹: {folder_path}") |
|
|
|
# 支持的音频格式 |
|
audio_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] |
|
|
|
# 扫描音频文件 |
|
audio_files = [] |
|
for ext in audio_extensions: |
|
audio_files.extend(folder.glob(f"*{ext}")) |
|
|
|
# 按文件名排序 |
|
audio_files = sorted(audio_files, key=lambda x: x.name) |
|
|
|
return audio_files |
|
|
|
|
|
def preprocess_audio_file(audio_path, output_path): |
|
""" |
|
完整的音频预处理(参考原代码以保证最高精度) |
|
|
|
处理步骤: |
|
1. 使用soundfile加载音频 |
|
2. 转换为单声道 |
|
3. 重采样到16kHz |
|
4. 使用pydub标准化 |
|
5. 保存为标准wav格式 |
|
|
|
Args: |
|
audio_path: 原始音频文件路径 |
|
output_path: 输出wav文件路径 |
|
|
|
Returns: |
|
True if success, False otherwise |
|
""" |
|
try: |
|
# 第1步:使用soundfile加载音频数据 |
|
audio_data, sr = sf.read(str(audio_path)) |
|
|
|
# 第2步:转换为单声道 |
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
# 第3步:重采样到16kHz |
|
if sr != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
sr = 16000 |
|
|
|
# 第4步:使用pydub进行格式标准化 |
|
# 先保存为临时文件 |
|
temp_file = str(output_path) + ".temp.wav" |
|
sf.write(temp_file, audio_data, sr) |
|
|
|
# 使用pydub加载并标准化 |
|
audio_pydub = AudioSegment.from_file(temp_file) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
# 第5步:导出为最终的标准wav |
|
audio_pydub.export(str(output_path), format="wav") |
|
|
|
# 删除临时文件 |
|
if os.path.exists(temp_file): |
|
os.remove(temp_file) |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f" ✗ 预处理失败 ({audio_path.name}): {e}") |
|
return False |
|
|
|
|
|
def create_temp_segments(audio_files): |
|
""" |
|
批量预处理音频文件 |
|
|
|
Args: |
|
audio_files: 原始音频文件路径列表 |
|
|
|
Returns: |
|
预处理后的临时文件信息列表 |
|
""" |
|
# 创建临时文件夹 |
|
TEMP_SEGMENTS_DIR.mkdir(exist_ok=True) |
|
|
|
temp_files = [] |
|
print(f"正在预处理 {len(audio_files)} 个音频文件...") |
|
|
|
for i, audio_file in enumerate(audio_files): |
|
# 生成临时文件名 |
|
temp_filename = f"preprocessed_{i:03d}.wav" |
|
temp_filepath = TEMP_SEGMENTS_DIR / temp_filename |
|
|
|
# 预处理音频 |
|
success = preprocess_audio_file(audio_file, temp_filepath) |
|
|
|
if success: |
|
temp_files.append({ |
|
'temp_file': temp_filepath, |
|
'original_file': audio_file, |
|
'filename': audio_file.name, |
|
'index': i |
|
}) |
|
print(f" ✓ [{i + 1}/{len(audio_files)}] {audio_file.name}") |
|
else: |
|
# 预处理失败,记录但不中断 |
|
temp_files.append({ |
|
'temp_file': None, |
|
'original_file': audio_file, |
|
'filename': audio_file.name, |
|
'index': i, |
|
'preprocess_error': True |
|
}) |
|
|
|
print(f"预处理完成: {len([f for f in temp_files if f.get('temp_file')])} / {len(audio_files)} 成功") |
|
return temp_files |
|
|
|
|
|
def transcribe_audio_files(temp_files, model, device_name): |
|
""" |
|
使用FireRedASR识别音频文件(使用预处理后的临时文件) |
|
|
|
Args: |
|
temp_files: 预处理后的临时文件信息列表 |
|
model: FireRedASR模型 |
|
device_name: 设备名称(cuda/cpu) |
|
|
|
Returns: |
|
识别结果列表 |
|
""" |
|
results = [] |
|
use_gpu = device_name.startswith("cuda") |
|
|
|
print(f"开始识别 {len(temp_files)} 个音频文件...") |
|
|
|
for i, file_info in enumerate(temp_files): |
|
filename = file_info['filename'] |
|
temp_file = file_info.get('temp_file') |
|
|
|
# 检查是否有预处理错误 |
|
if file_info.get('preprocess_error'): |
|
print(f" ✗ [{i + 1}/{len(temp_files)}] {filename}: 跳过(预处理失败)") |
|
results.append({ |
|
'filename': filename, |
|
'speaker': parse_speaker_from_filename(filename), |
|
'text': '', |
|
'confidence': 0.0, |
|
'error': 'Preprocessing failed' |
|
}) |
|
continue |
|
|
|
try: |
|
batch_uttid = [f"file_{i:03d}"] |
|
batch_wav_path = [str(temp_file)] |
|
|
|
# 使用参考代码的配置以保证最高精度 |
|
config = { |
|
"use_gpu": 1 if use_gpu else 0, |
|
"beam_size": 5, |
|
"nbest": 1, |
|
"decode_max_len": 0 |
|
} |
|
|
|
with torch.no_grad(): |
|
transcription_result = model.transcribe( |
|
batch_uttid, batch_wav_path, config |
|
) |
|
|
|
if transcription_result and len(transcription_result) > 0: |
|
result = transcription_result[0] |
|
text = result.get('text', '').strip() |
|
|
|
# 完整的置信度提取逻辑(参考原代码) |
|
confidence = result.get('confidence', result.get('score', 0.0)) |
|
if isinstance(confidence, (list, tuple)) and len(confidence) > 0: |
|
confidence = float(confidence[0]) |
|
elif not isinstance(confidence, (int, float)): |
|
confidence = 0.0 |
|
else: |
|
confidence = float(confidence) |
|
|
|
# 解析说话人信息 |
|
speaker = parse_speaker_from_filename(filename) |
|
|
|
if text: |
|
results.append({ |
|
'filename': filename, |
|
'speaker': speaker, |
|
'text': text, |
|
'confidence': round(confidence, 3) |
|
}) |
|
|
|
print(f" ✓ [{i + 1}/{len(temp_files)}] {filename}: {speaker} - {text[:30]}...") |
|
else: |
|
# 识别结果为空 |
|
results.append({ |
|
'filename': filename, |
|
'speaker': speaker, |
|
'text': '', |
|
'confidence': 0.0, |
|
'error': 'Empty transcription' |
|
}) |
|
|
|
# 清理GPU缓存 |
|
if use_gpu: |
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
print(f" ✗ [{i + 1}/{len(temp_files)}] {filename}: 识别失败 - {e}") |
|
# 记录失败的文件 |
|
results.append({ |
|
'filename': filename, |
|
'speaker': parse_speaker_from_filename(filename), |
|
'text': '', |
|
'confidence': 0.0, |
|
'error': str(e) |
|
}) |
|
continue |
|
|
|
# 统计成功数量 |
|
successful_count = len([r for r in results if r.get('text') and not r.get('error')]) |
|
print(f"识别完成: {successful_count}/{len(temp_files)} 个文件成功") |
|
|
|
return results |
|
|
|
|
|
def cleanup_temp_files(): |
|
"""清理临时文件""" |
|
if TEMP_SEGMENTS_DIR.exists(): |
|
for file in TEMP_SEGMENTS_DIR.glob("preprocessed_*.wav"): |
|
try: |
|
file.unlink() |
|
except: |
|
pass |
|
|
|
# 尝试删除文件夹 |
|
try: |
|
TEMP_SEGMENTS_DIR.rmdir() |
|
except: |
|
pass |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""FastAPI应用生命周期管理""" |
|
global firered_model, device |
|
|
|
print("=" * 60) |
|
print("正在启动FireRedASR服务...") |
|
print("=" * 60) |
|
|
|
# 初始化设备 |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
print(f"设备: {device}") |
|
|
|
# 根据模式决定是否加载模型 |
|
if KEEP_MODEL_LOADED: |
|
print("\n内存管理模式: 常驻模式(模型始终驻留内存)") |
|
if Path(FIRERED_MODEL_PATH).exists(): |
|
firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) |
|
if firered_model is not None: |
|
print("✓ 模型已加载到内存") |
|
else: |
|
print("✗ 模型加载失败") |
|
else: |
|
print(f"✗ 模型路径不存在: {FIRERED_MODEL_PATH}") |
|
else: |
|
print("\n内存管理模式: 按需模式(每次处理时才加载模型)") |
|
print("✓ 设备已初始化,模型将在首次请求时加载") |
|
firered_model = None |
|
|
|
print("=" * 60) |
|
|
|
yield # 应用运行期间 |
|
|
|
# 关闭时执行 |
|
print("服务正在关闭...") |
|
cleanup_temp_files() |
|
|
|
|
|
# 创建FastAPI应用实例 |
|
app = FastAPI(lifespan=lifespan, title="FireRedASR Batch Transcription Service (High Precision)") |
|
|
|
|
|
@app.post("/transcriptions_batch") |
|
async def transcribe_batch(request: TranscriptionRequest): |
|
""" |
|
批量音频转录API端点(高精度版本) |
|
|
|
特性: |
|
- 完整的音频预处理流程 |
|
- 强制标准化为16kHz单声道wav |
|
- 使用临时文件确保格式一致 |
|
- 参考代码的所有精度配置 |
|
|
|
输入: 包含音频片段的文件夹路径 |
|
输出: 每个片段的识别结果(speaker + text + confidence) |
|
""" |
|
global firered_model, device |
|
|
|
try: |
|
# 按需模式:临时加载模型 |
|
if not KEEP_MODEL_LOADED: |
|
if firered_model is None: |
|
print("\n临时加载模型...") |
|
if Path(FIRERED_MODEL_PATH).exists(): |
|
firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) |
|
else: |
|
return JSONResponse( |
|
content={ |
|
"data": [], |
|
"statistics": {}, |
|
"message": f"模型路径不存在: {FIRERED_MODEL_PATH}", |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
# 检查FireRedASR模型是否已加载 |
|
if firered_model is None: |
|
return JSONResponse( |
|
content={ |
|
"data": [], |
|
"statistics": {}, |
|
"message": "FireRedASR模型未加载,请检查模型路径并重启服务", |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
print(f"\n{'=' * 60}") |
|
print(f"收到批量转录请求") |
|
print(f"文件夹路径: {request.folder_path}") |
|
print(f"处理模式: {'常驻模式' if KEEP_MODEL_LOADED else '按需模式'}") |
|
|
|
start_time = time.time() |
|
|
|
# 第1步:扫描文件夹 |
|
try: |
|
audio_files = scan_audio_folder(request.folder_path) |
|
except ValueError as e: |
|
return JSONResponse( |
|
content={ |
|
"data": [], |
|
"statistics": {}, |
|
"message": str(e), |
|
"code": 400 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
if not audio_files: |
|
return JSONResponse( |
|
content={ |
|
"data": [], |
|
"statistics": { |
|
"total_files": 0, |
|
"successful": 0, |
|
"failed": 0, |
|
"processing_time": 0 |
|
}, |
|
"message": "文件夹中没有找到音频文件", |
|
"code": 400 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
print(f"找到 {len(audio_files)} 个音频文件") |
|
print(f"{'=' * 60}\n") |
|
|
|
# 第2步:音频预处理(创建临时文件) |
|
temp_files = create_temp_segments(audio_files) |
|
print() |
|
|
|
# 第3步:批量识别 |
|
results = transcribe_audio_files(temp_files, firered_model, device) |
|
print() |
|
|
|
# 第4步:清理临时文件 |
|
print("清理临时文件...") |
|
cleanup_temp_files() |
|
|
|
# 统计 |
|
successful = len([r for r in results if r.get('text') and not r.get('error')]) |
|
failed = len(results) - successful |
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
print(f"{'=' * 60}") |
|
print(f"处理完成,总耗时: {elapsed_time:.2f}秒") |
|
print(f"成功: {successful}/{len(results)}") |
|
print(f"失败: {failed}/{len(results)}") |
|
print(f"{'=' * 60}\n") |
|
|
|
# 按需模式:处理完成后卸载模型 |
|
if not KEEP_MODEL_LOADED: |
|
unload_fireredasr_model() |
|
print() |
|
|
|
# 构建响应数据 |
|
response_data = { |
|
"data": results, |
|
"statistics": { |
|
"total_files": len(audio_files), |
|
"successful": successful, |
|
"failed": failed, |
|
"processing_time": round(elapsed_time, 2) |
|
}, |
|
"message": "success", |
|
"code": 200 |
|
} |
|
|
|
return JSONResponse( |
|
content=jsonable_encoder(response_data), |
|
status_code=200 |
|
) |
|
|
|
except Exception as e: |
|
print(f"处理错误: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
# 确保清理临时文件 |
|
cleanup_temp_files() |
|
|
|
# 按需模式:出错也要卸载模型 |
|
if not KEEP_MODEL_LOADED: |
|
unload_fireredasr_model() |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": [], |
|
"statistics": {}, |
|
"message": f"处理失败: {str(e)}", |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
"""根路径""" |
|
return { |
|
"service": "FireRedASR Batch Transcription Service", |
|
"version": "3.0 (High Precision Edition)", |
|
"description": "批量音频转录服务(高精度版本)", |
|
"memory_mode": "常驻模式" if KEEP_MODEL_LOADED else "按需模式", |
|
"model_loaded": firered_model is not None, |
|
"features": [ |
|
"完整的音频预处理流程", |
|
"强制标准化为16kHz单声道wav", |
|
"临时文件机制确保格式一致", |
|
"参考代码的所有精度配置", |
|
"说话人标识自动解析", |
|
"跨平台路径支持(Windows/Linux)", |
|
"灵活的内存管理模式" |
|
], |
|
"endpoints": { |
|
"transcriptions_batch": "POST /transcriptions_batch - 批量转录", |
|
"health": "GET /health - 健康检查", |
|
"root": "GET / - 服务信息" |
|
} |
|
} |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""健康检查接口""" |
|
return JSONResponse( |
|
content={ |
|
"status": "healthy", |
|
"memory_mode": "常驻模式" if KEEP_MODEL_LOADED else "按需模式", |
|
"firered_model_loaded": firered_model is not None, |
|
"device": str(device) if device else "not initialized", |
|
"description": "High precision audio transcription with full preprocessing", |
|
"message": "Service is running" |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
print("=" * 60) |
|
print("FireRedASR 批量转录服务 v3.0 (高精度版)") |
|
print("=" * 60) |
|
print("特性:") |
|
print(f" ✓ 内存管理: {'常驻模式' if KEEP_MODEL_LOADED else '按需模式'}") |
|
print(" ✓ 完整音频预处理(soundfile + librosa + pydub)") |
|
print(" ✓ 强制标准化为16kHz单声道wav") |
|
print(" ✓ 临时文件机制确保格式统一") |
|
print(" ✓ beam_size=5 高精度配置") |
|
print(" ✓ 说话人信息自动解析") |
|
print(" ✓ 跨平台路径支持") |
|
print("=" * 60) |
|
print("启动服务器...") |
|
print() |
|
|
|
uvicorn.run(app, host='0.0.0.0', port=7777) |