You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

544 lines
16 KiB

import sys
sys.path.append('D:\\workstation\\voice-txt\\FireRedASR-test\\FireRedASR')
import os
import sys
import time
import json
import numpy as np
import torch
import argparse
import requests
import tempfile
import uvicorn
from pathlib import Path
from pydub import AudioSegment
import librosa
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
# 全局变量存储模型
model = None
device = None
def load_audio(audio_path):
"""加载音频文件"""
print(f" 加载音频: {audio_path}")
# 使用librosa加载音频数据用于能量分析
audio_data, sr = librosa.load(str(audio_path), sr=16000, mono=True)
# 使用pydub加载用于分段导出
audio_pydub = AudioSegment.from_file(str(audio_path))
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1)
duration = len(audio_data) / sr
print(f" 时长: {duration:.2f}秒, 采样率: {sr}Hz")
return audio_data, sr, audio_pydub
def energy_based_segmentation(audio_data, sr,
silence_threshold=0.001,
min_segment_length=1.0,
max_segment_length=455.0):
"""基于能量的音频分段"""
print(" 能量分段分析...")
# 计算短时能量
frame_len = int(0.025 * sr) # 25ms
hop_len = int(0.005 * sr) # 5ms
energy = []
for i in range(0, len(audio_data) - frame_len + 1, hop_len):
frame = audio_data[i:i + frame_len]
frame_energy = np.sum(frame ** 2) / len(frame)
energy.append(frame_energy)
energy = np.array(energy)
if energy.max() > 0:
energy = energy / energy.max()
# 找静音点
silence_points = []
for i, e in enumerate(energy):
if e < silence_threshold:
time_point = i * 0.005 # hop_length转换为秒
silence_points.append(time_point)
if not silence_points:
print(" 未检测到静音,使用固定分段")
return fixed_segmentation(len(audio_data) / sr, max_segment_length)
# 合并相邻静音
silence_intervals = []
if silence_points:
current_start = silence_points[0]
current_end = silence_points[0]
for point in silence_points[1:]:
if point - current_end <= 0.1: # 0.1s内视为连续
current_end = point
else:
if current_end - current_start >= 0.05: # 至少50ms
silence_intervals.append((current_start, current_end))
current_start = point
current_end = point
if current_end - current_start >= 0.05:
silence_intervals.append((current_start, current_end))
# 生成语音段
segments = []
last_end = 0.0
audio_duration = len(audio_data) / sr
for silence_start, silence_end in silence_intervals:
if silence_start - last_end >= min_segment_length:
segments.append({
'start_time': last_end,
'end_time': silence_start,
'type': 'natural'
})
elif segments and silence_start - segments[-1]['start_time'] <= max_segment_length:
segments[-1]['end_time'] = silence_start
last_end = silence_end
# 处理末尾
if audio_duration - last_end >= min_segment_length:
segments.append({
'start_time': last_end,
'end_time': audio_duration,
'type': 'natural'
})
elif segments:
segments[-1]['end_time'] = audio_duration
# 处理过长段落
final_segments = []
for segment in segments:
duration = segment['end_time'] - segment['start_time']
if duration > max_segment_length:
num_subsegments = int(np.ceil(duration / max_segment_length))
sub_duration = duration / num_subsegments
for i in range(num_subsegments):
sub_start = segment['start_time'] + i * sub_duration
sub_end = min(sub_start + sub_duration, segment['end_time'])
final_segments.append({
'start_time': sub_start,
'end_time': sub_end,
'type': 'forced'
})
else:
final_segments.append(segment)
print(f" 完成分段: {len(final_segments)}个片段")
return final_segments
def fixed_segmentation(total_duration, segment_duration):
"""固定时长分段(备用)"""
segments = []
start_time = 0
while start_time < total_duration:
end_time = min(start_time + segment_duration, total_duration)
segments.append({
'start_time': start_time,
'end_time': end_time,
'type': 'fixed'
})
start_time = end_time
return segments
def create_audio_segments(segments, audio_pydub):
"""创建音频片段文件"""
print(" 创建音频片段...")
temp_dir = Path("temp_segments")
temp_dir.mkdir(exist_ok=True)
segment_files = []
for i, segment_info in enumerate(segments):
start_time = segment_info['start_time']
end_time = segment_info['end_time']
# 转换为毫秒并添加25ms padding
start_ms = max(0, int(start_time * 1000) - 25)
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25)
# 提取并保存音频片段
segment = audio_pydub[start_ms:end_ms]
segment_file = temp_dir / f"segment_{i:03d}.wav"
segment.export(str(segment_file), format="wav")
segment_files.append({
'file': segment_file,
'start_time': start_time,
'end_time': end_time,
'index': i
})
print(f" 创建完成: {len(segment_files)}个文件")
return segment_files
def setup_fireredasr_environment():
possible_paths = [
# "D:/workstation/voice-txt/FireRedASR-test/FireRedASR",
"./FireRedASR",
"../FireRedASR",
"FireRedASR"
]
for path in possible_paths:
if Path(path).exists():
if str(Path(path).absolute()) not in sys.path:
sys.path.insert(0, str(Path(path).absolute()))
# print(f" 添加路径: {Path(path).absolute()}")
return True
return False
def load_fireredasr_model(model_dir):
print(" 加载FireRedASR模型...")
# 设置环境
if not setup_fireredasr_environment():
print(" 未找到FireRedASR路径,尝试直接导入...")
try:
# 修复PyTorch兼容性
torch.serialization.add_safe_globals([argparse.Namespace])
# 尝试多种导入方式
try:
from fireredasr.models.fireredasr import FireRedAsr
except ImportError:
try:
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr
except ImportError:
import fireredasr
from fireredasr.models.fireredasr import FireRedAsr
model = FireRedAsr.from_pretrained("aed", model_dir)
# 尝试使用GPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
try:
model = model.to(device)
except:
pass
print(f" 使用GPU: {torch.cuda.get_device_name(0)}")
else:
print(" 使用CPU")
if hasattr(model, 'eval'):
model.eval()
return model, device
except Exception as e:
print(f" 模型加载失败: {e}")
print("请检查:")
print("1. FireRedASR是否正确安装")
print("2. 路径配置是否正确")
print("3. 依赖库是否完整")
return None, None
def transcribe_segments(segment_files, model, device):
print(" 开始语音识别...")
results = []
use_gpu = device.startswith("cuda")
for i, segment_info in enumerate(segment_files):
segment_file = segment_info['file']
start_time = segment_info['start_time']
end_time = segment_info['end_time']
print(f" [{i + 1}/{len(segment_files)}] {start_time:.1f}s-{end_time:.1f}s")
try:
batch_uttid = [f"segment_{i:03d}"]
batch_wav_path = [str(segment_file)]
config = {
"use_gpu": 1 if use_gpu else 0,
"beam_size": 5,
"nbest": 1,
"decode_max_len": 0
}
with torch.no_grad():
transcription_result = model.transcribe(
batch_uttid, batch_wav_path, config
)
if transcription_result and len(transcription_result) > 0:
result = transcription_result[0]
text = result.get('text', '').strip()
# 提取置信度(如果存在)
confidence = result.get('confidence', result.get('score', 0.0))
if isinstance(confidence, (list, tuple)) and len(confidence) > 0:
confidence = float(confidence[0])
elif not isinstance(confidence, (int, float)):
confidence = 0.0
else:
confidence = float(confidence)
if text:
results.append({
'start_time': start_time,
'end_time': end_time,
'text': text,
'confidence': confidence
})
# print(f" {text} (置信度: {confidence:.3f})")
else:
print(f" 无内容")
# 清理GPU缓存
if use_gpu:
torch.cuda.empty_cache()
except Exception as e:
print(f" 错误: {e}")
continue
print(f" 识别完成: {len(results)}/{len(segment_files)}个片段")
return results
def call_ai_model(sentence):
url = "http://192.168.3.8:7777/v1/chat/completions"
prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。
{sentence}"""
payload = {
"model": "Qwen3-14B",
"messages": [
{"role": "user", "content": prompt}
]
}
try:
response = requests.post(url, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
processed_text = result["choices"][0]["message"]["content"].strip()
return processed_text
except requests.exceptions.RequestException as e:
print(f" API调用失败: {e}")
return sentence # 失败时返回原文
def process_transcription_results(results):
print(" 开始文本后处理...")
if not results:
print(" 没有识别结果")
return []
# 按时间排序
results.sort(key=lambda x: x['start_time'])
processed_data = []
total_segments = len(results)
for i, item in enumerate(results):
start_time = item['start_time']
end_time = item['end_time']
original_content = item['text']
confidence = item['confidence']
print(f" [{i + 1}/{total_segments}] 处理: {start_time:.1f}s-{end_time:.1f}s")
print(f" 原文: {original_content}")
# 调用AI模型处理
processed_content = call_ai_model(original_content)
print(f" 处理后: {processed_content}")
processed_data.append({
'start': int(start_time * 1000), # 转换为毫秒
'end': int(end_time * 1000), # 转换为毫秒
'content': processed_content,
'confidence': round(confidence, 3)
})
return processed_data
def cleanup_temp_files():
"""清理临时文件"""
temp_dir = Path("temp_segments")
if temp_dir.exists():
for file in temp_dir.glob("segment_*.wav"):
file.unlink(missing_ok=True)
try:
temp_dir.rmdir()
except:
pass
def process_audio_file(audio_path):
global model, device
try:
# 1. 加载音频
audio_data, sr, audio_pydub = load_audio(audio_path)
# 2. 能量分段
segments = energy_based_segmentation(
audio_data, sr,
silence_threshold=0.001, # 静音阈值
min_segment_length=1.0, # 最小段长2秒
max_segment_length=455.0 # 最大段长455秒
)
# 3. 创建音频片段
segment_files = create_audio_segments(segments, audio_pydub)
# 4. 语音识别
results = transcribe_segments(segment_files, model, device)
# 5. 处理结果
processed_data = process_transcription_results(results)
return processed_data
finally:
cleanup_temp_files()
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, device
print("正在启动服务...")
model_dir = "D:/workstation/voice-txt/FireRedASR-test/FireRedASR/pretrained_models/FireRedASR-AED-L"
if not Path(model_dir).exists():
print(f" 模型目录不存在: {model_dir}")
else:
model, device = load_fireredasr_model(model_dir)
if model is None:
print(" 模型加载失败,服务可能无法正常工作")
else:
print("模型加载成功,服务已就绪")
yield # 应用运行期间
# 关闭时执行(可选)
print(" 服务正在关闭...")
cleanup_temp_files()
# 创建FastAPI应用实例
app = FastAPI(lifespan=lifespan)
@app.post("/transcriptions")
async def create_file(file: UploadFile = File(...)):
temp_file_path = None
try:
# 检查模型是否已加载
if model is None:
return JSONResponse(
content={
"data": "",
"message": "模型未加载,请重启服务",
"code": 500
},
status_code=200
)
# 检查文件类型
if not file.filename.lower().endswith(('.mp3', '.wav', '.m4a', '.flac')):
return JSONResponse(
content={
"data": "",
"message": "不支持的音频格式,请上传mp3、wav、m4a或flac文件",
"code": 400
},
status_code=200
)
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
temp_file_path = temp_file.name
# 将上传文件内容写入临时文件
contents = await file.read()
temp_file.write(contents)
print(f"开始处理音频文件: {file.filename}")
start_time = time.time()
# 处理音频文件
result_data = process_audio_file(temp_file_path)
elapsed_time = time.time() - start_time
print(f"处理完成,耗时: {elapsed_time:.2f}")
return JSONResponse(
content={
"data": jsonable_encoder(result_data),
"message": "success",
"code": 200
},
status_code=200
)
except Exception as e:
print(f" 处理错误: {str(e)}")
import traceback
traceback.print_exc()
return JSONResponse(
content={
"data": "",
"message": str(e),
"code": 500
},
status_code=200
)
finally:
# 确保删除临时文件
if temp_file_path and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
except:
pass
@app.get("/health")
async def health_check():
"""健康检查接口"""
return JSONResponse(
content={
"status": "healthy",
"model_loaded": model is not None,
"message": "Service is running"
},
status_code=200
)
if __name__ == '__main__':
print("🎵 === FireRedASR API服务启动中 === 🎵")
uvicorn.run(app, host='0.0.0.0', port=7777)