You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

775 lines
25 KiB

import sys
sys.path.append('D:\\workstation\\voice-txt\\FireRedASR-test\\FireRedASR')
import os
import sys
import time
import json
import numpy as np
import torch
import argparse
import requests
import tempfile
import uvicorn
from pathlib import Path
from pydub import AudioSegment
import librosa
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
model = None
device = None
def format_time_hms(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
def load_audio(audio_path):
"""加载音频文件"""
print(f" 加载音频: {audio_path}")
# 使用librosa加载音频数据用于能量分析
audio_data, sr = librosa.load(str(audio_path), sr=16000, mono=True)
# 使用pydub加载用于分段导出
audio_pydub = AudioSegment.from_file(str(audio_path))
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1)
duration = len(audio_data) / sr
print(f" 时长: {duration:.2f}秒, 采样率: {sr}Hz")
return audio_data, sr, audio_pydub
def energy_based_segmentation(audio_data, sr,
silence_threshold=0.001,
min_segment_length=1.0,
max_segment_length=455.0):
# print(" 能量分段分析...")
# 计算短时能量
frame_len = int(0.025 * sr) # 25ms
hop_len = int(0.005 * sr) # 5ms
energy = []
for i in range(0, len(audio_data) - frame_len + 1, hop_len):
frame = audio_data[i:i + frame_len]
frame_energy = np.sum(frame ** 2) / len(frame)
energy.append(frame_energy)
energy = np.array(energy)
if energy.max() > 0:
energy = energy / energy.max()
# 找静音点
silence_points = []
for i, e in enumerate(energy):
if e < silence_threshold:
time_point = i * 0.005 # hop_length转换为秒
silence_points.append(time_point)
if not silence_points:
# print(" 未检测到静音,使用固定分段")
return fixed_segmentation(len(audio_data) / sr, max_segment_length)
# 合并相邻静音
silence_intervals = []
if silence_points:
current_start = silence_points[0]
current_end = silence_points[0]
for point in silence_points[1:]:
if point - current_end <= 0.1: # 0.1s内视为连续
current_end = point
else:
if current_end - current_start >= 0.05: # 至少50ms
silence_intervals.append((current_start, current_end))
current_start = point
current_end = point
if current_end - current_start >= 0.05:
silence_intervals.append((current_start, current_end))
# 生成语音段
segments = []
last_end = 0.0
audio_duration = len(audio_data) / sr
for silence_start, silence_end in silence_intervals:
if silence_start - last_end >= min_segment_length:
segments.append({
'start_time': last_end,
'end_time': silence_start,
'type': 'natural'
})
elif segments and silence_start - segments[-1]['start_time'] <= max_segment_length:
segments[-1]['end_time'] = silence_start
last_end = silence_end
# 处理末尾
if audio_duration - last_end >= min_segment_length:
segments.append({
'start_time': last_end,
'end_time': audio_duration,
'type': 'natural'
})
elif segments:
segments[-1]['end_time'] = audio_duration
# 处理过长段落
final_segments = []
for segment in segments:
duration = segment['end_time'] - segment['start_time']
if duration > max_segment_length:
num_subsegments = int(np.ceil(duration / max_segment_length))
sub_duration = duration / num_subsegments
for i in range(num_subsegments):
sub_start = segment['start_time'] + i * sub_duration
sub_end = min(sub_start + sub_duration, segment['end_time'])
final_segments.append({
'start_time': sub_start,
'end_time': sub_end,
'type': 'forced'
})
else:
final_segments.append(segment)
print(f" 完成分段: {len(final_segments)}个片段")
return final_segments
def refine_long_segments(segments, audio_data, sr, max_duration=60.0):
# print(f" 检查并细分超过{max_duration}秒的片段...")
# 更细致的分段参数
refined_params = {
'silence_threshold': 0.0001, # 更敏感的静音检测
'min_segment_length': 0.5, # 更短的最小段长
'max_segment_length': 45.0 # 确保不超过模型限制
}
refined_segments = []
refined_count = 0
for i, segment in enumerate(segments):
duration = segment['end_time'] - segment['start_time']
if duration > max_duration:
# print(f" 片段 {i + 1}: {duration:.1f}s > {max_duration}s,需要细分")
refined_count += 1
# 提取对应的音频数据段
start_sample = int(segment['start_time'] * sr)
end_sample = int(segment['end_time'] * sr)
segment_audio = audio_data[start_sample:end_sample]
# 对该片段进行更细致的分段
sub_segments = energy_based_segmentation(
segment_audio, sr,
silence_threshold=refined_params['silence_threshold'],
min_segment_length=refined_params['min_segment_length'],
max_segment_length=refined_params['max_segment_length']
)
# 调整时间偏移(相对于原始音频的时间)
for sub_segment in sub_segments:
sub_segment['start_time'] += segment['start_time']
sub_segment['end_time'] += segment['start_time']
sub_segment['type'] = 'refined' # 标记为细分片段
refined_segments.append(sub_segment)
# print(f" 细分为 {len(sub_segments)} 个子片段")
else:
# 保持原片段不变
refined_segments.append(segment)
# print(f" 完成细分: {refined_count}个超长片段被处理")
# print(f" 最终片段数: {len(segments)} -> {len(refined_segments)}")
# 验证所有片段都符合时长要求
oversized_count = sum(1 for seg in refined_segments
if seg['end_time'] - seg['start_time'] > max_duration)
if oversized_count > 0:
print(f" 警告: 仍有 {oversized_count} 个片段超过{max_duration}")
else:
print(f" 所有片段均在{max_duration}秒以内")
return refined_segments
def merge_segments_sequentially(segments, max_duration=45.0):
# print(f" 开始片段顺序拼装 (最大时长: {max_duration}秒)...")
if not segments:
return []
# 按开始时间排序确保顺序正确
segments_sorted = sorted(segments, key=lambda x: x['start_time'])
merged_segments = []
current_group = []
current_start = None
current_end = None
for i, segment in enumerate(segments_sorted):
segment_duration = segment['end_time'] - segment['start_time']
# 如果是第一个片段,直接加入
if not current_group:
current_group.append(segment)
current_start = segment['start_time']
current_end = segment['end_time']
print(f" 开始新组: [{i + 1}] {current_start:.1f}s-{current_end:.1f}s (时长: {segment_duration:.1f}s)")
continue
# 计算加入当前片段后的总时长
potential_duration = segment['end_time'] - current_start
# 检查是否可以加入当前组
if potential_duration <= max_duration:
# 可以加入
current_group.append(segment)
current_end = segment['end_time']
# print(
# f" 加入组: [{i + 1}] {segment['start_time']:.1f}s-{segment['end_time']:.1f}s, 组总时长: {potential_duration:.1f}s")
else:
# 不能加入,完成当前组并开始新组
merged_segment = {
'start_time': current_start,
'end_time': current_end,
'duration': current_end - current_start,
'type': 'merged',
'original_segments': current_group.copy(),
'segment_count': len(current_group)
}
merged_segments.append(merged_segment)
# print(
# f" 完成组: {current_start:.1f}s-{current_end:.1f}s, 包含{len(current_group)}个片段, 总时长: {merged_segment['duration']:.1f}s")
# 开始新组
current_group = [segment]
current_start = segment['start_time']
current_end = segment['end_time']
# print(f" 开始新组: [{i + 1}] {current_start:.1f}s-{current_end:.1f}s (时长: {segment_duration:.1f}s)")
# 处理最后一组
if current_group:
merged_segment = {
'start_time': current_start,
'end_time': current_end,
'duration': current_end - current_start,
'type': 'merged',
'original_segments': current_group.copy(),
'segment_count': len(current_group)
}
merged_segments.append(merged_segment)
# print(
# f" 完成组: {current_start:.1f}s-{current_end:.1f}s, 包含{len(current_group)}个片段, 总时长: {merged_segment['duration']:.1f}s")
# 统计信息
original_count = len(segments_sorted)
merged_count = len(merged_segments)
total_duration = sum(seg['duration'] for seg in merged_segments)
avg_duration = total_duration / merged_count if merged_count > 0 else 0
max_merged_duration = max(seg['duration'] for seg in merged_segments) if merged_segments else 0
min_merged_duration = min(seg['duration'] for seg in merged_segments) if merged_segments else 0
print(f" 拼装完成:")
print(f" 原始片段数: {original_count}")
print(f" 拼装后片段数: {merged_count}")
print(f" 压缩率: {(1 - merged_count / original_count) * 100:.1f}%")
print(f" 平均时长: {avg_duration:.1f}s")
print(f" 时长范围: {min_merged_duration:.1f}s - {max_merged_duration:.1f}s")
# 验证没有超过最大时长限制
oversized = [seg for seg in merged_segments if seg['duration'] > max_duration]
if oversized:
print(f" 警告: 有 {len(oversized)} 个拼装片段超过 {max_duration}s")
else:
print(f" 所有拼装片段均在 {max_duration}s 以内")
return merged_segments
def fixed_segmentation(total_duration, segment_duration):
segments = []
start_time = 0
while start_time < total_duration:
end_time = min(start_time + segment_duration, total_duration)
segments.append({
'start_time': start_time,
'end_time': end_time,
'type': 'fixed'
})
start_time = end_time
return segments
def create_audio_segments(segments, audio_pydub):
print(" 创建音频片段...")
temp_dir = Path("temp_segments")
temp_dir.mkdir(exist_ok=True)
segment_files = []
for i, segment_info in enumerate(segments):
start_time = segment_info['start_time']
end_time = segment_info['end_time']
# 转换为毫秒并添加25ms padding
start_ms = max(0, int(start_time * 1000) - 25)
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25)
# 提取并保存音频片段
segment = audio_pydub[start_ms:end_ms]
segment_file = temp_dir / f"segment_{i:03d}.wav"
segment.export(str(segment_file), format="wav")
# 添加片段信息
segment_info_with_file = {
'file': segment_file,
'start_time': start_time,
'end_time': end_time,
'duration': end_time - start_time,
'index': i
}
# 如果是拼装片段,保留原始片段信息
if 'original_segments' in segment_info:
segment_info_with_file['original_segments'] = segment_info['original_segments']
segment_info_with_file['segment_count'] = segment_info['segment_count']
segment_info_with_file['type'] = 'merged'
segment_files.append(segment_info_with_file)
print(f" 创建完成: {len(segment_files)}个文件")
return segment_files
def setup_fireredasr_environment():
possible_paths = [
"./FireRedASR",
"../FireRedASR",
"FireRedASR"
]
for path in possible_paths:
if Path(path).exists():
if str(Path(path).absolute()) not in sys.path:
sys.path.insert(0, str(Path(path).absolute()))
return True
return False
def load_fireredasr_model(model_dir):
# print(" 加载FireRedASR模型...")
# 设置环境
if not setup_fireredasr_environment():
print(" 未找到FireRedASR路径,尝试直接导入...")
try:
# 修复PyTorch兼容性
torch.serialization.add_safe_globals([argparse.Namespace])
# 尝试多种导入方式
try:
from fireredasr.models.fireredasr import FireRedAsr
except ImportError:
try:
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr
except ImportError:
import fireredasr
from fireredasr.models.fireredasr import FireRedAsr
model = FireRedAsr.from_pretrained("aed", model_dir)
# 尝试使用GPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
try:
model = model.to(device)
except:
pass
print(f" 使用GPU: {torch.cuda.get_device_name(0)}")
else:
print(" 使用CPU")
if hasattr(model, 'eval'):
model.eval()
return model, device
except Exception as e:
print(f" 模型加载失败: {e}")
print("请检查:")
print("1. FireRedASR是否正确安装")
print("2. 路径配置是否正确")
print("3. 依赖库是否完整")
return None, None
def call_ai_model(sentence):
url = "http://192.168.3.8:7777/v1/chat/completions"
prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容(一定不对文本做补充和提示)。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。
{sentence}"""
payload = {
"model": "Qwen3-14B",
"messages": [
{"role": "user", "content": prompt}
]
}
try:
response = requests.post(url, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
processed_text = result["choices"][0]["message"]["content"].strip()
return processed_text
except requests.exceptions.RequestException as e:
print(f" AI模型调用失败: {e}")
return sentence # 失败时返回原文
def transcribe_segments(segment_files, model, device):
print(" 开始语音识别...")
results = []
use_gpu = device.startswith("cuda")
for i, segment_info in enumerate(segment_files):
segment_file = segment_info['file']
start_time = segment_info['start_time']
end_time = segment_info['end_time']
duration = segment_info.get('duration', end_time - start_time)
# 显示片段信息
segment_type = segment_info.get('type', 'normal')
if segment_type == 'merged':
original_count = segment_info.get('segment_count', 1)
print(
f" [{i + 1}/{len(segment_files)}] {start_time:.1f}s-{end_time:.1f}s (时长:{duration:.1f}s, 拼装自{original_count}个片段)")
else:
print(f" [{i + 1}/{len(segment_files)}] {start_time:.1f}s-{end_time:.1f}s (时长:{duration:.1f}s)")
try:
batch_uttid = [f"segment_{i:03d}"]
batch_wav_path = [str(segment_file)]
config = {
"use_gpu": 1 if use_gpu else 0,
"beam_size": 5,
"nbest": 1,
"decode_max_len": 0
}
with torch.no_grad():
transcription_result = model.transcribe(
batch_uttid, batch_wav_path, config
)
if transcription_result and len(transcription_result) > 0:
text = transcription_result[0].get('text', '').strip()
if text:
result = {
'start_time': start_time,
'end_time': end_time,
'text': text
}
# 如果是拼装片段,添加额外信息
if segment_type == 'merged':
result['is_merged'] = True
result['original_segment_count'] = segment_info.get('segment_count', 1)
results.append(result)
# print(f" {text}")
else:
print(f" 无内容")
# 清理GPU缓存
if use_gpu:
torch.cuda.empty_cache()
except Exception as e:
print(f" 错误: {e}")
continue
print(f" 识别完成: {len(results)}/{len(segment_files)}个片段")
return results
def process_transcription_results(results):
print(" 开始AI文本后处理...")
if not results:
print(" 没有识别结果")
return []
# 按时间排序
results.sort(key=lambda x: x['start_time'])
processed_data = []
total_segments = len(results)
for i, item in enumerate(results):
start_time = item['start_time']
end_time = item['end_time']
original_content = item['text']
# print(f" [{i + 1}/{total_segments}] 处理: {start_time:.1f}s-{end_time:.1f}s")
# print(f" 原文: {original_content}")
# 调用AI模型处理
processed_content = call_ai_model(original_content)
# print(f" 处理后: {processed_content}")
# 构建最终结果
processed_item = {
'start': int(start_time * 1000), # 转换为毫秒
'end': int(end_time * 1000), # 转换为毫秒
'content': processed_content,
'original_content': original_content, # 保留原始识别结果
'start_time_hms': format_time_hms(start_time),
'end_time_hms': format_time_hms(end_time),
'duration': round(end_time - start_time, 2)
}
# 添加拼装信息
if item.get('is_merged', False):
processed_item['is_merged'] = True
processed_item['original_segment_count'] = item.get('original_segment_count', 1)
processed_data.append(processed_item)
return processed_data
def cleanup_temp_files():
temp_dir = Path("temp_segments")
if temp_dir.exists():
for file in temp_dir.glob("segment_*.wav"):
file.unlink(missing_ok=True)
try:
temp_dir.rmdir()
except:
pass
def process_audio_file(audio_path):
global model, device
try:
# 1. 加载音频
audio_data, sr, audio_pydub = load_audio(audio_path)
# 2. 能量分段
segments = energy_based_segmentation(
audio_data, sr,
silence_threshold=0.001, # 静音阈值
min_segment_length=1.0, # 最小段长1秒
max_segment_length=455.0 # 最大段长455秒
)
# 3. 细分超过60秒的片段
refined_segments = refine_long_segments(segments, audio_data, sr, max_duration=60.0)
# 4. 顺序拼装片段
merged_segments = merge_segments_sequentially(refined_segments, max_duration=45.0)
# 5. 创建音频片段
segment_files = create_audio_segments(merged_segments, audio_pydub)
# 6. 语音识别
results = transcribe_segments(segment_files, model, device)
# 7. AI文本后处理
processed_data = process_transcription_results(results)
# 8. 添加统计信息
stats = {
'total_segments': len(results),
'original_segments': len(refined_segments),
'merged_segments': len(merged_segments),
'compression_rate': round((1 - len(merged_segments) / len(refined_segments)) * 100, 1) if len(
refined_segments) > 0 else 0,
'total_duration': round(sum(item['duration'] for item in processed_data), 2)
}
return processed_data, stats
finally:
cleanup_temp_files()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
global model, device
# print("🎵 === FireRedASR专业级ASR API服务启动中 === 🎵")
model_dir = "D:/workstation/voice-txt/FireRedASR-test/FireRedASR/pretrained_models/FireRedASR-AED-L"
if not Path(model_dir).exists():
print(f" 模型目录不存在: {model_dir}")
else:
model, device = load_fireredasr_model(model_dir)
if model is None:
print(" 模型加载失败,服务可能无法正常工作")
else:
print(" 模型加载成功,服务已就绪")
yield # 应用运行期间
# 关闭时执行(可选)
print(" 服务正在关闭...")
cleanup_temp_files()
# 创建FastAPI应用实例
app = FastAPI(lifespan=lifespan)
@app.post("/transcriptions")
async def create_transcription(file: UploadFile = File(...)):
temp_file_path = None
try:
# 检查模型是否已加载
if model is None:
return JSONResponse(
content={
"data": "",
"message": "模型未加载,请重启服务",
"code": 500
},
status_code=200
)
# 检查文件类型
if not file.filename.lower().endswith(('.mp3', '.wav', '.m4a', '.flac')):
return JSONResponse(
content={
"data": "",
"message": "不支持的音频格式,请上传mp3、wav、m4a或flac文件",
"code": 400
},
status_code=200
)
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
temp_file_path = temp_file.name
# 将上传文件内容写入临时文件
contents = await file.read()
temp_file.write(contents)
print(f" 开始处理音频文件: {file.filename}")
start_time = time.time()
# 处理音频文件
result_data, stats = process_audio_file(temp_file_path)
elapsed_time = time.time() - start_time
print(f" 处理完成,耗时: {elapsed_time:.2f}")
return JSONResponse(
content={
"data": jsonable_encoder(result_data),
"stats": stats,
"processing_time": round(elapsed_time, 2),
"message": "success",
"code": 200
},
status_code=200
)
except Exception as e:
print(f" 处理错误: {str(e)}")
import traceback
traceback.print_exc()
return JSONResponse(
content={
"data": "",
"message": str(e),
"code": 500
},
status_code=200
)
finally:
# 确保删除临时文件
if temp_file_path and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
except:
pass
@app.get("/health")
async def health_check():
"""健康检查接口"""
return JSONResponse(
content={
"status": "healthy",
"model_loaded": model is not None,
"device": device if device else "unknown",
"message": "Professional ASR Service is running"
},
status_code=200
)
@app.get("/")
async def root():
"""根路径接口"""
return JSONResponse(
content={
"service": "FireRedASR Professional ASR API",
"version": "1.0.0",
"features": [
"能量分段(VAD)",
"智能片段拼装",
"长片段细分",
"AI文本后处理",
"完整时间戳",
"GPU加速支持"
],
"endpoints": {
"transcriptions": "POST /transcriptions - 语音转录(含AI后处理)",
"health": "GET /health - 健康检查"
}
},
status_code=200
)
if __name__ == '__main__':
print("🎵 === FireRedASR专业级ASR API服务启动 === 🎵")
uvicorn.run(app, host='0.0.0.0', port=7777)