You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
865 lines
25 KiB
865 lines
25 KiB
import sys |
|
import os |
|
import time |
|
import torch |
|
import argparse |
|
import uvicorn |
|
import soundfile as sf |
|
import librosa |
|
import numpy as np |
|
import torchaudio |
|
import uuid |
|
import re |
|
from pathlib import Path |
|
from pydub import AudioSegment |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.encoders import jsonable_encoder |
|
from typing import List, Dict, Any |
|
|
|
# ============ 全局配置 ============ |
|
# 设置离线模式 |
|
os.environ['HF_HUB_OFFLINE'] = '1' |
|
os.environ['TRANSFORMERS_OFFLINE'] = '1' |
|
|
|
# FireRedASR路径配置 |
|
FIRERED_ASR_PATH = 'G:/Work_data/workstation/Forwork-voice-txt/FireRedASR' |
|
FIRERED_MODEL_PATH = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L" |
|
|
|
# Pyannote路径配置 |
|
PYANNOTE_CONFIG_PATH = r"G:\Work_data\workstation\Audio_classification\classify_model\speaker-diarization-3.1\config.yaml" |
|
|
|
# 临时文件夹 |
|
TEMP_DIR = Path("temp_transcription") |
|
TEMP_DIR.mkdir(exist_ok=True) |
|
|
|
# 时长阈值(秒) |
|
DURATION_THRESHOLD = 50.0 |
|
|
|
# 全局模型变量 |
|
pyannote_pipeline = None |
|
firered_model = None |
|
device = None |
|
|
|
|
|
# ============ FireRedASR 相关函数 ============ |
|
|
|
def setup_fireredasr_environment(): |
|
"""设置FireRedASR环境""" |
|
if FIRERED_ASR_PATH not in sys.path: |
|
sys.path.insert(0, FIRERED_ASR_PATH) |
|
return True |
|
|
|
|
|
def load_fireredasr_model(model_dir): |
|
"""加载FireRedASR模型""" |
|
setup_fireredasr_environment() |
|
|
|
try: |
|
# 修复PyTorch兼容性 |
|
torch.serialization.add_safe_globals([argparse.Namespace]) |
|
|
|
try: |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
try: |
|
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
import fireredasr |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
model = FireRedAsr.from_pretrained("aed", model_dir) |
|
|
|
device_name = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
if torch.cuda.is_available(): |
|
try: |
|
model = model.to(device_name) |
|
except: |
|
pass |
|
|
|
if hasattr(model, 'eval'): |
|
model.eval() |
|
|
|
print(f"✓ FireRedASR模型加载成功,设备: {device_name}") |
|
return model, device_name |
|
|
|
except Exception as e: |
|
print(f"✗ FireRedASR模型加载失败: {e}") |
|
return None, None |
|
|
|
|
|
def preprocess_audio_for_asr(audio_path, output_path): |
|
""" |
|
音频预处理(FireRedASR标准) |
|
|
|
Args: |
|
audio_path: 原始音频路径 |
|
output_path: 输出路径 |
|
|
|
Returns: |
|
True if success |
|
""" |
|
try: |
|
# 加载音频 |
|
audio_data, sr = sf.read(str(audio_path)) |
|
|
|
# 转换为单声道 |
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
# 重采样到16kHz |
|
if sr != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
sr = 16000 |
|
|
|
# 使用pydub标准化 |
|
temp_file = str(output_path) + ".temp.wav" |
|
sf.write(temp_file, audio_data, sr) |
|
|
|
audio_pydub = AudioSegment.from_file(temp_file) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
audio_pydub.export(str(output_path), format="wav") |
|
|
|
if os.path.exists(temp_file): |
|
os.remove(temp_file) |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f" ✗ 预处理失败: {e}") |
|
return False |
|
|
|
|
|
def transcribe_single_audio(audio_path, model, device_name): |
|
""" |
|
识别单个音频文件 |
|
|
|
Args: |
|
audio_path: 音频文件路径 |
|
model: FireRedASR模型 |
|
device_name: 设备名称 |
|
|
|
Returns: |
|
识别结果字典 {"text": str, "confidence": float} |
|
""" |
|
try: |
|
# 预处理 |
|
temp_path = TEMP_DIR / f"temp_asr_{uuid.uuid4().hex[:8]}.wav" |
|
success = preprocess_audio_for_asr(audio_path, temp_path) |
|
|
|
if not success: |
|
return {"text": "", "confidence": 0.0, "error": "Preprocessing failed"} |
|
|
|
# 识别 |
|
batch_uttid = ["audio_001"] |
|
batch_wav_path = [str(temp_path)] |
|
|
|
config = { |
|
"use_gpu": 1 if device_name.startswith("cuda") else 0, |
|
"beam_size": 5, |
|
"nbest": 1, |
|
"decode_max_len": 0 |
|
} |
|
|
|
with torch.no_grad(): |
|
result = model.transcribe(batch_uttid, batch_wav_path, config) |
|
|
|
# 清理临时文件 |
|
if temp_path.exists(): |
|
temp_path.unlink() |
|
|
|
if result and len(result) > 0: |
|
text = result[0].get('text', '').strip() |
|
confidence = result[0].get('confidence', result[0].get('score', 0.0)) |
|
|
|
if isinstance(confidence, (list, tuple)) and len(confidence) > 0: |
|
confidence = float(confidence[0]) |
|
else: |
|
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 |
|
|
|
return {"text": text, "confidence": round(confidence, 3)} |
|
|
|
return {"text": "", "confidence": 0.0} |
|
|
|
except Exception as e: |
|
print(f" ✗ 识别失败: {e}") |
|
return {"text": "", "confidence": 0.0, "error": str(e)} |
|
|
|
|
|
def transcribe_audio_segments(segment_files, model, device_name): |
|
""" |
|
批量识别音频片段 |
|
|
|
Args: |
|
segment_files: 片段文件信息列表 |
|
model: FireRedASR模型 |
|
device_name: 设备名称 |
|
|
|
Returns: |
|
识别结果列表 |
|
""" |
|
results = [] |
|
|
|
print(f"开始识别 {len(segment_files)} 个音频片段...") |
|
|
|
for i, file_info in enumerate(segment_files): |
|
try: |
|
# 预处理 |
|
temp_path = TEMP_DIR / f"temp_seg_{i:03d}.wav" |
|
success = preprocess_audio_for_asr(file_info['filepath'], temp_path) |
|
|
|
if not success: |
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"text": "", |
|
"confidence": 0.0, |
|
"segment_type": file_info['segment_type'], |
|
"error": "Preprocessing failed" |
|
}) |
|
continue |
|
|
|
# 识别 |
|
batch_uttid = [f"seg_{i:03d}"] |
|
batch_wav_path = [str(temp_path)] |
|
|
|
config = { |
|
"use_gpu": 1 if device_name.startswith("cuda") else 0, |
|
"beam_size": 5, |
|
"nbest": 1, |
|
"decode_max_len": 0 |
|
} |
|
|
|
with torch.no_grad(): |
|
transcription = model.transcribe(batch_uttid, batch_wav_path, config) |
|
|
|
# 清理临时文件 |
|
if temp_path.exists(): |
|
temp_path.unlink() |
|
|
|
if transcription and len(transcription) > 0: |
|
text = transcription[0].get('text', '').strip() |
|
confidence = transcription[0].get('confidence', transcription[0].get('score', 0.0)) |
|
|
|
if isinstance(confidence, (list, tuple)) and len(confidence) > 0: |
|
confidence = float(confidence[0]) |
|
else: |
|
confidence = float(confidence) if isinstance(confidence, (int, float)) else 0.0 |
|
|
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"content": text, |
|
"confidence": round(confidence, 3), |
|
"segment_type": file_info['segment_type'] |
|
}) |
|
|
|
print(f" ✓ [{i + 1}/{len(segment_files)}] {text[:30]}...") |
|
else: |
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"content": "", |
|
"confidence": 0.0, |
|
"segment_type": file_info['segment_type'] |
|
}) |
|
|
|
# 清理GPU缓存 |
|
if device_name.startswith("cuda"): |
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
print(f" ✗ [{i + 1}/{len(segment_files)}] 识别失败: {e}") |
|
results.append({ |
|
"start": file_info['start_ms'], |
|
"end": file_info['end_ms'], |
|
"content": "", |
|
"confidence": 0.0, |
|
"segment_type": file_info['segment_type'], |
|
"error": str(e) |
|
}) |
|
|
|
return results |
|
|
|
|
|
# ============ Pyannote 相关函数 ============ |
|
|
|
def load_pyannote_pipeline(config_path): |
|
"""加载Pyannote模型""" |
|
try: |
|
from pyannote.audio import Pipeline |
|
|
|
pipeline = Pipeline.from_pretrained(config_path) |
|
device_obj = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.to(device_obj) |
|
|
|
print(f"✓ Pyannote模型加载成功,设备: {device_obj}") |
|
return pipeline |
|
|
|
except Exception as e: |
|
print(f"✗ Pyannote模型加载失败: {e}") |
|
return None |
|
|
|
|
|
def load_audio_for_segmentation(audio_path): |
|
"""加载音频用于分段""" |
|
# 使用soundfile加载 |
|
audio_data, sr = sf.read(str(audio_path)) |
|
|
|
# 转换为单声道 |
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
# 重采样到16kHz |
|
if sr != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
sr = 16000 |
|
|
|
# 使用pydub |
|
audio_pydub = AudioSegment.from_file(str(audio_path)) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
return audio_data, sr, audio_pydub |
|
|
|
|
|
def calculate_frame_energy(audio_data, sr, time_point, frame_length=0.025): |
|
"""计算帧能量""" |
|
frame_samples = int(frame_length * sr) |
|
center_sample = int(time_point * sr) |
|
|
|
start_sample = max(0, center_sample - frame_samples // 2) |
|
end_sample = min(len(audio_data), start_sample + frame_samples) |
|
|
|
frame = audio_data[start_sample:end_sample] |
|
|
|
if len(frame) == 0: |
|
return 0.0 |
|
|
|
return np.sqrt(np.mean(frame ** 2)) |
|
|
|
|
|
def find_optimal_boundary(audio_data, sr, start_time, end_time, |
|
frame_length=0.025, hop_length=0.005): |
|
"""在指定时间范围内找到能量最低点""" |
|
if start_time >= end_time: |
|
return start_time, 0.0 |
|
|
|
start_sample = int(start_time * sr) |
|
end_sample = int(end_time * sr) |
|
segment = audio_data[start_sample:end_sample] |
|
|
|
if len(segment) == 0: |
|
return start_time, 0.0 |
|
|
|
frame_samples = int(frame_length * sr) |
|
hop_samples = int(hop_length * sr) |
|
|
|
min_energy = float('inf') |
|
optimal_time = start_time |
|
|
|
for i in range(0, len(segment) - frame_samples + 1, hop_samples): |
|
frame = segment[i:i + frame_samples] |
|
energy = np.sqrt(np.mean(frame ** 2)) |
|
|
|
if energy < min_energy: |
|
min_energy = energy |
|
optimal_time = start_time + (i + frame_samples / 2) / sr |
|
|
|
return optimal_time, min_energy |
|
|
|
|
|
def optimize_segment_boundaries(segments, audio_data, sr, |
|
small_gap_threshold=0.1, |
|
min_gap_to_keep=0.05): |
|
"""优化片段边界""" |
|
if not segments: |
|
return [] |
|
|
|
audio_duration = len(audio_data) / sr |
|
optimized = [] |
|
|
|
for i, seg in enumerate(segments): |
|
current_seg = seg.copy() |
|
|
|
# 优化起始边界 |
|
if i == 0: |
|
search_start = 0.0 |
|
search_end = seg["start"] |
|
else: |
|
prev_end = segments[i - 1]["end"] |
|
search_start = prev_end |
|
search_end = seg["start"] |
|
|
|
gap_size = search_end - search_start |
|
|
|
if gap_size > 0 and gap_size >= small_gap_threshold: |
|
optimal_start, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) |
|
current_seg["start"] = optimal_start |
|
|
|
# 优化结束边界 |
|
if i == len(segments) - 1: |
|
search_start = seg["end"] |
|
search_end = audio_duration |
|
else: |
|
next_start = segments[i + 1]["start"] |
|
search_start = seg["end"] |
|
search_end = next_start |
|
|
|
gap_size = search_end - search_start |
|
|
|
if gap_size > 0 and gap_size >= small_gap_threshold: |
|
optimal_end, _ = find_optimal_boundary(audio_data, sr, search_start, search_end) |
|
current_seg["end"] = optimal_end |
|
|
|
optimized.append(current_seg) |
|
|
|
# 处理小间隙(共享边界) |
|
for i in range(len(optimized) - 1): |
|
gap = optimized[i + 1]["start"] - optimized[i]["end"] |
|
|
|
if 0 < gap < small_gap_threshold: |
|
optimal_point, _ = find_optimal_boundary( |
|
audio_data, sr, |
|
optimized[i]["end"], |
|
optimized[i + 1]["start"] |
|
) |
|
optimized[i]["end"] = optimal_point |
|
optimized[i + 1]["start"] = optimal_point |
|
|
|
# 处理重叠和极小间隙 |
|
final_optimized = [] |
|
for i in range(len(optimized)): |
|
if i == 0: |
|
final_optimized.append(optimized[i]) |
|
continue |
|
|
|
prev_seg = final_optimized[-1] |
|
curr_seg = optimized[i] |
|
|
|
gap = curr_seg["start"] - prev_seg["end"] |
|
|
|
if gap < 0: |
|
midpoint = (prev_seg["end"] + curr_seg["start"]) / 2 |
|
prev_seg["end"] = midpoint |
|
curr_seg["start"] = midpoint |
|
elif 0 < gap < min_gap_to_keep: |
|
prev_seg["end"] = curr_seg["start"] |
|
|
|
final_optimized.append(curr_seg) |
|
|
|
return final_optimized |
|
|
|
|
|
def merge_segments_by_gap(segments, max_duration=50.0): |
|
"""按时间间隔合并片段""" |
|
if not segments: |
|
return [] |
|
|
|
if len(segments) == 1: |
|
return segments |
|
|
|
groups = [[seg] for seg in segments] |
|
|
|
while True: |
|
gaps = [] |
|
for i in range(len(groups) - 1): |
|
group1_end = groups[i][-1]["end"] |
|
group2_start = groups[i + 1][0]["start"] |
|
gap = group2_start - group1_end |
|
|
|
merged_duration = groups[i + 1][-1]["end"] - groups[i][0]["start"] |
|
|
|
gaps.append({ |
|
"index": i, |
|
"gap": gap, |
|
"merged_duration": merged_duration |
|
}) |
|
|
|
if not gaps: |
|
break |
|
|
|
gaps.sort(key=lambda x: x["gap"]) |
|
|
|
merged = False |
|
for gap_info in gaps: |
|
if gap_info["merged_duration"] <= max_duration: |
|
idx = gap_info["index"] |
|
groups[idx].extend(groups[idx + 1]) |
|
groups.pop(idx + 1) |
|
merged = True |
|
break |
|
|
|
if not merged: |
|
break |
|
|
|
merged_segments = [] |
|
for group in groups: |
|
merged_seg = { |
|
"start": group[0]["start"], |
|
"end": group[-1]["end"], |
|
"original_count": len(group), |
|
"is_merged": len(group) > 1 |
|
} |
|
merged_segments.append(merged_seg) |
|
|
|
return merged_segments |
|
|
|
|
|
def perform_segmentation(audio_path, pipeline_model, device_obj): |
|
""" |
|
执行音频分段 |
|
|
|
Args: |
|
audio_path: 音频文件路径 |
|
pipeline_model: Pyannote模型 |
|
device_obj: 设备 |
|
|
|
Returns: |
|
片段列表 |
|
""" |
|
print("\n[Pyannote分段] 开始处理...") |
|
|
|
# 加载音频 |
|
audio_data, sr, audio_pydub = load_audio_for_segmentation(audio_path) |
|
duration = len(audio_data) / sr |
|
print(f" 音频时长: {duration:.1f}秒") |
|
|
|
# Pyannote处理 |
|
waveform, sample_rate = torchaudio.load(str(audio_path)) |
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
|
waveform = resampler(waveform) |
|
sample_rate = 16000 |
|
|
|
waveform = waveform.to(device_obj) |
|
audio_dict = {"waveform": waveform, "sample_rate": sample_rate} |
|
|
|
diarization_result = pipeline_model(audio_dict) |
|
|
|
# 提取初始片段 |
|
initial_segments = [] |
|
if hasattr(diarization_result, 'speaker_diarization'): |
|
annotation = diarization_result.speaker_diarization |
|
for segment, track, speaker in annotation.itertracks(yield_label=True): |
|
initial_segments.append({ |
|
"start": round(segment.start, 3), |
|
"end": round(segment.end, 3) |
|
}) |
|
else: |
|
for segment, track, speaker in diarization_result.itertracks(yield_label=True): |
|
initial_segments.append({ |
|
"start": round(segment.start, 3), |
|
"end": round(segment.end, 3) |
|
}) |
|
|
|
print(f" 初始片段数: {len(initial_segments)}") |
|
|
|
# 合并片段 |
|
merged_segments = merge_segments_by_gap(initial_segments, max_duration=DURATION_THRESHOLD) |
|
print(f" 合并后片段数: {len(merged_segments)}") |
|
|
|
# 边界优化 |
|
optimized_segments = optimize_segment_boundaries(merged_segments, audio_data, sr) |
|
print(f" 边界优化完成") |
|
|
|
# 导出片段 |
|
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" |
|
segment_folder = TEMP_DIR / session_id |
|
segment_folder.mkdir(exist_ok=True) |
|
|
|
segment_files = [] |
|
for i, seg in enumerate(optimized_segments): |
|
start_time = seg["start"] |
|
end_time = seg["end"] |
|
|
|
start_ms = max(0, int(start_time * 1000) - 200) |
|
end_ms = min(len(audio_pydub), int(end_time * 1000) + 200) |
|
|
|
segment_audio = audio_pydub[start_ms:end_ms] |
|
|
|
filename = f"segment_{i:03d}.wav" |
|
filepath = segment_folder / filename |
|
|
|
segment_audio.export(str(filepath), format="wav") |
|
|
|
# 判断segment_type |
|
segment_type = "forced" if seg.get("is_merged", False) else "natural" |
|
|
|
segment_files.append({ |
|
"filepath": filepath, |
|
"start_ms": int(start_time * 1000), |
|
"end_ms": int(end_time * 1000), |
|
"segment_type": segment_type |
|
}) |
|
|
|
print(f" 导出完成: {len(segment_files)} 个文件") |
|
|
|
return segment_files, session_id |
|
|
|
|
|
# ============ FastAPI 应用 ============ |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""应用生命周期管理""" |
|
global pyannote_pipeline, firered_model, device |
|
|
|
print("=" * 60) |
|
print("正在启动音频转录服务...") |
|
print("=" * 60) |
|
|
|
# 加载Pyannote模型 |
|
print("\n[1/2] 加载Pyannote模型...") |
|
pyannote_pipeline = load_pyannote_pipeline(PYANNOTE_CONFIG_PATH) |
|
|
|
# 加载FireRedASR模型 |
|
print("\n[2/2] 加载FireRedASR模型...") |
|
firered_model, device = load_fireredasr_model(FIRERED_MODEL_PATH) |
|
|
|
if pyannote_pipeline is None or firered_model is None: |
|
print("\n✗ 模型加载失败,服务无法启动") |
|
else: |
|
print("\n✓ 所有模型加载成功") |
|
|
|
print("=" * 60) |
|
|
|
yield |
|
|
|
print("\n服务关闭中...") |
|
|
|
|
|
app = FastAPI( |
|
lifespan=lifespan, |
|
title="Audio Transcription Service", |
|
description="音频转录服务(分段+识别)" |
|
) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
@app.post("/transcriptions") |
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
""" |
|
音频转录API |
|
|
|
处理流程: |
|
1. 检查音频时长 |
|
2. < 50s: 直接识别 |
|
3. >= 50s: Pyannote分段 → FireRedASR批量识别 |
|
""" |
|
global pyannote_pipeline, firered_model, device |
|
|
|
# 检查模型 |
|
if firered_model is None: |
|
raise HTTPException(status_code=500, detail="FireRedASR模型未加载") |
|
|
|
# 验证文件格式 |
|
allowed_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] |
|
file_ext = os.path.splitext(file.filename)[1].lower() |
|
if file_ext not in allowed_extensions: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"不支持的文件格式。支持: {', '.join(allowed_extensions)}" |
|
) |
|
|
|
temp_audio_path = None |
|
session_folder = None |
|
|
|
try: |
|
print(f"\n{'=' * 60}") |
|
print(f"收到转录请求: {file.filename}") |
|
|
|
total_start_time = time.time() |
|
|
|
# 保存上传的文件 |
|
session_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" |
|
temp_audio_path = TEMP_DIR / f"upload_{session_id}{file_ext}" |
|
|
|
with open(temp_audio_path, "wb") as f: |
|
content = await file.read() |
|
f.write(content) |
|
|
|
# 检查音频时长 |
|
audio_info = sf.info(str(temp_audio_path)) |
|
duration = audio_info.duration |
|
print(f"音频时长: {duration:.2f}秒") |
|
|
|
# 路由选择 |
|
if duration < DURATION_THRESHOLD: |
|
# 直接识别 |
|
print(f"\n模式: 直接识别(< {DURATION_THRESHOLD}秒)") |
|
|
|
result = transcribe_single_audio(temp_audio_path, firered_model, device) |
|
|
|
transcription = [{ |
|
"start": 0, |
|
"end": int(duration * 1000), |
|
"content": result.get("text", ""), |
|
"confidence": result.get("confidence", 0.0), |
|
"segment_type": "direct" |
|
}] |
|
|
|
statistics = { |
|
"total_segments": 1, |
|
"processing_time": round(time.time() - total_start_time, 2), |
|
"segmentation_types": { |
|
"natural": 0, |
|
"forced": 0, |
|
"direct": 1 |
|
}, |
|
"processing_method": "direct_asr" |
|
} |
|
|
|
else: |
|
# 分段 + 识别 |
|
print(f"\n模式: 分段识别(>= {DURATION_THRESHOLD}秒)") |
|
|
|
if pyannote_pipeline is None: |
|
raise HTTPException(status_code=500, detail="Pyannote模型未加载") |
|
|
|
# 分段 |
|
segment_files, seg_session_id = perform_segmentation( |
|
temp_audio_path, |
|
pyannote_pipeline, |
|
torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
) |
|
|
|
session_folder = TEMP_DIR / seg_session_id |
|
|
|
# 识别 |
|
print(f"\n[FireRedASR识别] 开始处理...") |
|
transcription = transcribe_audio_segments(segment_files, firered_model, device) |
|
|
|
# 统计 |
|
seg_types = {"natural": 0, "forced": 0, "direct": 0} |
|
for seg in transcription: |
|
seg_types[seg["segment_type"]] += 1 |
|
|
|
statistics = { |
|
"total_segments": len(transcription), |
|
"processing_time": round(time.time() - total_start_time, 2), |
|
"segmentation_types": seg_types, |
|
"processing_method": "pipeline_dual_model" |
|
} |
|
|
|
print(f"\n{'=' * 60}") |
|
print(f"处理完成,总耗时: {statistics['processing_time']}秒") |
|
print(f"{'=' * 60}\n") |
|
|
|
# 构建响应 |
|
response_data = { |
|
"code": 200, |
|
"message": "success", |
|
"data": { |
|
"transcription": transcription, |
|
"statistics": statistics |
|
} |
|
} |
|
|
|
return JSONResponse( |
|
content=jsonable_encoder(response_data), |
|
status_code=200 |
|
) |
|
|
|
except Exception as e: |
|
print(f"\n处理错误: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
return JSONResponse( |
|
content={ |
|
"code": 500, |
|
"message": f"处理失败: {str(e)}", |
|
"data": { |
|
"transcription": [], |
|
"statistics": {} |
|
} |
|
}, |
|
status_code=200 |
|
) |
|
|
|
finally: |
|
# 清理临时文件 |
|
if temp_audio_path and temp_audio_path.exists(): |
|
try: |
|
temp_audio_path.unlink() |
|
except: |
|
pass |
|
|
|
if session_folder and session_folder.exists(): |
|
try: |
|
for file in session_folder.glob("*.wav"): |
|
file.unlink() |
|
session_folder.rmdir() |
|
except: |
|
pass |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""健康检查""" |
|
return JSONResponse( |
|
content={ |
|
"status": "healthy", |
|
"pyannote_loaded": pyannote_pipeline is not None, |
|
"firered_loaded": firered_model is not None, |
|
"device": str(device) if device else "not initialized", |
|
"duration_threshold": DURATION_THRESHOLD, |
|
"message": "Service is running" |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
"""服务信息""" |
|
return { |
|
"service": "Audio Transcription Service", |
|
"version": "1.0", |
|
"description": "音频转录服务(智能分段 + 语音识别)", |
|
"features": [ |
|
f"短音频(<{DURATION_THRESHOLD}s)直接识别", |
|
f"长音频(≥{DURATION_THRESHOLD}s)智能分段后识别", |
|
"Pyannote高质量分段", |
|
"FireRedASR高精度识别", |
|
"自动边界优化" |
|
], |
|
"endpoints": { |
|
"transcriptions": "POST /transcriptions - 音频转录", |
|
"health": "GET /health - 健康检查", |
|
"root": "GET / - 服务信息" |
|
}, |
|
"models": { |
|
"segmentation": "Pyannote Speaker Diarization 3.1", |
|
"transcription": "FireRedASR-AED-L" |
|
} |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
print("=" * 60) |
|
print("音频转录服务 v1.0") |
|
print("=" * 60) |
|
print("特性:") |
|
print(f" ✓ 智能路由({DURATION_THRESHOLD}秒阈值)") |
|
print(" ✓ Pyannote分段 + FireRedASR识别") |
|
print(" ✓ 自动边界优化") |
|
print(" ✓ 移除说话人功能") |
|
print("=" * 60) |
|
print("启动服务器(端口9000)...") |
|
print() |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=9000) |