You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

672 lines
21 KiB

import sys
import os
import time
import torch
import argparse
import uvicorn
import soundfile as sf
import librosa
import numpy as np
import torchaudio
import uuid
import requests
from pathlib import Path
from pydub import AudioSegment
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.encoders import jsonable_encoder
# ============ 全局配置 ============
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
FIRERED_ASR_PATH = 'G:/Work_data/workstation/Forwork-voice-txt/FireRedASR'
FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L"
PYANNOTE_CONFIG_PATH = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml"
TEMP_DIR = Path("temp_transcription")
TEMP_DIR.mkdir(exist_ok=True)
DURATION_THRESHOLD = 50.0
pyannote_pipeline = None
firered_model = None
device = None
# ============ 文本后处理 ============
def call_ai_model(sentence):
url = "http://36.158.183.88:7777/v1/chat/completions"
prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。
{sentence}"""
payload = {
"model": "Qwen3-32B",
"messages": [{"role": "user", "content": prompt}],
"chat_template_kwargs": {"enable_thinking": False}
}
try:
response = requests.post(url, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"].strip()
except:
return sentence
def post_process_transcription(transcription_results):
if not transcription_results:
return [], 0.0
start_time = time.time()
processed_results = []
for segment in transcription_results:
if 'error' in segment or not segment.get('text', '').strip():
continue
processed_text = call_ai_model(segment['text'])
processed_results.append({
"start": segment['start'],
"end": segment['end'],
"content": processed_text,
"confidence": segment['confidence'],
"segment_type": segment['segment_type']
})
return processed_results, time.time() - start_time
# ============ FireRedASR ============
def setup_fireredasr_environment():
if FIRERED_ASR_PATH not in sys.path:
sys.path.insert(0, FIRERED_ASR_PATH)
return True
def load_fireredasr_model(model_dir):
setup_fireredasr_environment()
try:
torch.serialization.add_safe_globals([argparse.Namespace])
try:
from fireredasr.models.fireredasr import FireRedAsr
except ImportError:
try:
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr
except ImportError:
import fireredasr
from fireredasr.models.fireredasr import FireRedAsr
model = FireRedAsr.from_pretrained("aed", model_dir)
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
try:
model = model.to(device_name)
except:
pass
if hasattr(model, 'eval'):
model.eval()
return model, device_name
except Exception as e:
print(f"FireRedASR加载失败: {e}")
return None, None
def preprocess_audio_for_asr(audio_path, output_path):
try:
audio_data, sr = sf.read(str(audio_path))
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
if sr != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
sr = 16000
temp_file = str(output_path) + ".temp.wav"
sf.write(temp_file, audio_data, sr)
audio_pydub = AudioSegment.from_file(temp_file)
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1)
audio_pydub.export(str(output_path), format="wav")
if os.path.exists(temp_file):
os.remove(temp_file)
return True
except:
return False
def transcribe_single_audio(audio_path, model, device_name):
try:
temp_path = TEMP_DIR / f"temp_asr_{uuid.uuid4().hex[:8]}.wav"
if not preprocess_audio_for_asr(audio_path, temp_path):
return {"text": "", "confidence": 0.0, "error": "Preprocessing failed"}
config = {
"use_gpu": 1 if device_name.startswith("cuda") else 0,
"beam_size": 5,
"nbest": 1,
"decode_max_len": 0
}
with torch.no_grad():
result = model.transcribe(["audio_001"], [str(temp_path)], config)
if temp_path.exists():
temp_path.unlink()
if result and len(result) > 0:
text = result[0].get('text', '').strip()
confidence = result[0].get('confidence', result[0].get('score', 0.0))
if isinstance(confidence, (list, tuple)) and len(confidence) > 0:
confidence = float(confidence[0])
else:
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0
return {"text": text, "confidence": round(confidence, 3)}
return {"text": "", "confidence": 0.0}
except Exception as e:
return {"text": "", "confidence": 0.0, "error": str(e)}
def transcribe_audio_segments(segment_files, model, device_name):
results = []
for i, file_info in enumerate(segment_files):
try:
temp_path = TEMP_DIR / f"temp_seg_{i:03d}.wav"
if not preprocess_audio_for_asr(file_info['filepath'], temp_path):
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"text": "",
"confidence": 0.0,
"segment_type": file_info['segment_type'],
"error": "Preprocessing failed"
})
continue
config = {
"use_gpu": 1 if device_name.startswith("cuda") else 0,
"beam_size": 5,
"nbest": 1,
"decode_max_len": 0
}
with torch.no_grad():
transcription = model.transcribe([f"seg_{i:03d}"], [str(temp_path)], config)
if temp_path.exists():
temp_path.unlink()
if transcription and len(transcription) > 0:
text = transcription[0].get('text', '').strip()
confidence = transcription[0].get('confidence', transcription[0].get('score', 0.0))
if isinstance(confidence, (list, tuple)) and len(confidence) > 0:
confidence = float(confidence[0])
else:
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"text": text,
"confidence": round(confidence, 3),
"segment_type": file_info['segment_type']
})
else:
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"text": "",
"confidence": 0.0,
"segment_type": file_info['segment_type']
})
if device_name.startswith("cuda"):
torch.cuda.empty_cache()
except Exception as e:
results.append({
"start": file_info['start_ms'],
"end": file_info['end_ms'],
"text": "",
"confidence": 0.0,
"segment_type": file_info['segment_type'],
"error": str(e)
})
return results
# ============ Pyannote ============
def load_pyannote_pipeline(config_path):
try:
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained(config_path)
device_obj = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device_obj)
return pipeline
except Exception as e:
print(f"Pyannote加载失败: {e}")
return None
def load_audio_for_segmentation(audio_path):
audio_data, sr = sf.read(str(audio_path))
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
if sr != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
sr = 16000
audio_pydub = AudioSegment.from_file(str(audio_path))
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1)
return audio_data, sr, audio_pydub
def find_optimal_boundary(audio_data, sr, start_time, end_time, frame_length=0.025, hop_length=0.005):
if start_time >= end_time:
return start_time, 0.0
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
segment = audio_data[start_sample:end_sample]
if len(segment) == 0:
return start_time, 0.0
frame_samples = int(frame_length * sr)
hop_samples = int(hop_length * sr)
min_energy = float('inf')
optimal_time = start_time
for i in range(0, len(segment) - frame_samples + 1, hop_samples):
frame = segment[i:i + frame_samples]
energy = np.sqrt(np.mean(frame ** 2))
if energy < min_energy:
min_energy = energy
optimal_time = start_time + (i + frame_samples / 2) / sr
return optimal_time, min_energy
def optimize_segment_boundaries(segments, audio_data, sr, small_gap_threshold=0.1, min_gap_to_keep=0.05):
if not segments:
return []
audio_duration = len(audio_data) / sr
optimized = []
for i, seg in enumerate(segments):
current_seg = seg.copy()
if i == 0:
search_start = 0.0
search_end = seg["start"]
else:
search_start = segments[i - 1]["end"]
search_end = seg["start"]
gap_size = search_end - search_start
if gap_size > 0 and gap_size >= small_gap_threshold:
optimal_start, _ = find_optimal_boundary(audio_data, sr, search_start, search_end)
current_seg["start"] = optimal_start
if i == len(segments) - 1:
search_start = seg["end"]
search_end = audio_duration
else:
search_start = seg["end"]
search_end = segments[i + 1]["start"]
gap_size = search_end - search_start
if gap_size > 0 and gap_size >= small_gap_threshold:
optimal_end, _ = find_optimal_boundary(audio_data, sr, search_start, search_end)
current_seg["end"] = optimal_end
optimized.append(current_seg)
for i in range(len(optimized) - 1):
gap = optimized[i + 1]["start"] - optimized[i]["end"]
if 0 < gap < small_gap_threshold:
optimal_point, _ = find_optimal_boundary(audio_data, sr, optimized[i]["end"], optimized[i + 1]["start"])
optimized[i]["end"] = optimal_point
optimized[i + 1]["start"] = optimal_point
final_optimized = []
for i in range(len(optimized)):
if i == 0:
final_optimized.append(optimized[i])
continue
prev_seg = final_optimized[-1]
curr_seg = optimized[i]
gap = curr_seg["start"] - prev_seg["end"]
if gap < 0:
midpoint = (prev_seg["end"] + curr_seg["start"]) / 2
prev_seg["end"] = midpoint
curr_seg["start"] = midpoint
elif 0 < gap < min_gap_to_keep:
prev_seg["end"] = curr_seg["start"]
final_optimized.append(curr_seg)
return final_optimized
def merge_segments_by_gap(segments, max_duration=50.0):
if not segments:
return []
if len(segments) == 1:
return segments
groups = [[seg] for seg in segments]
while True:
gaps = []
for i in range(len(groups) - 1):
group1_end = groups[i][-1]["end"]
group2_start = groups[i + 1][0]["start"]
gap = group2_start - group1_end
merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"]
gaps.append({
"index": i,
"gap": gap,
"merged_duration": merged_duration
})
if not gaps:
break
gaps.sort(key=lambda x: x["gap"])
merged = False
for gap_info in gaps:
if gap_info["merged_duration"] <= max_duration:
idx = gap_info["index"]
groups[idx].extend(groups[idx + 1])
groups.pop(idx + 1)
merged = True
break
if not merged:
break
merged_segments = []
for group in groups:
merged_seg = {
"start": group[0]["start"],
"end": group[-1]["end"],
"original_count": len(group),
"is_merged": len(group) > 1
}
merged_segments.append(merged_seg)
return merged_segments
def perform_segmentation(audio_path, pipeline_model, device_obj):
audio_data, sr, audio_pydub = load_audio_for_segmentation(audio_path)
waveform, sample_rate = torchaudio.load(str(audio_path))
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
waveform = waveform.to(device_obj)
audio_dict = {"waveform": waveform, "sample_rate": 16000}
diarization_result = pipeline_model(audio_dict)
initial_segments = []
if hasattr(diarization_result, 'speaker_diarization'):
annotation = diarization_result.speaker_diarization
for segment, track, speaker in annotation.itertracks(yield_label=True):
initial_segments.append({
"start": round(segment.start, 3),
"end": round(segment.end, 3)
})
else:
for segment, track, speaker in diarization_result.itertracks(yield_label=True):
initial_segments.append({
"start": round(segment.start, 3),
"end": round(segment.end, 3)
})
merged_segments = merge_segments_by_gap(initial_segments, max_duration=DURATION_THRESHOLD)
optimized_segments = optimize_segment_boundaries(merged_segments, audio_data, sr)
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
segment_folder = TEMP_DIR / session_id
segment_folder.mkdir(exist_ok=True)
segment_files = []
for i, seg in enumerate(optimized_segments):
start_time = seg["start"]
end_time = seg["end"]
start_ms = max(0, int(start_time * 1000) - 25)
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25)
segment_audio = audio_pydub[start_ms:end_ms]
filepath = segment_folder / f"segment_{i:03d}.wav"
segment_audio.export(str(filepath), format="wav")
segment_type = "forced" if seg.get("is_merged", False) else "natural"
segment_files.append({
"filepath": filepath,
"start_ms": int(start_time * 1000),
"end_ms": int(end_time * 1000),
"segment_type": segment_type
})
return segment_files, session_id
# ============ FastAPI ============
@asynccontextmanager
async def lifespan(app: FastAPI):
global pyannote_pipeline, firered_model, device
print("加载Pyannote模型...")
pyannote_pipeline = load_pyannote_pipeline(PYANNOTE_CONFIG_PATH)
print("加载FireRedASR模型...")
firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH)
if pyannote_pipeline and firered_model:
print("服务就绪")
else:
print("模型加载失败")
yield
app = FastAPI(lifespan=lifespan, title="Audio Transcription Service")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/transcriptions")
async def transcribe_audio(file: UploadFile = File(...)):
global pyannote_pipeline, firered_model, device
if firered_model is None:
raise HTTPException(status_code=500, detail="FireRedASR模型未加载")
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
file_ext = os.path.splitext(file.filename)[1].lower()
if file_ext not in allowed_extensions:
raise HTTPException(status_code=400, detail=f"不支持的文件格式")
temp_audio_path = None
session_folder = None
try:
total_start_time = time.time()
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
temp_audio_path = TEMP_DIR / f"upload_{session_id}{file_ext}"
with open(temp_audio_path, "wb") as f:
content = await file.read()
f.write(content)
audio_info = sf.info(str(temp_audio_path))
duration = audio_info.duration
if duration < DURATION_THRESHOLD:
result = transcribe_single_audio(temp_audio_path, firered_model, device)
temp_transcription = [{
"start": 0,
"end": int(duration * 1000),
"text": result.get("text", ""),
"confidence": result.get("confidence", 0.0),
"segment_type": "direct"
}]
if 'error' in result:
temp_transcription[0]['error'] = result['error']
transcription, post_process_time = post_process_transcription(temp_transcription)
statistics = {
"total_segments": len(transcription),
"processing_time": round(time.time() - total_start_time, 2),
"post_processing_time": round(post_process_time, 2),
"segmentation_types": {
"natural": 0,
"forced": 0,
"direct": len(transcription)
},
"processing_method": "direct_asr"
}
else:
if pyannote_pipeline is None:
raise HTTPException(status_code=500, detail="Pyannote模型未加载")
segment_files, seg_session_id = perform_segmentation(
temp_audio_path,
pyannote_pipeline,
torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
session_folder = TEMP_DIR / seg_session_id
temp_transcription = transcribe_audio_segments(segment_files, firered_model, device)
transcription, post_process_time = post_process_transcription(temp_transcription)
seg_types = {"natural": 0, "forced": 0, "direct": 0}
for seg in transcription:
seg_types[seg["segment_type"]] += 1
statistics = {
"total_segments": len(transcription),
"processing_time": round(time.time() - total_start_time, 2),
"post_processing_time": round(post_process_time, 2),
"segmentation_types": seg_types,
"processing_method": "pipeline_dual_model"
}
response_data = {
"code": 200,
"message": "success",
"data": {
"transcription": transcription,
"statistics": statistics
}
}
return JSONResponse(content=jsonable_encoder(response_data), status_code=200)
except Exception as e:
return JSONResponse(
content={
"code": 500,
"message": f"处理失败: {str(e)}",
"data": {"transcription": [], "statistics": {}}
},
status_code=200
)
finally:
if temp_audio_path and temp_audio_path.exists():
try:
temp_audio_path.unlink()
except:
pass
if session_folder and session_folder.exists():
try:
for file in session_folder.glob("*.wav"):
file.unlink()
session_folder.rmdir()
except:
pass
@app.get("/health")
async def health_check():
return JSONResponse(
content={
"status": "healthy",
"pyannote_loaded": pyannote_pipeline is not None,
"firered_loaded": firered_model is not None,
"device": str(device) if device else "not initialized",
"message": "Service is running"
},
status_code=200
)
@app.get("/")
async def root():
return {
"service": "Audio Transcription Service",
"version": "1.0",
"endpoints": {
"transcriptions": "POST /transcriptions",
"health": "GET /health"
}
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=9000)