You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
544 lines
16 KiB
544 lines
16 KiB
import sys |
|
|
|
sys.path.append('D:\\workstation\\voice-txt\\FireRedASR-test\\FireRedASR') |
|
|
|
import os |
|
import sys |
|
import time |
|
import json |
|
import numpy as np |
|
import torch |
|
import argparse |
|
import requests |
|
import tempfile |
|
import uvicorn |
|
from pathlib import Path |
|
from pydub import AudioSegment |
|
import librosa |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, File, UploadFile |
|
from fastapi.responses import JSONResponse |
|
from fastapi.encoders import jsonable_encoder |
|
|
|
# 全局变量存储模型 |
|
model = None |
|
device = None |
|
|
|
|
|
def load_audio(audio_path): |
|
"""加载音频文件""" |
|
print(f" 加载音频: {audio_path}") |
|
|
|
# 使用librosa加载音频数据用于能量分析 |
|
audio_data, sr = librosa.load(str(audio_path), sr=16000, mono=True) |
|
|
|
# 使用pydub加载用于分段导出 |
|
audio_pydub = AudioSegment.from_file(str(audio_path)) |
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
duration = len(audio_data) / sr |
|
print(f" 时长: {duration:.2f}秒, 采样率: {sr}Hz") |
|
|
|
return audio_data, sr, audio_pydub |
|
|
|
|
|
def energy_based_segmentation(audio_data, sr, |
|
silence_threshold=0.001, |
|
min_segment_length=1.0, |
|
max_segment_length=455.0): |
|
"""基于能量的音频分段""" |
|
print(" 能量分段分析...") |
|
|
|
# 计算短时能量 |
|
frame_len = int(0.025 * sr) # 25ms |
|
hop_len = int(0.005 * sr) # 5ms |
|
|
|
energy = [] |
|
for i in range(0, len(audio_data) - frame_len + 1, hop_len): |
|
frame = audio_data[i:i + frame_len] |
|
frame_energy = np.sum(frame ** 2) / len(frame) |
|
energy.append(frame_energy) |
|
|
|
energy = np.array(energy) |
|
if energy.max() > 0: |
|
energy = energy / energy.max() |
|
|
|
# 找静音点 |
|
silence_points = [] |
|
for i, e in enumerate(energy): |
|
if e < silence_threshold: |
|
time_point = i * 0.005 # hop_length转换为秒 |
|
silence_points.append(time_point) |
|
|
|
if not silence_points: |
|
print(" 未检测到静音,使用固定分段") |
|
return fixed_segmentation(len(audio_data) / sr, max_segment_length) |
|
|
|
# 合并相邻静音 |
|
silence_intervals = [] |
|
if silence_points: |
|
current_start = silence_points[0] |
|
current_end = silence_points[0] |
|
|
|
for point in silence_points[1:]: |
|
if point - current_end <= 0.1: # 0.1s内视为连续 |
|
current_end = point |
|
else: |
|
if current_end - current_start >= 0.05: # 至少50ms |
|
silence_intervals.append((current_start, current_end)) |
|
current_start = point |
|
current_end = point |
|
|
|
if current_end - current_start >= 0.05: |
|
silence_intervals.append((current_start, current_end)) |
|
|
|
# 生成语音段 |
|
segments = [] |
|
last_end = 0.0 |
|
audio_duration = len(audio_data) / sr |
|
|
|
for silence_start, silence_end in silence_intervals: |
|
if silence_start - last_end >= min_segment_length: |
|
segments.append({ |
|
'start_time': last_end, |
|
'end_time': silence_start, |
|
'type': 'natural' |
|
}) |
|
elif segments and silence_start - segments[-1]['start_time'] <= max_segment_length: |
|
segments[-1]['end_time'] = silence_start |
|
|
|
last_end = silence_end |
|
|
|
# 处理末尾 |
|
if audio_duration - last_end >= min_segment_length: |
|
segments.append({ |
|
'start_time': last_end, |
|
'end_time': audio_duration, |
|
'type': 'natural' |
|
}) |
|
elif segments: |
|
segments[-1]['end_time'] = audio_duration |
|
|
|
# 处理过长段落 |
|
final_segments = [] |
|
for segment in segments: |
|
duration = segment['end_time'] - segment['start_time'] |
|
if duration > max_segment_length: |
|
num_subsegments = int(np.ceil(duration / max_segment_length)) |
|
sub_duration = duration / num_subsegments |
|
|
|
for i in range(num_subsegments): |
|
sub_start = segment['start_time'] + i * sub_duration |
|
sub_end = min(sub_start + sub_duration, segment['end_time']) |
|
final_segments.append({ |
|
'start_time': sub_start, |
|
'end_time': sub_end, |
|
'type': 'forced' |
|
}) |
|
else: |
|
final_segments.append(segment) |
|
|
|
print(f" 完成分段: {len(final_segments)}个片段") |
|
return final_segments |
|
|
|
|
|
def fixed_segmentation(total_duration, segment_duration): |
|
"""固定时长分段(备用)""" |
|
segments = [] |
|
start_time = 0 |
|
while start_time < total_duration: |
|
end_time = min(start_time + segment_duration, total_duration) |
|
segments.append({ |
|
'start_time': start_time, |
|
'end_time': end_time, |
|
'type': 'fixed' |
|
}) |
|
start_time = end_time |
|
return segments |
|
|
|
|
|
def create_audio_segments(segments, audio_pydub): |
|
"""创建音频片段文件""" |
|
print(" 创建音频片段...") |
|
|
|
temp_dir = Path("temp_segments") |
|
temp_dir.mkdir(exist_ok=True) |
|
|
|
segment_files = [] |
|
for i, segment_info in enumerate(segments): |
|
start_time = segment_info['start_time'] |
|
end_time = segment_info['end_time'] |
|
|
|
# 转换为毫秒并添加25ms padding |
|
start_ms = max(0, int(start_time * 1000) - 25) |
|
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25) |
|
|
|
# 提取并保存音频片段 |
|
segment = audio_pydub[start_ms:end_ms] |
|
segment_file = temp_dir / f"segment_{i:03d}.wav" |
|
segment.export(str(segment_file), format="wav") |
|
|
|
segment_files.append({ |
|
'file': segment_file, |
|
'start_time': start_time, |
|
'end_time': end_time, |
|
'index': i |
|
}) |
|
|
|
print(f" 创建完成: {len(segment_files)}个文件") |
|
return segment_files |
|
|
|
|
|
def setup_fireredasr_environment(): |
|
|
|
possible_paths = [ |
|
# "D:/workstation/voice-txt/FireRedASR-test/FireRedASR", |
|
"./FireRedASR", |
|
"../FireRedASR", |
|
"FireRedASR" |
|
] |
|
|
|
for path in possible_paths: |
|
if Path(path).exists(): |
|
if str(Path(path).absolute()) not in sys.path: |
|
sys.path.insert(0, str(Path(path).absolute())) |
|
# print(f" 添加路径: {Path(path).absolute()}") |
|
return True |
|
return False |
|
|
|
|
|
def load_fireredasr_model(model_dir): |
|
print(" 加载FireRedASR模型...") |
|
|
|
# 设置环境 |
|
if not setup_fireredasr_environment(): |
|
print(" 未找到FireRedASR路径,尝试直接导入...") |
|
|
|
try: |
|
# 修复PyTorch兼容性 |
|
torch.serialization.add_safe_globals([argparse.Namespace]) |
|
|
|
# 尝试多种导入方式 |
|
try: |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
try: |
|
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr |
|
except ImportError: |
|
import fireredasr |
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
model = FireRedAsr.from_pretrained("aed", model_dir) |
|
|
|
# 尝试使用GPU |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
if torch.cuda.is_available(): |
|
try: |
|
model = model.to(device) |
|
except: |
|
pass |
|
print(f" 使用GPU: {torch.cuda.get_device_name(0)}") |
|
else: |
|
print(" 使用CPU") |
|
|
|
if hasattr(model, 'eval'): |
|
model.eval() |
|
return model, device |
|
|
|
except Exception as e: |
|
print(f" 模型加载失败: {e}") |
|
print("请检查:") |
|
print("1. FireRedASR是否正确安装") |
|
print("2. 路径配置是否正确") |
|
print("3. 依赖库是否完整") |
|
return None, None |
|
|
|
|
|
def transcribe_segments(segment_files, model, device): |
|
print(" 开始语音识别...") |
|
|
|
results = [] |
|
use_gpu = device.startswith("cuda") |
|
|
|
for i, segment_info in enumerate(segment_files): |
|
segment_file = segment_info['file'] |
|
start_time = segment_info['start_time'] |
|
end_time = segment_info['end_time'] |
|
|
|
print(f" [{i + 1}/{len(segment_files)}] {start_time:.1f}s-{end_time:.1f}s") |
|
|
|
try: |
|
batch_uttid = [f"segment_{i:03d}"] |
|
batch_wav_path = [str(segment_file)] |
|
|
|
config = { |
|
"use_gpu": 1 if use_gpu else 0, |
|
"beam_size": 5, |
|
"nbest": 1, |
|
"decode_max_len": 0 |
|
} |
|
|
|
with torch.no_grad(): |
|
transcription_result = model.transcribe( |
|
batch_uttid, batch_wav_path, config |
|
) |
|
|
|
if transcription_result and len(transcription_result) > 0: |
|
result = transcription_result[0] |
|
text = result.get('text', '').strip() |
|
|
|
# 提取置信度(如果存在) |
|
confidence = result.get('confidence', result.get('score', 0.0)) |
|
if isinstance(confidence, (list, tuple)) and len(confidence) > 0: |
|
confidence = float(confidence[0]) |
|
elif not isinstance(confidence, (int, float)): |
|
confidence = 0.0 |
|
else: |
|
confidence = float(confidence) |
|
|
|
if text: |
|
results.append({ |
|
'start_time': start_time, |
|
'end_time': end_time, |
|
'text': text, |
|
'confidence': confidence |
|
}) |
|
# print(f" {text} (置信度: {confidence:.3f})") |
|
else: |
|
print(f" 无内容") |
|
|
|
# 清理GPU缓存 |
|
if use_gpu: |
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
print(f" 错误: {e}") |
|
continue |
|
|
|
print(f" 识别完成: {len(results)}/{len(segment_files)}个片段") |
|
return results |
|
|
|
|
|
def call_ai_model(sentence): |
|
url = "http://192.168.3.8:7777/v1/chat/completions" |
|
|
|
prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。 |
|
{sentence}""" |
|
|
|
payload = { |
|
"model": "Qwen3-14B", |
|
"messages": [ |
|
{"role": "user", "content": prompt} |
|
] |
|
} |
|
|
|
try: |
|
response = requests.post(url, json=payload, timeout=120) |
|
response.raise_for_status() |
|
result = response.json() |
|
processed_text = result["choices"][0]["message"]["content"].strip() |
|
return processed_text |
|
except requests.exceptions.RequestException as e: |
|
print(f" API调用失败: {e}") |
|
return sentence # 失败时返回原文 |
|
|
|
|
|
def process_transcription_results(results): |
|
|
|
print(" 开始文本后处理...") |
|
|
|
if not results: |
|
print(" 没有识别结果") |
|
return [] |
|
|
|
# 按时间排序 |
|
results.sort(key=lambda x: x['start_time']) |
|
|
|
processed_data = [] |
|
total_segments = len(results) |
|
|
|
for i, item in enumerate(results): |
|
start_time = item['start_time'] |
|
end_time = item['end_time'] |
|
original_content = item['text'] |
|
confidence = item['confidence'] |
|
|
|
print(f" [{i + 1}/{total_segments}] 处理: {start_time:.1f}s-{end_time:.1f}s") |
|
print(f" 原文: {original_content}") |
|
|
|
# 调用AI模型处理 |
|
processed_content = call_ai_model(original_content) |
|
print(f" 处理后: {processed_content}") |
|
|
|
processed_data.append({ |
|
'start': int(start_time * 1000), # 转换为毫秒 |
|
'end': int(end_time * 1000), # 转换为毫秒 |
|
'content': processed_content, |
|
'confidence': round(confidence, 3) |
|
}) |
|
|
|
return processed_data |
|
|
|
|
|
def cleanup_temp_files(): |
|
"""清理临时文件""" |
|
temp_dir = Path("temp_segments") |
|
if temp_dir.exists(): |
|
for file in temp_dir.glob("segment_*.wav"): |
|
file.unlink(missing_ok=True) |
|
try: |
|
temp_dir.rmdir() |
|
except: |
|
pass |
|
|
|
|
|
def process_audio_file(audio_path): |
|
|
|
global model, device |
|
|
|
try: |
|
# 1. 加载音频 |
|
audio_data, sr, audio_pydub = load_audio(audio_path) |
|
|
|
# 2. 能量分段 |
|
segments = energy_based_segmentation( |
|
audio_data, sr, |
|
silence_threshold=0.001, # 静音阈值 |
|
min_segment_length=1.0, # 最小段长2秒 |
|
max_segment_length=455.0 # 最大段长455秒 |
|
) |
|
|
|
# 3. 创建音频片段 |
|
segment_files = create_audio_segments(segments, audio_pydub) |
|
|
|
# 4. 语音识别 |
|
results = transcribe_segments(segment_files, model, device) |
|
|
|
# 5. 处理结果 |
|
processed_data = process_transcription_results(results) |
|
|
|
return processed_data |
|
|
|
finally: |
|
cleanup_temp_files() |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
global model, device |
|
|
|
print("正在启动服务...") |
|
model_dir = "D:/workstation/voice-txt/FireRedASR-test/FireRedASR/pretrained_models/FireRedASR-AED-L" |
|
|
|
if not Path(model_dir).exists(): |
|
print(f" 模型目录不存在: {model_dir}") |
|
else: |
|
model, device = load_fireredasr_model(model_dir) |
|
if model is None: |
|
print(" 模型加载失败,服务可能无法正常工作") |
|
else: |
|
print("模型加载成功,服务已就绪") |
|
|
|
yield # 应用运行期间 |
|
|
|
# 关闭时执行(可选) |
|
print(" 服务正在关闭...") |
|
cleanup_temp_files() |
|
|
|
|
|
# 创建FastAPI应用实例 |
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
@app.post("/transcriptions") |
|
async def create_file(file: UploadFile = File(...)): |
|
temp_file_path = None |
|
|
|
try: |
|
# 检查模型是否已加载 |
|
if model is None: |
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": "模型未加载,请重启服务", |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
# 检查文件类型 |
|
if not file.filename.lower().endswith(('.mp3', '.wav', '.m4a', '.flac')): |
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": "不支持的音频格式,请上传mp3、wav、m4a或flac文件", |
|
"code": 400 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
# 创建临时文件 |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: |
|
temp_file_path = temp_file.name |
|
# 将上传文件内容写入临时文件 |
|
contents = await file.read() |
|
temp_file.write(contents) |
|
|
|
print(f"开始处理音频文件: {file.filename}") |
|
start_time = time.time() |
|
|
|
# 处理音频文件 |
|
result_data = process_audio_file(temp_file_path) |
|
|
|
elapsed_time = time.time() - start_time |
|
print(f"处理完成,耗时: {elapsed_time:.2f}秒") |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": jsonable_encoder(result_data), |
|
"message": "success", |
|
"code": 200 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
except Exception as e: |
|
print(f" 处理错误: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
return JSONResponse( |
|
content={ |
|
"data": "", |
|
"message": str(e), |
|
"code": 500 |
|
}, |
|
status_code=200 |
|
) |
|
|
|
finally: |
|
# 确保删除临时文件 |
|
if temp_file_path and os.path.exists(temp_file_path): |
|
try: |
|
os.remove(temp_file_path) |
|
except: |
|
pass |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""健康检查接口""" |
|
return JSONResponse( |
|
content={ |
|
"status": "healthy", |
|
"model_loaded": model is not None, |
|
"message": "Service is running" |
|
}, |
|
status_code=200 |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
print("🎵 === FireRedASR API服务启动中 === 🎵") |
|
uvicorn.run(app, host='0.0.0.0', port=7777) |