You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
672 lines
21 KiB
672 lines
21 KiB
import sys |
|
import os |
|
import time |
|
import torch |
|
import argparse |
|
import uvicorn |
|
import soundfile as sf |
|
import librosa |
|
import numpy as np |
|
import torchaudio |
|
import uuid |
|
import requests |
|
from pathlib import Path |
|
from pydub import AudioSegment |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.encoders import jsonable_encoder |
|
|
|
# ============ 全局配置 ============ |
|
os.environ['HF_HUB_OFFLINE'] = '1' |
|
os.environ['TRANSFORMERS_OFFLINE'] = '1' |
|
|
|
FIRERED_ASR_PATH = 'G:/Work_data/workstation/Forwork-voice-txt/FireRedASR' |
|
FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L" |
|
PYANNOTE_CONFIG_PATH = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml" |
|
|
|
TEMP_DIR = Path("temp_transcription") |
|
TEMP_DIR.mkdir(exist_ok=True) |
|
|
|
DURATION_THRESHOLD = 50.0 |
|
|
|
pyannote_pipeline = None |
|
firered_model = None |
|
device = None |
|
|
|
|
|
# ============ 文本后处理 ============ |
|
def call_ai_model(sentence): |
|
url = "http://36.158.183.88:7777/v1/chat/completions" |
|
prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。 |
|
{sentence}""" |
|
|
|
payload = { |
|
"model": "Qwen3-32B", |
|
"messages": [{"role": "user", "content": prompt}], |
|
"chat_template_kwargs": {"enable_thinking": False} |
|
} |
|
|
|
try: |
|
response = requests.post(url, json=payload, timeout=120) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result["choices"][0]["message"]["content"].strip() |
|
except: |
|
return sentence |
|
|
|
|
|
def post_process_transcription(transcription_results): |
|
if not transcription_results: |
|
return [], 0.0 |
|
|
|
start_time = time.time() |
|
processed_results = [] |
|
|
|
for segment in transcription_results: |
|
if 'error' in segment or not segment.get('text', '').strip(): |
|
continue |
|
|
|
processed_text = call_ai_model(segment['text']) |
|
|
|
processed_results.append({ |
|
"start": segment['start'], |
|
"end": segment['end'], |
|
"content": processed_text, |
|
"confidence": segment['confidence'], |
|
"segment_type": segment['segment_type'] |
|
}) |
|
|
|
return processed_results, time.time() - start_time |
|
|
|
|
|
# ============ FireRedASR ============ |
|
def setup_fireredasr_environment(): |
|
if FIRERED_ASR_PATH not in sys.path: |
|
sys.path.insert(0, FIRERED_ASR_PATH) |
|
return True |
|
|
|
|
|
def load_fireredasr_model(model_dir): |
|
setup_fireredasr_environment() |
|
|
|
try: |
|
torch.serialization.add_safe_globals([argparse.Namespace]) |
|
|
|
try: |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
try: |
|
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
import fireredasr |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
model = FireRedAsr.from_pretrained("aed", model_dir) |
|
device_name = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
if torch.cuda.is_available(): |
|
try: |
|
model = model.to(device_name) |
|
except: |
|
pass |
|
|
|
if hasattr(model, 'eval'): |
|
model.eval() |
|
|
|
return model, device_name |
|
except Exception as e: |
|
print(f"FireRedASR加载失败: {e}") |
|
return None, None |
|
|
|
|
|
def preprocess_audio_for_asr(audio_path, output_path): |
|
try: |
|
audio_data, sr = sf.read(str(audio_path)) |
|
|
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
if sr != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
sr = 16000 |
|
|
|
temp_file = str(output_path) + ".temp.wav" |
|
sf.write(temp_file, audio_data, sr) |
|
|
|
audio_pydub = AudioSegment.from_file(temp_file) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
audio_pydub.export(str(output_path), format="wav") |
|
|
|
if os.path.exists(temp_file): |
|
os.remove(temp_file) |
|
|
|
return True |
|
except: |
|
return False |
|
|
|
|
|
def transcribe_single_audio(audio_path, model, device_name): |
|
try: |
|
temp_path = TEMP_DIR / f"temp_asr_{uuid.uuid4().hex[:8]}.wav" |
|
|
|
if not preprocess_audio_for_asr(audio_path, temp_path): |
|
return {"text": "", "confidence": 0.0, "error": "Preprocessing failed"} |
|
|
|
config = { |
|
"use_gpu": 1 if device_name.startswith("cuda") else 0, |
|
"beam_size": 5, |
|
"nbest": 1, |
|
"decode_max_len": 0 |
|
} |
|
|
|
with torch.no_grad(): |
|
result = model.transcribe(["audio_001"], [str(temp_path)], config) |
|
|
|
if temp_path.exists(): |
|
temp_path.unlink() |
|
|
|
if result and len(result) > 0: |
|
text = result[0].get('text', '').strip() |
|
confidence = result[0].get('confidence', result[0].get('score', 0.0)) |
|
|
|
if isinstance(confidence, (list, tuple)) and len(confidence) > 0: |
|
confidence = float(confidence[0]) |
|
else: |
|
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 |
|
|
|
return {"text": text, "confidence": round(confidence, 3)} |
|
|
|
return {"text": "", "confidence": 0.0} |
|
except Exception as e: |
|
return {"text": "", "confidence": 0.0, "error": str(e)} |
|
|
|
|
|
def transcribe_audio_segments(segment_files, model, device_name): |
|
results = [] |
|
|
|
for i, file_info in enumerate(segment_files): |
|
try: |
|
temp_path = TEMP_DIR / f"temp_seg_{i:03d}.wav" |
|
|
|
if not preprocess_audio_for_asr(file_info['filepath'], temp_path): |
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"text": "", |
|
"confidence": 0.0, |
|
"segment_type": file_info['segment_type'], |
|
"error": "Preprocessing failed" |
|
}) |
|
continue |
|
|
|
config = { |
|
"use_gpu": 1 if device_name.startswith("cuda") else 0, |
|
"beam_size": 5, |
|
"nbest": 1, |
|
"decode_max_len": 0 |
|
} |
|
|
|
with torch.no_grad(): |
|
transcription = model.transcribe([f"seg_{i:03d}"], [str(temp_path)], config) |
|
|
|
if temp_path.exists(): |
|
temp_path.unlink() |
|
|
|
if transcription and len(transcription) > 0: |
|
text = transcription[0].get('text', '').strip() |
|
confidence = transcription[0].get('confidence', transcription[0].get('score', 0.0)) |
|
|
|
if isinstance(confidence, (list, tuple)) and len(confidence) > 0: |
|
confidence = float(confidence[0]) |
|
else: |
|
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 |
|
|
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"text": text, |
|
"confidence": round(confidence, 3), |
|
"segment_type": file_info['segment_type'] |
|
}) |
|
else: |
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"text": "", |
|
"confidence": 0.0, |
|
"segment_type": file_info['segment_type'] |
|
}) |
|
|
|
if device_name.startswith("cuda"): |
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"text": "", |
|
"confidence": 0.0, |
|
"segment_type": file_info['segment_type'], |
|
"error": str(e) |
|
}) |
|
|
|
return results |
|
|
|
|
|
# ============ Pyannote ============ |
|
def load_pyannote_pipeline(config_path): |
|
try: |
|
from pyannote.audio import Pipeline |
|
pipeline = Pipeline.from_pretrained(config_path) |
|
device_obj = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.to(device_obj) |
|
return pipeline |
|
except Exception as e: |
|
print(f"Pyannote加载失败: {e}") |
|
return None |
|
|
|
|
|
def load_audio_for_segmentation(audio_path): |
|
audio_data, sr = sf.read(str(audio_path)) |
|
|
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
if sr != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
sr = 16000 |
|
|
|
audio_pydub = AudioSegment.from_file(str(audio_path)) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
return audio_data, sr, audio_pydub |
|
|
|
|
|
def find_optimal_boundary(audio_data, sr, start_time, end_time, frame_length=0.025, hop_length=0.005): |
|
if start_time >= end_time: |
|
return start_time, 0.0 |
|
|
|
start_sample = int(start_time * sr) |
|
end_sample = int(end_time * sr) |
|
segment = audio_data[start_sample:end_sample] |
|
|
|
if len(segment) == 0: |
|
return start_time, 0.0 |
|
|
|
frame_samples = int(frame_length * sr) |
|
hop_samples = int(hop_length * sr) |
|
|
|
min_energy = float('inf') |
|
optimal_time = start_time |
|
|
|
for i in range(0, len(segment) - frame_samples + 1, hop_samples): |
|
frame = segment[i:i + frame_samples] |
|
energy = np.sqrt(np.mean(frame ** 2)) |
|
|
|
if energy < min_energy: |
|
min_energy = energy |
|
optimal_time = start_time + (i + frame_samples / 2) / sr |
|
|
|
return optimal_time, min_energy |
|
|
|
|
|
def optimize_segment_boundaries(segments, audio_data, sr, small_gap_threshold=0.1, min_gap_to_keep=0.05): |
|
if not segments: |
|
return [] |
|
|
|
audio_duration = len(audio_data) / sr |
|
optimized = [] |
|
|
|
for i, seg in enumerate(segments): |
|
current_seg = seg.copy() |
|
|
|
if i == 0: |
|
search_start = 0.0 |
|
search_end = seg["start"] |
|
else: |
|
search_start = segments[i - 1]["end"] |
|
search_end = seg["start"] |
|
|
|
gap_size = search_end - search_start |
|
if gap_size > 0 and gap_size >= small_gap_threshold: |
|
optimal_start, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) |
|
current_seg["start"] = optimal_start |
|
|
|
if i == len(segments) - 1: |
|
search_start = seg["end"] |
|
search_end = audio_duration |
|
else: |
|
search_start = seg["end"] |
|
search_end = segments[i + 1]["start"] |
|
|
|
gap_size = search_end - search_start |
|
if gap_size > 0 and gap_size >= small_gap_threshold: |
|
optimal_end, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) |
|
current_seg["end"] = optimal_end |
|
|
|
optimized.append(current_seg) |
|
|
|
for i in range(len(optimized) - 1): |
|
gap = optimized[i + 1]["start"] - optimized[i]["end"] |
|
if 0 < gap < small_gap_threshold: |
|
optimal_point, _ = find_optimal_boundary(audio_data, sr, optimized[i]["end"], optimized[i + 1]["start"]) |
|
optimized[i]["end"] = optimal_point |
|
optimized[i + 1]["start"] = optimal_point |
|
|
|
final_optimized = [] |
|
for i in range(len(optimized)): |
|
if i == 0: |
|
final_optimized.append(optimized[i]) |
|
continue |
|
|
|
prev_seg = final_optimized[-1] |
|
curr_seg = optimized[i] |
|
gap = curr_seg["start"] - prev_seg["end"] |
|
|
|
if gap < 0: |
|
midpoint = (prev_seg["end"] + curr_seg["start"]) / 2 |
|
prev_seg["end"] = midpoint |
|
curr_seg["start"] = midpoint |
|
elif 0 < gap < min_gap_to_keep: |
|
prev_seg["end"] = curr_seg["start"] |
|
|
|
final_optimized.append(curr_seg) |
|
|
|
return final_optimized |
|
|
|
|
|
def merge_segments_by_gap(segments, max_duration=50.0): |
|
if not segments: |
|
return [] |
|
|
|
if len(segments) == 1: |
|
return segments |
|
|
|
groups = [[seg] for seg in segments] |
|
|
|
while True: |
|
gaps = [] |
|
for i in range(len(groups) - 1): |
|
group1_end = groups[i][-1]["end"] |
|
group2_start = groups[i + 1][0]["start"] |
|
gap = group2_start - group1_end |
|
merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"] |
|
|
|
gaps.append({ |
|
"index": i, |
|
"gap": gap, |
|
"merged_duration": merged_duration |
|
}) |
|
|
|
if not gaps: |
|
break |
|
|
|
gaps.sort(key=lambda x: x["gap"]) |
|
|
|
merged = False |
|
for gap_info in gaps: |
|
if gap_info["merged_duration"] <= max_duration: |
|
idx = gap_info["index"] |
|
groups[idx].extend(groups[idx + 1]) |
|
groups.pop(idx + 1) |
|
merged = True |
|
break |
|
|
|
if not merged: |
|
break |
|
|
|
merged_segments = [] |
|
for group in groups: |
|
merged_seg = { |
|
"start": group[0]["start"], |
|
"end": group[-1]["end"], |
|
"original_count": len(group), |
|
"is_merged": len(group) > 1 |
|
} |
|
merged_segments.append(merged_seg) |
|
|
|
return merged_segments |
|
|
|
|
|
def perform_segmentation(audio_path, pipeline_model, device_obj): |
|
audio_data, sr, audio_pydub = load_audio_for_segmentation(audio_path) |
|
|
|
waveform, sample_rate = torchaudio.load(str(audio_path)) |
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
|
waveform = resampler(waveform) |
|
|
|
waveform = waveform.to(device_obj) |
|
audio_dict = {"waveform": waveform, "sample_rate": 16000} |
|
|
|
diarization_result = pipeline_model(audio_dict) |
|
|
|
initial_segments = [] |
|
if hasattr(diarization_result, 'speaker_diarization'): |
|
annotation = diarization_result.speaker_diarization |
|
for segment, track, speaker in annotation.itertracks(yield_label=True): |
|
initial_segments.append({ |
|
"start": round(segment.start, 3), |
|
"end": round(segment.end, 3) |
|
}) |
|
else: |
|
for segment, track, speaker in diarization_result.itertracks(yield_label=True): |
|
initial_segments.append({ |
|
"start": round(segment.start, 3), |
|
"end": round(segment.end, 3) |
|
}) |
|
|
|
merged_segments = merge_segments_by_gap(initial_segments, max_duration=DURATION_THRESHOLD) |
|
optimized_segments = optimize_segment_boundaries(merged_segments, audio_data, sr) |
|
|
|
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" |
|
segment_folder = TEMP_DIR / session_id |
|
segment_folder.mkdir(exist_ok=True) |
|
|
|
segment_files = [] |
|
for i, seg in enumerate(optimized_segments): |
|
start_time = seg["start"] |
|
end_time = seg["end"] |
|
|
|
start_ms = max(0, int(start_time * 1000) - 25) |
|
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25) |
|
|
|
segment_audio = audio_pydub[start_ms:end_ms] |
|
filepath = segment_folder / f"segment_{i:03d}.wav" |
|
segment_audio.export(str(filepath), format="wav") |
|
|
|
segment_type = "forced" if seg.get("is_merged", False) else "natural" |
|
|
|
segment_files.append({ |
|
"filepath": filepath, |
|
"start_ms": int(start_time * 1000), |
|
"end_ms": int(end_time * 1000), |
|
"segment_type": segment_type |
|
}) |
|
|
|
return segment_files, session_id |
|
|
|
|
|
# ============ FastAPI ============ |
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
global pyannote_pipeline, firered_model, device |
|
|
|
print("加载Pyannote模型...") |
|
pyannote_pipeline = load_pyannote_pipeline(PYANNOTE_CONFIG_PATH) |
|
|
|
print("加载FireRedASR模型...") |
|
firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) |
|
|
|
if pyannote_pipeline and firered_model: |
|
print("服务就绪") |
|
else: |
|
print("模型加载失败") |
|
|
|
yield |
|
|
|
|
|
app = FastAPI(lifespan=lifespan, title="Audio Transcription Service") |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
@app.post("/transcriptions") |
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
global pyannote_pipeline, firered_model, device |
|
|
|
if firered_model is None: |
|
raise HTTPException(status_code=500, detail="FireRedASR模型未加载") |
|
|
|
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] |
|
file_ext = os.path.splitext(file.filename)[1].lower() |
|
if file_ext not in allowed_extensions: |
|
raise HTTPException(status_code=400, detail=f"不支持的文件格式") |
|
|
|
temp_audio_path = None |
|
session_folder = None |
|
|
|
try: |
|
total_start_time = time.time() |
|
|
|
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" |
|
temp_audio_path = TEMP_DIR / f"upload_{session_id}{file_ext}" |
|
|
|
with open(temp_audio_path, "wb") as f: |
|
content = await file.read() |
|
f.write(content) |
|
|
|
audio_info = sf.info(str(temp_audio_path)) |
|
duration = audio_info.duration |
|
|
|
if duration < DURATION_THRESHOLD: |
|
result = transcribe_single_audio(temp_audio_path, firered_model, device) |
|
|
|
temp_transcription = [{ |
|
"start": 0, |
|
"end": int(duration * 1000), |
|
"text": result.get("text", ""), |
|
"confidence": result.get("confidence", 0.0), |
|
"segment_type": "direct" |
|
}] |
|
|
|
if 'error' in result: |
|
temp_transcription[0]['error'] = result['error'] |
|
|
|
transcription, post_process_time = post_process_transcription(temp_transcription) |
|
|
|
statistics = { |
|
"total_segments": len(transcription), |
|
"processing_time": round(time.time() - total_start_time, 2), |
|
"post_processing_time": round(post_process_time, 2), |
|
"segmentation_types": { |
|
"natural": 0, |
|
"forced": 0, |
|
"direct": len(transcription) |
|
}, |
|
"processing_method": "direct_asr" |
|
} |
|
else: |
|
if pyannote_pipeline is None: |
|
raise HTTPException(status_code=500, detail="Pyannote模型未加载") |
|
|
|
segment_files, seg_session_id = perform_segmentation( |
|
temp_audio_path, |
|
pyannote_pipeline, |
|
torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
) |
|
|
|
session_folder = TEMP_DIR / seg_session_id |
|
|
|
temp_transcription = transcribe_audio_segments(segment_files, firered_model, device) |
|
transcription, post_process_time = post_process_transcription(temp_transcription) |
|
|
|
seg_types = {"natural": 0, "forced": 0, "direct": 0} |
|
for seg in transcription: |
|
seg_types[seg["segment_type"]] += 1 |
|
|
|
statistics = { |
|
"total_segments": len(transcription), |
|
"processing_time": round(time.time() - total_start_time, 2), |
|
"post_processing_time": round(post_process_time, 2), |
|
"segmentation_types": seg_types, |
|
"processing_method": "pipeline_dual_model" |
|
} |
|
|
|
response_data = { |
|
"code": 200, |
|
"message": "success", |
|
"data": { |
|
"transcription": transcription, |
|
"statistics": statistics |
|
} |
|
} |
|
|
|
return JSONResponse(content=jsonable_encoder(response_data), status_code=200) |
|
|
|
except Exception as e: |
|
return JSONResponse( |
|
content={ |
|
"code": 500, |
|
"message": f"处理失败: {str(e)}", |
|
"data": {"transcription": [], "statistics": {}} |
|
}, |
|
status_code=200 |
|
) |
|
|
|
finally: |
|
if temp_audio_path and temp_audio_path.exists(): |
|
try: |
|
temp_audio_path.unlink() |
|
except: |
|
pass |
|
|
|
if session_folder and session_folder.exists(): |
|
try: |
|
for file in session_folder.glob("*.wav"): |
|
file.unlink() |
|
session_folder.rmdir() |
|
except: |
|
pass |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return JSONResponse( |
|
content={ |
|
"status": "healthy", |
|
"pyannote_loaded": pyannote_pipeline is not None, |
|
"firered_loaded": firered_model is not None, |
|
"device": str(device) if device else "not initialized", |
|
"message": "Service is running" |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return { |
|
"service": "Audio Transcription Service", |
|
"version": "1.0", |
|
"endpoints": { |
|
"transcriptions": "POST /transcriptions", |
|
"health": "GET /health" |
|
} |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=9000) |