import sys import os import time import torch import argparse import uvicorn import soundfile as sf import librosa import numpy as np import torchaudio import uuid import requests from pathlib import Path from pydub import AudioSegment from contextlib import asynccontextmanager from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.encoders import jsonable_encoder # ============ 全局配置 ============ os.environ['HF_HUB_OFFLINE'] = '1' os.environ['TRANSFORMERS_OFFLINE'] = '1' FIRERED_ASR_PATH = 'G:/Work_data/workstation/Forwork-voice-txt/FireRedASR' FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L" PYANNOTE_CONFIG_PATH = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml" TEMP_DIR = Path("temp_transcription") TEMP_DIR.mkdir(exist_ok=True) DURATION_THRESHOLD = 50.0 pyannote_pipeline = None firered_model = None device = None # ============ 文本后处理 ============ def call_ai_model(sentence): url = "http://36.158.183.88:7777/v1/chat/completions" prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。 {sentence}""" payload = { "model": "Qwen3-32B", "messages": [{"role": "user", "content": prompt}], "chat_template_kwargs": {"enable_thinking": False} } try: response = requests.post(url, json=payload, timeout=120) response.raise_for_status() result = response.json() return result["choices"][0]["message"]["content"].strip() except: return sentence def post_process_transcription(transcription_results): if not transcription_results: return [], 0.0 start_time = time.time() processed_results = [] for segment in transcription_results: if 'error' in segment or not segment.get('text', '').strip(): continue processed_text = call_ai_model(segment['text']) processed_results.append({ "start": segment['start'], "end": segment['end'], "content": processed_text, "confidence": segment['confidence'], "segment_type": segment['segment_type'] }) return processed_results, time.time() - start_time # ============ FireRedASR ============ def setup_fireredasr_environment(): if FIRERED_ASR_PATH not in sys.path: sys.path.insert(0, FIRERED_ASR_PATH) return True def load_fireredasr_model(model_dir): setup_fireredasr_environment() try: torch.serialization.add_safe_globals([argparse.Namespace]) try: from fireredasr.models.fireredasr import FireRedAsr except ImportError: try: from FireRedASR.fireredasr.models.fireredasr import FireRedAsr except ImportError: import fireredasr from fireredasr.models.fireredasr import FireRedAsr model = FireRedAsr.from_pretrained("aed", model_dir) device_name = "cuda:0" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available(): try: model = model.to(device_name) except: pass if hasattr(model, 'eval'): model.eval() return model, device_name except Exception as e: print(f"FireRedASR加载失败: {e}") return None, None def preprocess_audio_for_asr(audio_path, output_path): try: audio_data, sr = sf.read(str(audio_path)) if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) if sr != 16000: audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) sr = 16000 temp_file = str(output_path) + ".temp.wav" sf.write(temp_file, audio_data, sr) audio_pydub = AudioSegment.from_file(temp_file) audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) audio_pydub.export(str(output_path), format="wav") if os.path.exists(temp_file): os.remove(temp_file) return True except: return False def transcribe_single_audio(audio_path, model, device_name): try: temp_path = TEMP_DIR / f"temp_asr_{uuid.uuid4().hex[:8]}.wav" if not preprocess_audio_for_asr(audio_path, temp_path): return {"text": "", "confidence": 0.0, "error": "Preprocessing failed"} config = { "use_gpu": 1 if device_name.startswith("cuda") else 0, "beam_size": 5, "nbest": 1, "decode_max_len": 0 } with torch.no_grad(): result = model.transcribe(["audio_001"], [str(temp_path)], config) if temp_path.exists(): temp_path.unlink() if result and len(result) > 0: text = result[0].get('text', '').strip() confidence = result[0].get('confidence', result[0].get('score', 0.0)) if isinstance(confidence, (list, tuple)) and len(confidence) > 0: confidence = float(confidence[0]) else: confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 return {"text": text, "confidence": round(confidence, 3)} return {"text": "", "confidence": 0.0} except Exception as e: return {"text": "", "confidence": 0.0, "error": str(e)} def transcribe_audio_segments(segment_files, model, device_name): results = [] for i, file_info in enumerate(segment_files): try: temp_path = TEMP_DIR / f"temp_seg_{i:03d}.wav" if not preprocess_audio_for_asr(file_info['filepath'], temp_path): results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "text": "", "confidence": 0.0, "segment_type": file_info['segment_type'], "error": "Preprocessing failed" }) continue config = { "use_gpu": 1 if device_name.startswith("cuda") else 0, "beam_size": 5, "nbest": 1, "decode_max_len": 0 } with torch.no_grad(): transcription = model.transcribe([f"seg_{i:03d}"], [str(temp_path)], config) if temp_path.exists(): temp_path.unlink() if transcription and len(transcription) > 0: text = transcription[0].get('text', '').strip() confidence = transcription[0].get('confidence', transcription[0].get('score', 0.0)) if isinstance(confidence, (list, tuple)) and len(confidence) > 0: confidence = float(confidence[0]) else: confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "text": text, "confidence": round(confidence, 3), "segment_type": file_info['segment_type'] }) else: results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "text": "", "confidence": 0.0, "segment_type": file_info['segment_type'] }) if device_name.startswith("cuda"): torch.cuda.empty_cache() except Exception as e: results.append({ "start": file_info['start_ms'], "end": file_info['end_ms'], "text": "", "confidence": 0.0, "segment_type": file_info['segment_type'], "error": str(e) }) return results # ============ Pyannote ============ def load_pyannote_pipeline(config_path): try: from pyannote.audio import Pipeline pipeline = Pipeline.from_pretrained(config_path) device_obj = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline.to(device_obj) return pipeline except Exception as e: print(f"Pyannote加载失败: {e}") return None def load_audio_for_segmentation(audio_path): audio_data, sr = sf.read(str(audio_path)) if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) if sr != 16000: audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) sr = 16000 audio_pydub = AudioSegment.from_file(str(audio_path)) audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) return audio_data, sr, audio_pydub def find_optimal_boundary(audio_data, sr, start_time, end_time, frame_length=0.025, hop_length=0.005): if start_time >= end_time: return start_time, 0.0 start_sample = int(start_time * sr) end_sample = int(end_time * sr) segment = audio_data[start_sample:end_sample] if len(segment) == 0: return start_time, 0.0 frame_samples = int(frame_length * sr) hop_samples = int(hop_length * sr) min_energy = float('inf') optimal_time = start_time for i in range(0, len(segment) - frame_samples + 1, hop_samples): frame = segment[i:i + frame_samples] energy = np.sqrt(np.mean(frame ** 2)) if energy < min_energy: min_energy = energy optimal_time = start_time + (i + frame_samples / 2) / sr return optimal_time, min_energy def optimize_segment_boundaries(segments, audio_data, sr, small_gap_threshold=0.1, min_gap_to_keep=0.05): if not segments: return [] audio_duration = len(audio_data) / sr optimized = [] for i, seg in enumerate(segments): current_seg = seg.copy() if i == 0: search_start = 0.0 search_end = seg["start"] else: search_start = segments[i - 1]["end"] search_end = seg["start"] gap_size = search_end - search_start if gap_size > 0 and gap_size >= small_gap_threshold: optimal_start, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) current_seg["start"] = optimal_start if i == len(segments) - 1: search_start = seg["end"] search_end = audio_duration else: search_start = seg["end"] search_end = segments[i + 1]["start"] gap_size = search_end - search_start if gap_size > 0 and gap_size >= small_gap_threshold: optimal_end, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) current_seg["end"] = optimal_end optimized.append(current_seg) for i in range(len(optimized) - 1): gap = optimized[i + 1]["start"] - optimized[i]["end"] if 0 < gap < small_gap_threshold: optimal_point, _ = find_optimal_boundary(audio_data, sr, optimized[i]["end"], optimized[i + 1]["start"]) optimized[i]["end"] = optimal_point optimized[i + 1]["start"] = optimal_point final_optimized = [] for i in range(len(optimized)): if i == 0: final_optimized.append(optimized[i]) continue prev_seg = final_optimized[-1] curr_seg = optimized[i] gap = curr_seg["start"] - prev_seg["end"] if gap < 0: midpoint = (prev_seg["end"] + curr_seg["start"]) / 2 prev_seg["end"] = midpoint curr_seg["start"] = midpoint elif 0 < gap < min_gap_to_keep: prev_seg["end"] = curr_seg["start"] final_optimized.append(curr_seg) return final_optimized def merge_segments_by_gap(segments, max_duration=50.0): if not segments: return [] if len(segments) == 1: return segments groups = [[seg] for seg in segments] while True: gaps = [] for i in range(len(groups) - 1): group1_end = groups[i][-1]["end"] group2_start = groups[i + 1][0]["start"] gap = group2_start - group1_end merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"] gaps.append({ "index": i, "gap": gap, "merged_duration": merged_duration }) if not gaps: break gaps.sort(key=lambda x: x["gap"]) merged = False for gap_info in gaps: if gap_info["merged_duration"] <= max_duration: idx = gap_info["index"] groups[idx].extend(groups[idx + 1]) groups.pop(idx + 1) merged = True break if not merged: break merged_segments = [] for group in groups: merged_seg = { "start": group[0]["start"], "end": group[-1]["end"], "original_count": len(group), "is_merged": len(group) > 1 } merged_segments.append(merged_seg) return merged_segments def perform_segmentation(audio_path, pipeline_model, device_obj): audio_data, sr, audio_pydub = load_audio_for_segmentation(audio_path) waveform, sample_rate = torchaudio.load(str(audio_path)) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) waveform = waveform.to(device_obj) audio_dict = {"waveform": waveform, "sample_rate": 16000} diarization_result = pipeline_model(audio_dict) initial_segments = [] if hasattr(diarization_result, 'speaker_diarization'): annotation = diarization_result.speaker_diarization for segment, track, speaker in annotation.itertracks(yield_label=True): initial_segments.append({ "start": round(segment.start, 3), "end": round(segment.end, 3) }) else: for segment, track, speaker in diarization_result.itertracks(yield_label=True): initial_segments.append({ "start": round(segment.start, 3), "end": round(segment.end, 3) }) merged_segments = merge_segments_by_gap(initial_segments, max_duration=DURATION_THRESHOLD) optimized_segments = optimize_segment_boundaries(merged_segments, audio_data, sr) session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" segment_folder = TEMP_DIR / session_id segment_folder.mkdir(exist_ok=True) segment_files = [] for i, seg in enumerate(optimized_segments): start_time = seg["start"] end_time = seg["end"] start_ms = max(0, int(start_time * 1000) - 25) end_ms = min(len(audio_pydub), int(end_time * 1000) + 25) segment_audio = audio_pydub[start_ms:end_ms] filepath = segment_folder / f"segment_{i:03d}.wav" segment_audio.export(str(filepath), format="wav") segment_type = "forced" if seg.get("is_merged", False) else "natural" segment_files.append({ "filepath": filepath, "start_ms": int(start_time * 1000), "end_ms": int(end_time * 1000), "segment_type": segment_type }) return segment_files, session_id # ============ FastAPI ============ @asynccontextmanager async def lifespan(app: FastAPI): global pyannote_pipeline, firered_model, device print("加载Pyannote模型...") pyannote_pipeline = load_pyannote_pipeline(PYANNOTE_CONFIG_PATH) print("加载FireRedASR模型...") firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) if pyannote_pipeline and firered_model: print("服务就绪") else: print("模型加载失败") yield app = FastAPI(lifespan=lifespan, title="Audio Transcription Service") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/transcriptions") async def transcribe_audio(file: UploadFile = File(...)): global pyannote_pipeline, firered_model, device if firered_model is None: raise HTTPException(status_code=500, detail="FireRedASR模型未加载") allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] file_ext = os.path.splitext(file.filename)[1].lower() if file_ext not in allowed_extensions: raise HTTPException(status_code=400, detail=f"不支持的文件格式") temp_audio_path = None session_folder = None try: total_start_time = time.time() session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" temp_audio_path = TEMP_DIR / f"upload_{session_id}{file_ext}" with open(temp_audio_path, "wb") as f: content = await file.read() f.write(content) audio_info = sf.info(str(temp_audio_path)) duration = audio_info.duration if duration < DURATION_THRESHOLD: result = transcribe_single_audio(temp_audio_path, firered_model, device) temp_transcription = [{ "start": 0, "end": int(duration * 1000), "text": result.get("text", ""), "confidence": result.get("confidence", 0.0), "segment_type": "direct" }] if 'error' in result: temp_transcription[0]['error'] = result['error'] transcription, post_process_time = post_process_transcription(temp_transcription) statistics = { "total_segments": len(transcription), "processing_time": round(time.time() - total_start_time, 2), "post_processing_time": round(post_process_time, 2), "segmentation_types": { "natural": 0, "forced": 0, "direct": len(transcription) }, "processing_method": "direct_asr" } else: if pyannote_pipeline is None: raise HTTPException(status_code=500, detail="Pyannote模型未加载") segment_files, seg_session_id = perform_segmentation( temp_audio_path, pyannote_pipeline, torch.device("cuda" if torch.cuda.is_available() else "cpu") ) session_folder = TEMP_DIR / seg_session_id temp_transcription = transcribe_audio_segments(segment_files, firered_model, device) transcription, post_process_time = post_process_transcription(temp_transcription) seg_types = {"natural": 0, "forced": 0, "direct": 0} for seg in transcription: seg_types[seg["segment_type"]] += 1 statistics = { "total_segments": len(transcription), "processing_time": round(time.time() - total_start_time, 2), "post_processing_time": round(post_process_time, 2), "segmentation_types": seg_types, "processing_method": "pipeline_dual_model" } response_data = { "code": 200, "message": "success", "data": { "transcription": transcription, "statistics": statistics } } return JSONResponse(content=jsonable_encoder(response_data), status_code=200) except Exception as e: return JSONResponse( content={ "code": 500, "message": f"处理失败: {str(e)}", "data": {"transcription": [], "statistics": {}} }, status_code=200 ) finally: if temp_audio_path and temp_audio_path.exists(): try: temp_audio_path.unlink() except: pass if session_folder and session_folder.exists(): try: for file in session_folder.glob("*.wav"): file.unlink() session_folder.rmdir() except: pass @app.get("/health") async def health_check(): return JSONResponse( content={ "status": "healthy", "pyannote_loaded": pyannote_pipeline is not None, "firered_loaded": firered_model is not None, "device": str(device) if device else "not initialized", "message": "Service is running" }, status_code=200 ) @app.get("/") async def root(): return { "service": "Audio Transcription Service", "version": "1.0", "endpoints": { "transcriptions": "POST /transcriptions", "health": "GET /health" } } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=9000)