You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
902 lines
28 KiB
902 lines
28 KiB
import sys |
|
|
|
sys.path.append('G:/Work_data/workstation/Forwork-voice-txt/FireRedASR') |
|
|
|
import os |
|
import sys |
|
import time |
|
import json |
|
import numpy as np |
|
import torch |
|
import argparse |
|
import requests |
|
import tempfile |
|
import uvicorn |
|
import yaml |
|
import soundfile as sf |
|
from pathlib import Path |
|
from pydub import AudioSegment |
|
import librosa |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, File, UploadFile |
|
from fastapi.responses import JSONResponse |
|
from fastapi.encoders import jsonable_encoder |
|
|
|
# 全局变量存储FireRedASR模型 |
|
firered_model = None |
|
device = None |
|
|
|
# VAD模型配置 |
|
VAD_MODEL_PATH = "G:/Work_data/广播/音频/vad" |
|
|
|
|
|
def load_audio(audio_path): |
|
"""加载音频文件""" |
|
# 使用soundfile加载音频数据 |
|
audio_data, sr = sf.read(str(audio_path)) |
|
|
|
# 转换为单声道 |
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
# 重采样到16kHz |
|
if sr != 16000: |
|
import librosa |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
sr = 16000 |
|
|
|
# 使用pydub加载用于分段导出 |
|
audio_pydub = AudioSegment.from_file(str(audio_path)) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
return audio_data, sr, audio_pydub |
|
|
|
|
|
class SegmentOptimizer: |
|
"""分段优化器 - 纯算法处理类""" |
|
|
|
def energy_based_segmentation(self, audio_data, sr, |
|
silence_threshold=0.0005, |
|
min_segment_length=1.0, |
|
max_segment_length=50.0): |
|
"""基于能量的音频分段""" |
|
# 计算短时能量 |
|
frame_len = int(0.025 * sr) # 25ms |
|
hop_len = int(0.005 * sr) # 5ms |
|
|
|
energy = [] |
|
for i in range(0, len(audio_data) - frame_len + 1, hop_len): |
|
frame = audio_data[i:i + frame_len] |
|
frame_energy = np.sum(frame ** 2) / len(frame) |
|
energy.append(frame_energy) |
|
|
|
energy = np.array(energy) |
|
if energy.max() > 0: |
|
energy = energy / energy.max() |
|
|
|
# 找静音点 |
|
silence_points = [] |
|
for i, e in enumerate(energy): |
|
if e < silence_threshold: |
|
time_point = i * 0.005 |
|
silence_points.append(time_point) |
|
|
|
if not silence_points: |
|
return self._forced_segmentation(len(audio_data) / sr, max_segment_length) |
|
|
|
# 合并相邻静音点成静音间隔 |
|
silence_intervals = self._merge_silence_points(silence_points) |
|
|
|
if not silence_intervals: |
|
return self._forced_segmentation(len(audio_data) / sr, max_segment_length) |
|
|
|
# 基于静音间隔生成自然分段 |
|
natural_segments = self._create_natural_segments( |
|
silence_intervals, len(audio_data) / sr, min_segment_length |
|
) |
|
|
|
# 检查是否有超长片段需要强制切分 |
|
final_segments = [] |
|
for segment in natural_segments: |
|
duration = segment['end_time'] - segment['start_time'] |
|
if duration > max_segment_length: |
|
forced_parts = self._forced_segmentation(duration, max_segment_length) |
|
for part in forced_parts: |
|
final_segments.append({ |
|
'start_time': segment['start_time'] + part['start_time'], |
|
'end_time': segment['start_time'] + part['end_time'], |
|
'type': 'forced' |
|
}) |
|
else: |
|
final_segments.append(segment) |
|
|
|
return final_segments |
|
|
|
def _merge_silence_points(self, silence_points): |
|
"""合并相邻的静音点成静音间隔""" |
|
if not silence_points: |
|
return [] |
|
|
|
silence_intervals = [] |
|
current_start = silence_points[0] |
|
current_end = silence_points[0] |
|
|
|
for point in silence_points[1:]: |
|
if point - current_end <= 0.1: |
|
current_end = point |
|
else: |
|
if current_end - current_start >= 0.05: |
|
silence_intervals.append((current_start, current_end)) |
|
current_start = point |
|
current_end = point |
|
|
|
if current_end - current_start >= 0.05: |
|
silence_intervals.append((current_start, current_end)) |
|
|
|
return silence_intervals |
|
|
|
def _create_natural_segments(self, silence_intervals, audio_duration, min_segment_length): |
|
"""基于静音间隔创建自然分段""" |
|
segments = [] |
|
last_end = 0.0 |
|
|
|
for silence_start, silence_end in silence_intervals: |
|
if silence_start - last_end >= min_segment_length: |
|
segments.append({ |
|
'start_time': last_end, |
|
'end_time': silence_start, |
|
'type': 'natural' |
|
}) |
|
elif segments: |
|
segments[-1]['end_time'] = silence_start |
|
|
|
last_end = silence_end |
|
|
|
if audio_duration - last_end >= min_segment_length: |
|
segments.append({ |
|
'start_time': last_end, |
|
'end_time': audio_duration, |
|
'type': 'natural' |
|
}) |
|
elif segments: |
|
segments[-1]['end_time'] = audio_duration |
|
|
|
return segments |
|
|
|
def _forced_segmentation(self, duration, max_segment_length): |
|
"""强制按时长切分""" |
|
segments = [] |
|
num_segments = int(np.ceil(duration / max_segment_length)) |
|
segment_length = duration / num_segments |
|
|
|
for i in range(num_segments): |
|
start_time = i * segment_length |
|
end_time = min((i + 1) * segment_length, duration) |
|
segments.append({ |
|
'start_time': start_time, |
|
'end_time': end_time, |
|
'type': 'forced' |
|
}) |
|
|
|
return segments |
|
|
|
def _merge_adjacent_segments(self, segments, max_duration=50.0): |
|
"""智能拼接相邻片段""" |
|
if not segments: |
|
return [] |
|
|
|
merged_segments = [] |
|
current_group = [segments[0]] |
|
|
|
for i in range(1, len(segments)): |
|
current_seg = segments[i] |
|
group_start = current_group[0][0] |
|
group_end = current_seg[1] |
|
total_duration = group_end - group_start |
|
|
|
if total_duration <= max_duration: |
|
current_group.append(current_seg) |
|
else: |
|
merged_segment = self._create_merged_segment(current_group) |
|
merged_segments.append(merged_segment) |
|
current_group = [current_seg] |
|
|
|
if current_group: |
|
merged_segment = self._create_merged_segment(current_group) |
|
merged_segments.append(merged_segment) |
|
|
|
return merged_segments |
|
|
|
def _create_merged_segment(self, segment_group): |
|
"""创建拼接片段""" |
|
if len(segment_group) == 1: |
|
return segment_group[0] |
|
|
|
start_time = segment_group[0][0] |
|
end_time = segment_group[-1][1] |
|
|
|
types = [seg[2] for seg in segment_group] |
|
unique_types = list(set(types)) |
|
if len(unique_types) == 1: |
|
merged_type = f"M{unique_types[0]}" |
|
else: |
|
merged_type = "MIXED" |
|
|
|
return (start_time, end_time, merged_type) |
|
|
|
def extract_coarse_segments(self, audio_data, sample_rate, vad_result): |
|
"""提取VAD粗分段""" |
|
segments = [] |
|
min_duration = 0.1 |
|
|
|
for result in vad_result: |
|
if 'value' in result and result['value']: |
|
for segment_info in result['value']: |
|
start_time = segment_info[0] / 1000.0 |
|
end_time = segment_info[1] / 1000.0 |
|
duration = end_time - start_time |
|
|
|
if duration >= min_duration: |
|
segments.append((start_time, end_time, 'vad_coarse')) |
|
|
|
return segments |
|
|
|
def three_level_segmentation(self, audio_data, sample_rate, vad_result): |
|
"""三级分段处理主流程""" |
|
duration = len(audio_data) / sample_rate |
|
|
|
# 第一级:提取VAD粗分段 |
|
coarse_segments = self.extract_coarse_segments(audio_data, sample_rate, vad_result) |
|
|
|
if not coarse_segments: |
|
# 兜底:对整个音频进行强制分段 |
|
fallback_segments = self._forced_segmentation(duration, 50.0) |
|
return [(seg['start_time'], seg['end_time'], seg['type']) |
|
for seg in fallback_segments] |
|
|
|
# 第二级:对长片段进行能量细分段 |
|
fine_segments = [] |
|
|
|
for start_time, end_time, seg_type in coarse_segments: |
|
duration = end_time - start_time |
|
|
|
if duration >= 50.0: |
|
# 提取片段音频 |
|
start_sample = int(start_time * sample_rate) |
|
end_sample = int(end_time * sample_rate) |
|
segment_audio = audio_data[start_sample:end_sample] |
|
|
|
# 进行能量分析细分段 |
|
sub_segments = self.energy_based_segmentation( |
|
segment_audio, sample_rate, |
|
silence_threshold=0.0005, |
|
min_segment_length=1.0, |
|
max_segment_length=50.0 |
|
) |
|
|
|
# 将相对时间转换为绝对时间 |
|
for sub_seg in sub_segments: |
|
abs_start = start_time + sub_seg['start_time'] |
|
abs_end = start_time + sub_seg['end_time'] |
|
fine_segments.append((abs_start, abs_end, sub_seg['type'])) |
|
else: |
|
fine_segments.append((start_time, end_time, 'direct')) |
|
|
|
# 第三级:智能拼接相邻片段 |
|
merged_segments = self._merge_adjacent_segments(fine_segments, max_duration=50.0) |
|
|
|
return merged_segments |
|
|
|
|
|
def setup_vad_model_config(model_path): |
|
"""设置VAD模型配置参数""" |
|
config_file = os.path.join(model_path, "config.yaml") |
|
if not os.path.exists(config_file): |
|
return None |
|
|
|
with open(config_file, 'r', encoding='utf-8') as f: |
|
config = yaml.safe_load(f) |
|
|
|
# 更新参数以减少片段数量 |
|
config['model_conf'].update({ |
|
"max_end_silence_time": 525, |
|
"max_start_silence_time": 2000, |
|
"sil_to_speech_time_thres": 150, |
|
"speech_to_sil_time_thres": 150, |
|
"max_single_segment_time": 16000000, |
|
"speech_2_noise_ratio": 1.5, |
|
"speech_noise_thres": 0.7, |
|
}) |
|
|
|
# 保存临时配置文件 |
|
temp_config_file = os.path.join(model_path, "config_temp.yaml") |
|
with open(temp_config_file, 'w', encoding='utf-8') as f: |
|
yaml.dump(config, f, default_flow_style=False, allow_unicode=True) |
|
|
|
return temp_config_file |
|
|
|
|
|
def load_and_run_vad(audio_path, model_path): |
|
"""临时加载VAD模型进行分段,处理完后释放""" |
|
if not Path(model_path).exists(): |
|
return None |
|
|
|
# 设置模型配置 |
|
temp_config_file = setup_vad_model_config(model_path) |
|
if temp_config_file is None: |
|
return None |
|
|
|
# 临时替换配置文件 |
|
original_config = os.path.join(model_path, "config.yaml") |
|
backup_config = os.path.join(model_path, "config_backup.yaml") |
|
|
|
try: |
|
# 备份原配置 |
|
if not os.path.exists(backup_config): |
|
os.rename(original_config, backup_config) |
|
|
|
# 使用新配置 |
|
os.rename(temp_config_file, original_config) |
|
|
|
# 导入并加载VAD模型 |
|
from funasr import AutoModel |
|
vad_model = AutoModel(model=model_path, device="cpu") |
|
|
|
# 执行VAD检测 |
|
vad_result = vad_model.generate(input=audio_path) |
|
|
|
# 显式删除模型释放内存 |
|
del vad_model |
|
|
|
return vad_result |
|
|
|
except Exception as e: |
|
print(f"VAD处理失败: {e}") |
|
return None |
|
finally: |
|
# 恢复原配置 |
|
if os.path.exists(original_config): |
|
os.remove(original_config) |
|
if os.path.exists(backup_config): |
|
os.rename(backup_config, original_config) |
|
|
|
|
|
def create_audio_segments_from_final_segments(segments, audio_pydub): |
|
"""根据最终分段结果创建音频片段文件""" |
|
temp_dir = Path("temp_segments") |
|
temp_dir.mkdir(exist_ok=True) |
|
|
|
segment_files = [] |
|
for i, (start_time, end_time, seg_type) in enumerate(segments): |
|
# 转换为毫秒并添加25ms padding |
|
start_ms = max(0, int(start_time * 1000) - 25) |
|
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25) |
|
|
|
# 提取并保存音频片段 |
|
segment = audio_pydub[start_ms:end_ms] |
|
segment_file = temp_dir / f"segment_{i:03d}_{seg_type}.wav" |
|
segment.export(str(segment_file), format="wav") |
|
|
|
segment_files.append({ |
|
'file': segment_file, |
|
'start_time': start_time, |
|
'end_time': end_time, |
|
'index': i, |
|
'type': seg_type |
|
}) |
|
|
|
return segment_files |
|
|
|
|
|
def setup_fireredasr_environment(): |
|
possible_paths = [ |
|
"./FireRedASR", |
|
"../FireRedASR", |
|
"FireRedASR" |
|
] |
|
|
|
for path in possible_paths: |
|
if Path(path).exists(): |
|
if str(Path(path).absolute()) not in sys.path: |
|
sys.path.insert(0, str(Path(path).absolute())) |
|
return True |
|
return False |
|
|
|
|
|
def load_fireredasr_model(model_dir): |
|
"""加载FireRedASR模型""" |
|
# 设置环境 |
|
if not setup_fireredasr_environment(): |
|
pass |
|
|
|
try: |
|
# 修复PyTorch兼容性 |
|
torch.serialization.add_safe_globals([argparse.Namespace]) |
|
|
|
# 尝试多种导入方式 |
|
try: |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
try: |
|
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
import fireredasr |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
model = FireRedAsr.from_pretrained("aed", model_dir) |
|
|
|
# 尝试使用GPU |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
if torch.cuda.is_available(): |
|
try: |
|
model = model.to(device) |
|
except: |
|
pass |
|
|
|
if hasattr(model, 'eval'): |
|
model.eval() |
|
return model, device |
|
|
|
except Exception as e: |
|
print(f"FireRedASR模型加载失败: {e}") |
|
return None, None |
|
|
|
|
|
def transcribe_segments(segment_files, model, device): |
|
"""FireRedASR语音识别""" |
|
results = [] |
|
use_gpu = device.startswith("cuda") |
|
|
|
print(f"开始识别 {len(segment_files)} 个音频片段...") |
|
|
|
for i, segment_info in enumerate(segment_files): |
|
segment_file = segment_info['file'] |
|
start_time = segment_info['start_time'] |
|
end_time = segment_info['end_time'] |
|
seg_type = segment_info['type'] |
|
|
|
try: |
|
batch_uttid = [f"segment_{i:03d}"] |
|
batch_wav_path = [str(segment_file)] |
|
|
|
config = { |
|
"use_gpu": 1 if use_gpu else 0, |
|
"beam_size": 5, |
|
"nbest": 1, |
|
"decode_max_len": 0 |
|
} |
|
|
|
with torch.no_grad(): |
|
transcription_result = model.transcribe( |
|
batch_uttid, batch_wav_path, config |
|
) |
|
|
|
if transcription_result and len(transcription_result) > 0: |
|
result = transcription_result[0] |
|
text = result.get('text', '').strip() |
|
|
|
# 提取置信度 |
|
confidence = result.get('confidence', result.get('score', 0.0)) |
|
if isinstance(confidence, (list, tuple)) and len(confidence) > 0: |
|
confidence = float(confidence[0]) |
|
elif not isinstance(confidence, (int, float)): |
|
confidence = 0.0 |
|
else: |
|
confidence = float(confidence) |
|
|
|
if text: |
|
results.append({ |
|
'start_time': start_time, |
|
'end_time': end_time, |
|
'text': text, |
|
'confidence': confidence, |
|
'segment_type': seg_type |
|
}) |
|
|
|
# 清理GPU缓存 |
|
if use_gpu: |
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
print(f"片段 {i + 1} 识别失败: {e}") |
|
continue |
|
|
|
print(f"识别完成: {len(results)}/{len(segment_files)} 个片段") |
|
return results |
|
|
|
|
|
# def call_ai_model(sentence): |
|
# """调用AI模型进行文本后处理""" |
|
# url = "http://192.168.3.8:7777/v1/chat/completions" |
|
# |
|
# prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。 |
|
# {sentence}""" |
|
# |
|
# payload = { |
|
# "model": "Qwen3-32B", |
|
# "messages": [ |
|
# {"role": "user", "content": prompt} |
|
# ], |
|
# "chat_template_kwargs": {"enable_thinking": False} |
|
# } |
|
# |
|
# try: |
|
# response = requests.post(url, json=payload, timeout=120) |
|
# response.raise_for_status() |
|
# result = response.json() |
|
# processed_text = result["choices"][0]["message"]["content"].strip() |
|
# return processed_text |
|
# except requests.exceptions.RequestException: |
|
# return sentence # 失败时返回原文 |
|
# |
|
|
|
def process_transcription_results(results): |
|
"""处理转录结果""" |
|
if not results: |
|
return [] |
|
|
|
# 按时间排序 |
|
results.sort(key=lambda x: x['start_time']) |
|
|
|
processed_data = [] |
|
print(f"开始文本后处理 {len(results)} 个片段...") |
|
|
|
for item in results: |
|
start_time = item['start_time'] |
|
end_time = item['end_time'] |
|
original_content = item['text'] |
|
confidence = item['confidence'] |
|
seg_type = item.get('segment_type', 'unknown') |
|
|
|
# 调用AI模型处理 |
|
# processed_content = call_ai_model(original_content) |
|
|
|
processed_data.append({ |
|
'start': int(start_time * 1000), # 转换为毫秒 |
|
'end': int(end_time * 1000), |
|
'content': original_content, |
|
'confidence': round(confidence, 3), |
|
'segment_type': seg_type |
|
}) |
|
|
|
return processed_data |
|
|
|
|
|
def cleanup_temp_files(): |
|
"""清理临时文件""" |
|
temp_dir = Path("temp_segments") |
|
if temp_dir.exists(): |
|
for file in temp_dir.glob("segment_*.wav"): |
|
file.unlink(missing_ok=True) |
|
try: |
|
temp_dir.rmdir() |
|
except: |
|
pass |
|
|
|
|
|
def process_audio_file(audio_path): |
|
|
|
global firered_model, device |
|
|
|
try: |
|
print(f"开始处理音频文件...") |
|
|
|
# 第一阶段:音频预处理 |
|
audio_data, sr, audio_pydub = load_audio(audio_path) |
|
duration = len(audio_data) / sr |
|
print(f"音频时长: {duration:.1f}秒") |
|
|
|
# 第二阶段:VAD分段处理 |
|
print("执行VAD分段...") |
|
vad_result = load_and_run_vad(audio_path, VAD_MODEL_PATH) |
|
|
|
if vad_result is None: |
|
print("VAD分段失败,使用兜底分段方案") |
|
# 兜底方案:简单的固定时长分段 |
|
num_segments = int(np.ceil(duration / 50.0)) |
|
segment_length = duration / num_segments |
|
|
|
final_segments = [] |
|
for i in range(num_segments): |
|
start_time = i * segment_length |
|
end_time = min((i + 1) * segment_length, duration) |
|
final_segments.append((start_time, end_time, 'fallback')) |
|
else: |
|
# 第三阶段:智能分段优化 |
|
print("执行智能分段优化...") |
|
optimizer = SegmentOptimizer() |
|
final_segments = optimizer.three_level_segmentation(audio_data, sr, vad_result) |
|
|
|
if not final_segments: |
|
print("未能获得有效的音频分段") |
|
return [] |
|
|
|
print(f"分段完成: {len(final_segments)} 个片段") |
|
|
|
# 第四阶段:生成音频片段文件 |
|
segment_files = create_audio_segments_from_final_segments(final_segments, audio_pydub) |
|
|
|
# 第五阶段:FireRedASR语音识别 |
|
results = transcribe_segments(segment_files, firered_model, device) |
|
|
|
# 第六阶段:文本后处理 |
|
processed_data = process_transcription_results(results) |
|
|
|
print(f"处理完成: {len(processed_data)} 个转录片段") |
|
|
|
return processed_data |
|
|
|
except Exception as e: |
|
print(f"处理过程中发生错误: {e}") |
|
return [] |
|
finally: |
|
cleanup_temp_files() |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""FastAPI应用生命周期管理""" |
|
global firered_model, device |
|
|
|
print("正在启动服务...") |
|
|
|
# 模型路径配置 |
|
firered_model_dir = "G:/Work_data/workstation/Forwork-voice-txt/FireRedASR/pretrained_models/FireRedASR-AED-L" |
|
|
|
# 检查VAD模型路径 |
|
if Path(VAD_MODEL_PATH).exists(): |
|
print("VAD模型路径检查通过") |
|
else: |
|
print(f"警告: VAD模型目录不存在: {VAD_MODEL_PATH}") |
|
|
|
# 加载FireRedASR模型 |
|
if Path(firered_model_dir).exists(): |
|
print("加载FireRedASR模型...") |
|
firered_model, device = load_fireredasr_model(firered_model_dir) |
|
if firered_model is not None: |
|
print("FireRedASR模型加载成功") |
|
else: |
|
print("FireRedASR模型加载失败") |
|
else: |
|
print(f"错误: FireRedASR模型目录不存在: {firered_model_dir}") |
|
|
|
# 检查服务就绪状态 |
|
if firered_model is None: |
|
print("警告: FireRedASR模型未加载,语音识别功能不可用") |
|
else: |
|
print("服务已就绪") |
|
|
|
yield # 应用运行期间 |
|
|
|
# 关闭时执行 |
|
cleanup_temp_files() |
|
|
|
|
|
# 创建FastAPI应用实例 |
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
@app.post("/transcriptions") |
|
async def create_file(file: UploadFile = File(...)): |
|
"""音频转录API端点""" |
|
temp_file_path = None |
|
|
|
try: |
|
# 检查FireRedASR模型是否已加载 |
|
if firered_model is None: |
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": "FireRedASR模型未加载,请检查模型路径并重启服务", |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
# 检查文件类型 |
|
if not file.filename.lower().endswith(('.mp3', '.wav', '.m4a', '.flac')): |
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": "不支持的音频格式,请上传mp3、wav、m4a或flac文件", |
|
"code": 400 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
# 创建临时文件 |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: |
|
temp_file_path = temp_file.name |
|
contents = await file.read() |
|
temp_file.write(contents) |
|
|
|
print(f"开始处理音频文件: {file.filename}") |
|
start_time = time.time() |
|
|
|
# 流水线式处理音频文件 |
|
result_data = process_audio_file(temp_file_path) |
|
|
|
elapsed_time = time.time() - start_time |
|
print(f"总耗时: {elapsed_time:.2f}秒") |
|
|
|
# 统计分段类型 |
|
segmentation_stats = {} |
|
for item in result_data: |
|
seg_type = item.get('segment_type', 'unknown') |
|
segmentation_stats[seg_type] = segmentation_stats.get(seg_type, 0) + 1 |
|
|
|
# 构建响应数据 |
|
response_data = { |
|
"transcription": result_data, |
|
"statistics": { |
|
"total_segments": len(result_data), |
|
"processing_time": round(elapsed_time, 2), |
|
"segmentation_types": segmentation_stats, |
|
"processing_method": "pipeline_dual_model" |
|
} |
|
} |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": jsonable_encoder(response_data), |
|
"message": "success", |
|
"code": 200 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
except Exception as e: |
|
print(f"处理错误: {str(e)}") |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": str(e), |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
finally: |
|
# 确保删除临时文件 |
|
if temp_file_path and os.path.exists(temp_file_path): |
|
try: |
|
os.remove(temp_file_path) |
|
except: |
|
pass |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""健康检查接口""" |
|
vad_available = Path(VAD_MODEL_PATH).exists() |
|
|
|
return JSONResponse( |
|
content={ |
|
"status": "healthy", |
|
"firered_model_loaded": firered_model is not None, |
|
"vad_model_available": vad_available, |
|
"architecture": "pipeline_dual_model", |
|
"description": "FireRedASR常驻,VAD按需加载", |
|
"message": "Pipeline dual-model service is running" |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
@app.get("/segmentation/info") |
|
async def segmentation_info(): |
|
"""分段信息接口""" |
|
return JSONResponse( |
|
content={ |
|
"architecture": "流水线式双模型架构", |
|
"segmentation_types": { |
|
"vad_coarse": "VAD粗分段", |
|
"direct": "直接保留 (< 50s)", |
|
"natural": "自然切分 (静音处)", |
|
"forced": "强制切分 (时长)", |
|
"Mdirect": "拼接-直接保留", |
|
"Mnatural": "拼接-自然切分", |
|
"Mforced": "拼接-强制切分", |
|
"MIXED": "拼接-混合类型", |
|
"fallback": "兜底分段" |
|
}, |
|
"parameters": { |
|
"vad_max_segment": "160秒", |
|
"energy_max_segment": "50秒", |
|
"merge_max_duration": "50秒", |
|
"silence_threshold": "0.0005" |
|
} |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
@app.post("/test/vad_only") |
|
async def test_vad_segmentation(file: UploadFile = File(...)): |
|
"""测试VAD分段功能""" |
|
temp_file_path = None |
|
|
|
try: |
|
# 创建临时文件 |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: |
|
temp_file_path = temp_file.name |
|
contents = await file.read() |
|
temp_file.write(contents) |
|
|
|
start_time = time.time() |
|
|
|
# 音频预处理 |
|
audio_data, sr, _ = load_audio(temp_file_path) |
|
|
|
# VAD分段处理 |
|
vad_result = load_and_run_vad(temp_file_path, VAD_MODEL_PATH) |
|
|
|
if vad_result is None: |
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": "VAD分段失败", |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
# 智能分段优化 |
|
optimizer = SegmentOptimizer() |
|
final_segments = optimizer.three_level_segmentation(audio_data, sr, vad_result) |
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
# 格式化分段结果 |
|
segments_info = [] |
|
for i, (start_time, end_time, seg_type) in enumerate(final_segments): |
|
segments_info.append({ |
|
"index": i + 1, |
|
"start_time": round(start_time, 2), |
|
"end_time": round(end_time, 2), |
|
"duration": round(end_time - start_time, 2), |
|
"type": seg_type |
|
}) |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": { |
|
"segments": segments_info, |
|
"total_segments": len(final_segments), |
|
"processing_time": round(elapsed_time, 2) |
|
}, |
|
"message": "VAD分段测试成功", |
|
"code": 200 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
except Exception as e: |
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": f"VAD分段测试失败: {str(e)}", |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
finally: |
|
if temp_file_path and os.path.exists(temp_file_path): |
|
try: |
|
os.remove(temp_file_path) |
|
except: |
|
pass |
|
|
|
|
|
if __name__ == '__main__': |
|
print("FireRedASR API服务启动中(流水线式双模型架构)") |
|
print("架构特点:") |
|
print(" - VAD模型:按需临时加载,专门负责分段") |
|
print(" - FireRedASR模型:常驻内存,专门负责识别") |
|
print(" - 流水线处理:分阶段执行,内存优化") |
|
print(" - 智能分段:三级处理算法,质量保证") |
|
print(" - 兜底机制:VAD失败时自动降级") |
|
print("启动服务器...") |
|
uvicorn.run(app, host='0.0.0.0', port=7777) |