|
|
import sys |
|
|
|
|
|
sys.path.append('D:\\workstation\\voice-txt\\FireRedASR-test\\FireRedASR') |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import json |
|
|
import numpy as np |
|
|
import torch |
|
|
import argparse |
|
|
import requests |
|
|
import tempfile |
|
|
import uvicorn |
|
|
from pathlib import Path |
|
|
from pydub import AudioSegment |
|
|
import librosa |
|
|
from contextlib import asynccontextmanager |
|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.encoders import jsonable_encoder |
|
|
|
|
|
|
|
|
model = None |
|
|
device = None |
|
|
|
|
|
|
|
|
def format_time_hms(seconds): |
|
|
|
|
|
hours = int(seconds // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
secs = int(seconds % 60) |
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}" |
|
|
|
|
|
|
|
|
def load_audio(audio_path): |
|
|
"""加载音频文件""" |
|
|
print(f" 加载音频: {audio_path}") |
|
|
|
|
|
# 使用librosa加载音频数据用于能量分析 |
|
|
audio_data, sr = librosa.load(str(audio_path), sr=16000, mono=True) |
|
|
|
|
|
# 使用pydub加载用于分段导出 |
|
|
audio_pydub = AudioSegment.from_file(str(audio_path)) |
|
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
|
|
duration = len(audio_data) / sr |
|
|
print(f" 时长: {duration:.2f}秒, 采样率: {sr}Hz") |
|
|
|
|
|
return audio_data, sr, audio_pydub |
|
|
|
|
|
|
|
|
def energy_based_segmentation(audio_data, sr, |
|
|
silence_threshold=0.001, |
|
|
min_segment_length=1.0, |
|
|
max_segment_length=455.0): |
|
|
|
|
|
# print(" 能量分段分析...") |
|
|
|
|
|
# 计算短时能量 |
|
|
frame_len = int(0.025 * sr) # 25ms |
|
|
hop_len = int(0.005 * sr) # 5ms |
|
|
|
|
|
energy = [] |
|
|
for i in range(0, len(audio_data) - frame_len + 1, hop_len): |
|
|
frame = audio_data[i:i + frame_len] |
|
|
frame_energy = np.sum(frame ** 2) / len(frame) |
|
|
energy.append(frame_energy) |
|
|
|
|
|
energy = np.array(energy) |
|
|
if energy.max() > 0: |
|
|
energy = energy / energy.max() |
|
|
|
|
|
# 找静音点 |
|
|
silence_points = [] |
|
|
for i, e in enumerate(energy): |
|
|
if e < silence_threshold: |
|
|
time_point = i * 0.005 # hop_length转换为秒 |
|
|
silence_points.append(time_point) |
|
|
|
|
|
if not silence_points: |
|
|
# print(" 未检测到静音,使用固定分段") |
|
|
return fixed_segmentation(len(audio_data) / sr, max_segment_length) |
|
|
|
|
|
# 合并相邻静音 |
|
|
silence_intervals = [] |
|
|
if silence_points: |
|
|
current_start = silence_points[0] |
|
|
current_end = silence_points[0] |
|
|
|
|
|
for point in silence_points[1:]: |
|
|
if point - current_end <= 0.1: # 0.1s内视为连续 |
|
|
current_end = point |
|
|
else: |
|
|
if current_end - current_start >= 0.05: # 至少50ms |
|
|
silence_intervals.append((current_start, current_end)) |
|
|
current_start = point |
|
|
current_end = point |
|
|
|
|
|
if current_end - current_start >= 0.05: |
|
|
silence_intervals.append((current_start, current_end)) |
|
|
|
|
|
# 生成语音段 |
|
|
segments = [] |
|
|
last_end = 0.0 |
|
|
audio_duration = len(audio_data) / sr |
|
|
|
|
|
for silence_start, silence_end in silence_intervals: |
|
|
if silence_start - last_end >= min_segment_length: |
|
|
segments.append({ |
|
|
'start_time': last_end, |
|
|
'end_time': silence_start, |
|
|
'type': 'natural' |
|
|
}) |
|
|
elif segments and silence_start - segments[-1]['start_time'] <= max_segment_length: |
|
|
segments[-1]['end_time'] = silence_start |
|
|
|
|
|
last_end = silence_end |
|
|
|
|
|
# 处理末尾 |
|
|
if audio_duration - last_end >= min_segment_length: |
|
|
segments.append({ |
|
|
'start_time': last_end, |
|
|
'end_time': audio_duration, |
|
|
'type': 'natural' |
|
|
}) |
|
|
elif segments: |
|
|
segments[-1]['end_time'] = audio_duration |
|
|
|
|
|
# 处理过长段落 |
|
|
final_segments = [] |
|
|
for segment in segments: |
|
|
duration = segment['end_time'] - segment['start_time'] |
|
|
if duration > max_segment_length: |
|
|
num_subsegments = int(np.ceil(duration / max_segment_length)) |
|
|
sub_duration = duration / num_subsegments |
|
|
|
|
|
for i in range(num_subsegments): |
|
|
sub_start = segment['start_time'] + i * sub_duration |
|
|
sub_end = min(sub_start + sub_duration, segment['end_time']) |
|
|
final_segments.append({ |
|
|
'start_time': sub_start, |
|
|
'end_time': sub_end, |
|
|
'type': 'forced' |
|
|
}) |
|
|
else: |
|
|
final_segments.append(segment) |
|
|
|
|
|
print(f" 完成分段: {len(final_segments)}个片段") |
|
|
return final_segments |
|
|
|
|
|
|
|
|
def refine_long_segments(segments, audio_data, sr, max_duration=60.0): |
|
|
|
|
|
# print(f" 检查并细分超过{max_duration}秒的片段...") |
|
|
|
|
|
# 更细致的分段参数 |
|
|
refined_params = { |
|
|
'silence_threshold': 0.0001, # 更敏感的静音检测 |
|
|
'min_segment_length': 0.5, # 更短的最小段长 |
|
|
'max_segment_length': 45.0 # 确保不超过模型限制 |
|
|
} |
|
|
|
|
|
refined_segments = [] |
|
|
refined_count = 0 |
|
|
|
|
|
for i, segment in enumerate(segments): |
|
|
duration = segment['end_time'] - segment['start_time'] |
|
|
|
|
|
if duration > max_duration: |
|
|
# print(f" 片段 {i + 1}: {duration:.1f}s > {max_duration}s,需要细分") |
|
|
refined_count += 1 |
|
|
|
|
|
# 提取对应的音频数据段 |
|
|
start_sample = int(segment['start_time'] * sr) |
|
|
end_sample = int(segment['end_time'] * sr) |
|
|
segment_audio = audio_data[start_sample:end_sample] |
|
|
|
|
|
# 对该片段进行更细致的分段 |
|
|
sub_segments = energy_based_segmentation( |
|
|
segment_audio, sr, |
|
|
silence_threshold=refined_params['silence_threshold'], |
|
|
min_segment_length=refined_params['min_segment_length'], |
|
|
max_segment_length=refined_params['max_segment_length'] |
|
|
) |
|
|
|
|
|
# 调整时间偏移(相对于原始音频的时间) |
|
|
for sub_segment in sub_segments: |
|
|
sub_segment['start_time'] += segment['start_time'] |
|
|
sub_segment['end_time'] += segment['start_time'] |
|
|
sub_segment['type'] = 'refined' # 标记为细分片段 |
|
|
refined_segments.append(sub_segment) |
|
|
|
|
|
# print(f" 细分为 {len(sub_segments)} 个子片段") |
|
|
|
|
|
else: |
|
|
# 保持原片段不变 |
|
|
refined_segments.append(segment) |
|
|
|
|
|
# print(f" 完成细分: {refined_count}个超长片段被处理") |
|
|
# print(f" 最终片段数: {len(segments)} -> {len(refined_segments)}") |
|
|
|
|
|
# 验证所有片段都符合时长要求 |
|
|
oversized_count = sum(1 for seg in refined_segments |
|
|
if seg['end_time'] - seg['start_time'] > max_duration) |
|
|
if oversized_count > 0: |
|
|
print(f" 警告: 仍有 {oversized_count} 个片段超过{max_duration}秒") |
|
|
else: |
|
|
print(f" 所有片段均在{max_duration}秒以内") |
|
|
|
|
|
return refined_segments |
|
|
|
|
|
|
|
|
def merge_segments_sequentially(segments, max_duration=45.0): |
|
|
|
|
|
# print(f" 开始片段顺序拼装 (最大时长: {max_duration}秒)...") |
|
|
|
|
|
if not segments: |
|
|
return [] |
|
|
|
|
|
# 按开始时间排序确保顺序正确 |
|
|
segments_sorted = sorted(segments, key=lambda x: x['start_time']) |
|
|
|
|
|
merged_segments = [] |
|
|
current_group = [] |
|
|
current_start = None |
|
|
current_end = None |
|
|
|
|
|
for i, segment in enumerate(segments_sorted): |
|
|
segment_duration = segment['end_time'] - segment['start_time'] |
|
|
|
|
|
# 如果是第一个片段,直接加入 |
|
|
if not current_group: |
|
|
current_group.append(segment) |
|
|
current_start = segment['start_time'] |
|
|
current_end = segment['end_time'] |
|
|
print(f" 开始新组: [{i + 1}] {current_start:.1f}s-{current_end:.1f}s (时长: {segment_duration:.1f}s)") |
|
|
continue |
|
|
|
|
|
# 计算加入当前片段后的总时长 |
|
|
potential_duration = segment['end_time'] - current_start |
|
|
|
|
|
# 检查是否可以加入当前组 |
|
|
if potential_duration <= max_duration: |
|
|
# 可以加入 |
|
|
current_group.append(segment) |
|
|
current_end = segment['end_time'] |
|
|
# print( |
|
|
# f" 加入组: [{i + 1}] {segment['start_time']:.1f}s-{segment['end_time']:.1f}s, 组总时长: {potential_duration:.1f}s") |
|
|
else: |
|
|
# 不能加入,完成当前组并开始新组 |
|
|
merged_segment = { |
|
|
'start_time': current_start, |
|
|
'end_time': current_end, |
|
|
'duration': current_end - current_start, |
|
|
'type': 'merged', |
|
|
'original_segments': current_group.copy(), |
|
|
'segment_count': len(current_group) |
|
|
} |
|
|
merged_segments.append(merged_segment) |
|
|
# print( |
|
|
# f" 完成组: {current_start:.1f}s-{current_end:.1f}s, 包含{len(current_group)}个片段, 总时长: {merged_segment['duration']:.1f}s") |
|
|
|
|
|
# 开始新组 |
|
|
current_group = [segment] |
|
|
current_start = segment['start_time'] |
|
|
current_end = segment['end_time'] |
|
|
# print(f" 开始新组: [{i + 1}] {current_start:.1f}s-{current_end:.1f}s (时长: {segment_duration:.1f}s)") |
|
|
|
|
|
# 处理最后一组 |
|
|
if current_group: |
|
|
merged_segment = { |
|
|
'start_time': current_start, |
|
|
'end_time': current_end, |
|
|
'duration': current_end - current_start, |
|
|
'type': 'merged', |
|
|
'original_segments': current_group.copy(), |
|
|
'segment_count': len(current_group) |
|
|
} |
|
|
merged_segments.append(merged_segment) |
|
|
# print( |
|
|
# f" 完成组: {current_start:.1f}s-{current_end:.1f}s, 包含{len(current_group)}个片段, 总时长: {merged_segment['duration']:.1f}s") |
|
|
|
|
|
# 统计信息 |
|
|
original_count = len(segments_sorted) |
|
|
merged_count = len(merged_segments) |
|
|
total_duration = sum(seg['duration'] for seg in merged_segments) |
|
|
avg_duration = total_duration / merged_count if merged_count > 0 else 0 |
|
|
max_merged_duration = max(seg['duration'] for seg in merged_segments) if merged_segments else 0 |
|
|
min_merged_duration = min(seg['duration'] for seg in merged_segments) if merged_segments else 0 |
|
|
|
|
|
print(f" 拼装完成:") |
|
|
print(f" 原始片段数: {original_count}") |
|
|
print(f" 拼装后片段数: {merged_count}") |
|
|
print(f" 压缩率: {(1 - merged_count / original_count) * 100:.1f}%") |
|
|
print(f" 平均时长: {avg_duration:.1f}s") |
|
|
print(f" 时长范围: {min_merged_duration:.1f}s - {max_merged_duration:.1f}s") |
|
|
|
|
|
# 验证没有超过最大时长限制 |
|
|
oversized = [seg for seg in merged_segments if seg['duration'] > max_duration] |
|
|
if oversized: |
|
|
print(f" 警告: 有 {len(oversized)} 个拼装片段超过 {max_duration}s") |
|
|
else: |
|
|
print(f" 所有拼装片段均在 {max_duration}s 以内") |
|
|
|
|
|
return merged_segments |
|
|
|
|
|
|
|
|
def fixed_segmentation(total_duration, segment_duration): |
|
|
|
|
|
segments = [] |
|
|
start_time = 0 |
|
|
while start_time < total_duration: |
|
|
end_time = min(start_time + segment_duration, total_duration) |
|
|
segments.append({ |
|
|
'start_time': start_time, |
|
|
'end_time': end_time, |
|
|
'type': 'fixed' |
|
|
}) |
|
|
start_time = end_time |
|
|
return segments |
|
|
|
|
|
|
|
|
def create_audio_segments(segments, audio_pydub): |
|
|
|
|
|
print("️ 创建音频片段...") |
|
|
|
|
|
temp_dir = Path("temp_segments") |
|
|
temp_dir.mkdir(exist_ok=True) |
|
|
|
|
|
segment_files = [] |
|
|
for i, segment_info in enumerate(segments): |
|
|
start_time = segment_info['start_time'] |
|
|
end_time = segment_info['end_time'] |
|
|
|
|
|
# 转换为毫秒并添加25ms padding |
|
|
start_ms = max(0, int(start_time * 1000) - 25) |
|
|
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25) |
|
|
|
|
|
# 提取并保存音频片段 |
|
|
segment = audio_pydub[start_ms:end_ms] |
|
|
segment_file = temp_dir / f"segment_{i:03d}.wav" |
|
|
segment.export(str(segment_file), format="wav") |
|
|
|
|
|
# 添加片段信息 |
|
|
segment_info_with_file = { |
|
|
'file': segment_file, |
|
|
'start_time': start_time, |
|
|
'end_time': end_time, |
|
|
'duration': end_time - start_time, |
|
|
'index': i |
|
|
} |
|
|
|
|
|
# 如果是拼装片段,保留原始片段信息 |
|
|
if 'original_segments' in segment_info: |
|
|
segment_info_with_file['original_segments'] = segment_info['original_segments'] |
|
|
segment_info_with_file['segment_count'] = segment_info['segment_count'] |
|
|
segment_info_with_file['type'] = 'merged' |
|
|
|
|
|
segment_files.append(segment_info_with_file) |
|
|
|
|
|
print(f" 创建完成: {len(segment_files)}个文件") |
|
|
return segment_files |
|
|
|
|
|
|
|
|
def setup_fireredasr_environment(): |
|
|
|
|
|
possible_paths = [ |
|
|
"./FireRedASR", |
|
|
"../FireRedASR", |
|
|
"FireRedASR" |
|
|
] |
|
|
|
|
|
for path in possible_paths: |
|
|
if Path(path).exists(): |
|
|
if str(Path(path).absolute()) not in sys.path: |
|
|
sys.path.insert(0, str(Path(path).absolute())) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def load_fireredasr_model(model_dir): |
|
|
|
|
|
# print(" 加载FireRedASR模型...") |
|
|
|
|
|
# 设置环境 |
|
|
if not setup_fireredasr_environment(): |
|
|
print(" 未找到FireRedASR路径,尝试直接导入...") |
|
|
|
|
|
try: |
|
|
# 修复PyTorch兼容性 |
|
|
torch.serialization.add_safe_globals([argparse.Namespace]) |
|
|
|
|
|
# 尝试多种导入方式 |
|
|
try: |
|
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
except ImportError: |
|
|
try: |
|
|
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr |
|
|
except ImportError: |
|
|
import fireredasr |
|
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
|
|
model = FireRedAsr.from_pretrained("aed", model_dir) |
|
|
|
|
|
# 尝试使用GPU |
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
model = model.to(device) |
|
|
except: |
|
|
pass |
|
|
print(f" 使用GPU: {torch.cuda.get_device_name(0)}") |
|
|
else: |
|
|
print(" 使用CPU") |
|
|
|
|
|
if hasattr(model, 'eval'): |
|
|
model.eval() |
|
|
return model, device |
|
|
|
|
|
except Exception as e: |
|
|
print(f" 模型加载失败: {e}") |
|
|
print("请检查:") |
|
|
print("1. FireRedASR是否正确安装") |
|
|
print("2. 路径配置是否正确") |
|
|
print("3. 依赖库是否完整") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def call_ai_model(sentence): |
|
|
|
|
|
url = "http://192.168.3.8:7777/v1/chat/completions" |
|
|
|
|
|
prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容(一定不对文本做补充和提示)。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。 |
|
|
{sentence}""" |
|
|
|
|
|
payload = { |
|
|
"model": "Qwen3-14B", |
|
|
"messages": [ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post(url, json=payload, timeout=120) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
processed_text = result["choices"][0]["message"]["content"].strip() |
|
|
return processed_text |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f" AI模型调用失败: {e}") |
|
|
return sentence # 失败时返回原文 |
|
|
|
|
|
|
|
|
def transcribe_segments(segment_files, model, device): |
|
|
|
|
|
print(" 开始语音识别...") |
|
|
|
|
|
results = [] |
|
|
use_gpu = device.startswith("cuda") |
|
|
|
|
|
for i, segment_info in enumerate(segment_files): |
|
|
segment_file = segment_info['file'] |
|
|
start_time = segment_info['start_time'] |
|
|
end_time = segment_info['end_time'] |
|
|
duration = segment_info.get('duration', end_time - start_time) |
|
|
|
|
|
# 显示片段信息 |
|
|
segment_type = segment_info.get('type', 'normal') |
|
|
if segment_type == 'merged': |
|
|
original_count = segment_info.get('segment_count', 1) |
|
|
print( |
|
|
f" [{i + 1}/{len(segment_files)}] {start_time:.1f}s-{end_time:.1f}s (时长:{duration:.1f}s, 拼装自{original_count}个片段)") |
|
|
else: |
|
|
print(f" [{i + 1}/{len(segment_files)}] {start_time:.1f}s-{end_time:.1f}s (时长:{duration:.1f}s)") |
|
|
|
|
|
try: |
|
|
batch_uttid = [f"segment_{i:03d}"] |
|
|
batch_wav_path = [str(segment_file)] |
|
|
|
|
|
config = { |
|
|
"use_gpu": 1 if use_gpu else 0, |
|
|
"beam_size": 5, |
|
|
"nbest": 1, |
|
|
"decode_max_len": 0 |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
transcription_result = model.transcribe( |
|
|
batch_uttid, batch_wav_path, config |
|
|
) |
|
|
|
|
|
if transcription_result and len(transcription_result) > 0: |
|
|
text = transcription_result[0].get('text', '').strip() |
|
|
if text: |
|
|
result = { |
|
|
'start_time': start_time, |
|
|
'end_time': end_time, |
|
|
'text': text |
|
|
} |
|
|
# 如果是拼装片段,添加额外信息 |
|
|
if segment_type == 'merged': |
|
|
result['is_merged'] = True |
|
|
result['original_segment_count'] = segment_info.get('segment_count', 1) |
|
|
|
|
|
results.append(result) |
|
|
# print(f" {text}") |
|
|
else: |
|
|
print(f" 无内容") |
|
|
|
|
|
# 清理GPU缓存 |
|
|
if use_gpu: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
except Exception as e: |
|
|
print(f" 错误: {e}") |
|
|
continue |
|
|
|
|
|
print(f" 识别完成: {len(results)}/{len(segment_files)}个片段") |
|
|
return results |
|
|
|
|
|
|
|
|
def process_transcription_results(results): |
|
|
|
|
|
print(" 开始AI文本后处理...") |
|
|
|
|
|
if not results: |
|
|
print(" 没有识别结果") |
|
|
return [] |
|
|
|
|
|
# 按时间排序 |
|
|
results.sort(key=lambda x: x['start_time']) |
|
|
|
|
|
processed_data = [] |
|
|
total_segments = len(results) |
|
|
|
|
|
for i, item in enumerate(results): |
|
|
start_time = item['start_time'] |
|
|
end_time = item['end_time'] |
|
|
original_content = item['text'] |
|
|
|
|
|
# print(f" [{i + 1}/{total_segments}] 处理: {start_time:.1f}s-{end_time:.1f}s") |
|
|
# print(f" 原文: {original_content}") |
|
|
|
|
|
# 调用AI模型处理 |
|
|
processed_content = call_ai_model(original_content) |
|
|
# print(f" 处理后: {processed_content}") |
|
|
|
|
|
# 构建最终结果 |
|
|
processed_item = { |
|
|
'start': int(start_time * 1000), # 转换为毫秒 |
|
|
'end': int(end_time * 1000), # 转换为毫秒 |
|
|
'content': processed_content, |
|
|
'original_content': original_content, # 保留原始识别结果 |
|
|
'start_time_hms': format_time_hms(start_time), |
|
|
'end_time_hms': format_time_hms(end_time), |
|
|
'duration': round(end_time - start_time, 2) |
|
|
} |
|
|
|
|
|
# 添加拼装信息 |
|
|
if item.get('is_merged', False): |
|
|
processed_item['is_merged'] = True |
|
|
processed_item['original_segment_count'] = item.get('original_segment_count', 1) |
|
|
|
|
|
processed_data.append(processed_item) |
|
|
|
|
|
return processed_data |
|
|
|
|
|
|
|
|
def cleanup_temp_files(): |
|
|
|
|
|
temp_dir = Path("temp_segments") |
|
|
if temp_dir.exists(): |
|
|
for file in temp_dir.glob("segment_*.wav"): |
|
|
file.unlink(missing_ok=True) |
|
|
try: |
|
|
temp_dir.rmdir() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
def process_audio_file(audio_path): |
|
|
|
|
|
global model, device |
|
|
|
|
|
try: |
|
|
# 1. 加载音频 |
|
|
audio_data, sr, audio_pydub = load_audio(audio_path) |
|
|
|
|
|
# 2. 能量分段 |
|
|
segments = energy_based_segmentation( |
|
|
audio_data, sr, |
|
|
silence_threshold=0.001, # 静音阈值 |
|
|
min_segment_length=1.0, # 最小段长1秒 |
|
|
max_segment_length=455.0 # 最大段长455秒 |
|
|
) |
|
|
|
|
|
# 3. 细分超过60秒的片段 |
|
|
refined_segments = refine_long_segments(segments, audio_data, sr, max_duration=60.0) |
|
|
|
|
|
# 4. 顺序拼装片段 |
|
|
merged_segments = merge_segments_sequentially(refined_segments, max_duration=45.0) |
|
|
|
|
|
# 5. 创建音频片段 |
|
|
segment_files = create_audio_segments(merged_segments, audio_pydub) |
|
|
|
|
|
# 6. 语音识别 |
|
|
results = transcribe_segments(segment_files, model, device) |
|
|
|
|
|
# 7. AI文本后处理 |
|
|
processed_data = process_transcription_results(results) |
|
|
|
|
|
# 8. 添加统计信息 |
|
|
stats = { |
|
|
'total_segments': len(results), |
|
|
'original_segments': len(refined_segments), |
|
|
'merged_segments': len(merged_segments), |
|
|
'compression_rate': round((1 - len(merged_segments) / len(refined_segments)) * 100, 1) if len( |
|
|
refined_segments) > 0 else 0, |
|
|
'total_duration': round(sum(item['duration'] for item in processed_data), 2) |
|
|
} |
|
|
|
|
|
return processed_data, stats |
|
|
|
|
|
finally: |
|
|
cleanup_temp_files() |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""应用生命周期管理""" |
|
|
global model, device |
|
|
|
|
|
# print("🎵 === FireRedASR专业级ASR API服务启动中 === 🎵") |
|
|
model_dir = "D:/workstation/voice-txt/FireRedASR-test/FireRedASR/pretrained_models/FireRedASR-AED-L" |
|
|
|
|
|
if not Path(model_dir).exists(): |
|
|
print(f" 模型目录不存在: {model_dir}") |
|
|
else: |
|
|
model, device = load_fireredasr_model(model_dir) |
|
|
if model is None: |
|
|
print(" 模型加载失败,服务可能无法正常工作") |
|
|
else: |
|
|
print(" 模型加载成功,服务已就绪") |
|
|
|
|
|
yield # 应用运行期间 |
|
|
|
|
|
# 关闭时执行(可选) |
|
|
print(" 服务正在关闭...") |
|
|
cleanup_temp_files() |
|
|
|
|
|
|
|
|
# 创建FastAPI应用实例 |
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
@app.post("/transcriptions") |
|
|
async def create_transcription(file: UploadFile = File(...)): |
|
|
|
|
|
temp_file_path = None |
|
|
|
|
|
try: |
|
|
# 检查模型是否已加载 |
|
|
if model is None: |
|
|
return JSONResponse( |
|
|
content={ |
|
|
"data": "", |
|
|
"message": "模型未加载,请重启服务", |
|
|
"code": 500 |
|
|
}, |
|
|
status_code=200 |
|
|
) |
|
|
|
|
|
# 检查文件类型 |
|
|
if not file.filename.lower().endswith(('.mp3', '.wav', '.m4a', '.flac')): |
|
|
return JSONResponse( |
|
|
content={ |
|
|
"data": "", |
|
|
"message": "不支持的音频格式,请上传mp3、wav、m4a或flac文件", |
|
|
"code": 400 |
|
|
}, |
|
|
status_code=200 |
|
|
) |
|
|
|
|
|
# 创建临时文件 |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: |
|
|
temp_file_path = temp_file.name |
|
|
# 将上传文件内容写入临时文件 |
|
|
contents = await file.read() |
|
|
temp_file.write(contents) |
|
|
|
|
|
print(f" 开始处理音频文件: {file.filename}") |
|
|
start_time = time.time() |
|
|
|
|
|
# 处理音频文件 |
|
|
result_data, stats = process_audio_file(temp_file_path) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
print(f" 处理完成,耗时: {elapsed_time:.2f}秒") |
|
|
|
|
|
return JSONResponse( |
|
|
content={ |
|
|
"data": jsonable_encoder(result_data), |
|
|
"stats": stats, |
|
|
"processing_time": round(elapsed_time, 2), |
|
|
"message": "success", |
|
|
"code": 200 |
|
|
}, |
|
|
status_code=200 |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f" 处理错误: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
return JSONResponse( |
|
|
content={ |
|
|
"data": "", |
|
|
"message": str(e), |
|
|
"code": 500 |
|
|
}, |
|
|
status_code=200 |
|
|
) |
|
|
|
|
|
finally: |
|
|
# 确保删除临时文件 |
|
|
if temp_file_path and os.path.exists(temp_file_path): |
|
|
try: |
|
|
os.remove(temp_file_path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""健康检查接口""" |
|
|
return JSONResponse( |
|
|
content={ |
|
|
"status": "healthy", |
|
|
"model_loaded": model is not None, |
|
|
"device": device if device else "unknown", |
|
|
"message": "Professional ASR Service is running" |
|
|
}, |
|
|
status_code=200 |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""根路径接口""" |
|
|
return JSONResponse( |
|
|
content={ |
|
|
"service": "FireRedASR Professional ASR API", |
|
|
"version": "1.0.0", |
|
|
"features": [ |
|
|
"能量分段(VAD)", |
|
|
"智能片段拼装", |
|
|
"长片段细分", |
|
|
"AI文本后处理", |
|
|
"完整时间戳", |
|
|
"GPU加速支持" |
|
|
], |
|
|
"endpoints": { |
|
|
"transcriptions": "POST /transcriptions - 语音转录(含AI后处理)", |
|
|
"health": "GET /health - 健康检查" |
|
|
} |
|
|
}, |
|
|
status_code=200 |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
print("🎵 === FireRedASR专业级ASR API服务启动 === 🎵") |
|
|
uvicorn.run(app, host='0.0.0.0', port=7777) |