|
|
|
|
|
import sys |
|
|
sys.path.append('D:\\workstation\\voice-txt\\FireRedASR-test\\FireRedASR') |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import json |
|
|
import numpy as np |
|
|
import torch |
|
|
import argparse |
|
|
import requests |
|
|
from pathlib import Path |
|
|
from pydub import AudioSegment |
|
|
import librosa |
|
|
|
|
|
|
|
|
def format_time_hms(seconds): |
|
|
"""将秒数转换为 HH:MM:SS 格式""" |
|
|
hours = int(seconds // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
secs = int(seconds % 60) |
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}" |
|
|
|
|
|
|
|
|
def load_audio(audio_path): |
|
|
"""加载音频文件""" |
|
|
print(f"📁 加载音频: {audio_path}") |
|
|
|
|
|
# 使用librosa加载音频数据用于能量分析 |
|
|
audio_data, sr = librosa.load(str(audio_path), sr=16000, mono=True) |
|
|
|
|
|
# 使用pydub加载用于分段导出 |
|
|
audio_pydub = AudioSegment.from_file(str(audio_path)) |
|
|
audio_pydub = audio_pydub.set_frame_rate(16000).set_channels(1) |
|
|
|
|
|
duration = len(audio_data) / sr |
|
|
print(f" 时长: {duration:.2f}秒, 采样率: {sr}Hz") |
|
|
|
|
|
return audio_data, sr, audio_pydub |
|
|
|
|
|
|
|
|
def energy_based_segmentation(audio_data, sr, |
|
|
silence_threshold=0.001, |
|
|
min_segment_length=2.0, |
|
|
max_segment_length=455.0): |
|
|
"""基于能量的音频分段""" |
|
|
print("🧠 能量分段分析...") |
|
|
|
|
|
# 计算短时能量 |
|
|
frame_len = int(0.025 * sr) # 25ms |
|
|
hop_len = int(0.005 * sr) # 5ms |
|
|
|
|
|
energy = [] |
|
|
for i in range(0, len(audio_data) - frame_len + 1, hop_len): |
|
|
frame = audio_data[i:i + frame_len] |
|
|
frame_energy = np.sum(frame ** 2) / len(frame) |
|
|
energy.append(frame_energy) |
|
|
|
|
|
energy = np.array(energy) |
|
|
if energy.max() > 0: |
|
|
energy = energy / energy.max() |
|
|
|
|
|
# 找静音点 |
|
|
silence_points = [] |
|
|
for i, e in enumerate(energy): |
|
|
if e < silence_threshold: |
|
|
time_point = i * 0.005 # hop_length转换为秒 |
|
|
silence_points.append(time_point) |
|
|
|
|
|
if not silence_points: |
|
|
print(" 未检测到静音,使用固定分段") |
|
|
return fixed_segmentation(len(audio_data) / sr, max_segment_length) |
|
|
|
|
|
# 合并相邻静音 |
|
|
silence_intervals = [] |
|
|
if silence_points: |
|
|
current_start = silence_points[0] |
|
|
current_end = silence_points[0] |
|
|
|
|
|
for point in silence_points[1:]: |
|
|
if point - current_end <= 0.1: # 0.1s内视为连续 |
|
|
current_end = point |
|
|
else: |
|
|
if current_end - current_start >= 0.05: # 至少50ms |
|
|
silence_intervals.append((current_start, current_end)) |
|
|
current_start = point |
|
|
current_end = point |
|
|
|
|
|
if current_end - current_start >= 0.05: |
|
|
silence_intervals.append((current_start, current_end)) |
|
|
|
|
|
# 生成语音段 |
|
|
segments = [] |
|
|
last_end = 0.0 |
|
|
audio_duration = len(audio_data) / sr |
|
|
|
|
|
for silence_start, silence_end in silence_intervals: |
|
|
if silence_start - last_end >= min_segment_length: |
|
|
segments.append({ |
|
|
'start_time': last_end, |
|
|
'end_time': silence_start, |
|
|
'type': 'natural' |
|
|
}) |
|
|
elif segments and silence_start - segments[-1]['start_time'] <= max_segment_length: |
|
|
segments[-1]['end_time'] = silence_start |
|
|
|
|
|
last_end = silence_end |
|
|
|
|
|
# 处理末尾 |
|
|
if audio_duration - last_end >= min_segment_length: |
|
|
segments.append({ |
|
|
'start_time': last_end, |
|
|
'end_time': audio_duration, |
|
|
'type': 'natural' |
|
|
}) |
|
|
elif segments: |
|
|
segments[-1]['end_time'] = audio_duration |
|
|
|
|
|
# 处理过长段落 |
|
|
final_segments = [] |
|
|
for segment in segments: |
|
|
duration = segment['end_time'] - segment['start_time'] |
|
|
if duration > max_segment_length: |
|
|
num_subsegments = int(np.ceil(duration / max_segment_length)) |
|
|
sub_duration = duration / num_subsegments |
|
|
|
|
|
for i in range(num_subsegments): |
|
|
sub_start = segment['start_time'] + i * sub_duration |
|
|
sub_end = min(sub_start + sub_duration, segment['end_time']) |
|
|
final_segments.append({ |
|
|
'start_time': sub_start, |
|
|
'end_time': sub_end, |
|
|
'type': 'forced' |
|
|
}) |
|
|
else: |
|
|
final_segments.append(segment) |
|
|
|
|
|
print(f" 完成分段: {len(final_segments)}个片段") |
|
|
return final_segments |
|
|
|
|
|
|
|
|
def fixed_segmentation(total_duration, segment_duration): |
|
|
"""固定时长分段(备用)""" |
|
|
segments = [] |
|
|
start_time = 0 |
|
|
while start_time < total_duration: |
|
|
end_time = min(start_time + segment_duration, total_duration) |
|
|
segments.append({ |
|
|
'start_time': start_time, |
|
|
'end_time': end_time, |
|
|
'type': 'fixed' |
|
|
}) |
|
|
start_time = end_time |
|
|
return segments |
|
|
|
|
|
|
|
|
def create_audio_segments(segments, audio_pydub): |
|
|
"""创建音频片段文件""" |
|
|
print("✂️ 创建音频片段...") |
|
|
|
|
|
temp_dir = Path("temp_segments") |
|
|
temp_dir.mkdir(exist_ok=True) |
|
|
|
|
|
segment_files = [] |
|
|
for i, segment_info in enumerate(segments): |
|
|
start_time = segment_info['start_time'] |
|
|
end_time = segment_info['end_time'] |
|
|
|
|
|
# 转换为毫秒并添加25ms padding |
|
|
start_ms = max(0, int(start_time * 1000) - 25) |
|
|
end_ms = min(len(audio_pydub), int(end_time * 1000) + 25) |
|
|
|
|
|
# 提取并保存音频片段 |
|
|
segment = audio_pydub[start_ms:end_ms] |
|
|
segment_file = temp_dir / f"segment_{i:03d}.wav" |
|
|
segment.export(str(segment_file), format="wav") |
|
|
|
|
|
segment_files.append({ |
|
|
'file': segment_file, |
|
|
'start_time': start_time, |
|
|
'end_time': end_time, |
|
|
'index': i |
|
|
}) |
|
|
|
|
|
print(f" 创建完成: {len(segment_files)}个文件") |
|
|
return segment_files |
|
|
|
|
|
|
|
|
def setup_fireredasr_environment(): |
|
|
"""设置FireRedASR环境""" |
|
|
# 尝试添加FireRedASR路径到sys.path |
|
|
possible_paths = [ |
|
|
"D:/workstation/voice-txt/FireRedASR-test/FireRedASR", |
|
|
"./FireRedASR", |
|
|
"../FireRedASR", |
|
|
"FireRedASR" |
|
|
] |
|
|
|
|
|
for path in possible_paths: |
|
|
if Path(path).exists(): |
|
|
if str(Path(path).absolute()) not in sys.path: |
|
|
sys.path.insert(0, str(Path(path).absolute())) |
|
|
print(f" 添加路径: {Path(path).absolute()}") |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def load_fireredasr_model(model_dir): |
|
|
"""加载FireRedASR模型""" |
|
|
print("🚀 加载FireRedASR模型...") |
|
|
|
|
|
# 设置环境 |
|
|
if not setup_fireredasr_environment(): |
|
|
print("⚠️ 未找到FireRedASR路径,尝试直接导入...") |
|
|
|
|
|
try: |
|
|
# 修复PyTorch兼容性 |
|
|
torch.serialization.add_safe_globals([argparse.Namespace]) |
|
|
|
|
|
# 尝试多种导入方式 |
|
|
try: |
|
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
except ImportError: |
|
|
try: |
|
|
from FireRedASR.fireredasr.models.fireredasr import FireRedAsr |
|
|
except ImportError: |
|
|
import fireredasr |
|
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
|
|
model = FireRedAsr.from_pretrained("aed", model_dir) |
|
|
|
|
|
# 尝试使用GPU |
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
model = model.to(device) |
|
|
except: |
|
|
pass |
|
|
print(f" 使用GPU: {torch.cuda.get_device_name(0)}") |
|
|
else: |
|
|
print(" 使用CPU") |
|
|
|
|
|
if hasattr(model, 'eval'): |
|
|
model.eval() |
|
|
return model, device |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 模型加载失败: {e}") |
|
|
print("请检查:") |
|
|
print("1. FireRedASR是否正确安装") |
|
|
print("2. 路径配置是否正确") |
|
|
print("3. 依赖库是否完整") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def transcribe_segments(segment_files, model, device): |
|
|
"""转录音频片段""" |
|
|
print("🎤 开始语音识别...") |
|
|
|
|
|
results = [] |
|
|
use_gpu = device.startswith("cuda") |
|
|
|
|
|
for i, segment_info in enumerate(segment_files): |
|
|
segment_file = segment_info['file'] |
|
|
start_time = segment_info['start_time'] |
|
|
end_time = segment_info['end_time'] |
|
|
|
|
|
print(f" [{i + 1}/{len(segment_files)}] {start_time:.1f}s-{end_time:.1f}s") |
|
|
|
|
|
try: |
|
|
batch_uttid = [f"segment_{i:03d}"] |
|
|
batch_wav_path = [str(segment_file)] |
|
|
|
|
|
config = { |
|
|
"use_gpu": 1 if use_gpu else 0, |
|
|
"beam_size": 5, |
|
|
"nbest": 1, |
|
|
"decode_max_len": 0 |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
transcription_result = model.transcribe( |
|
|
batch_uttid, batch_wav_path, config |
|
|
) |
|
|
|
|
|
if transcription_result and len(transcription_result) > 0: |
|
|
text = transcription_result[0].get('text', '').strip() |
|
|
if text: |
|
|
results.append({ |
|
|
'start_time': start_time, |
|
|
'end_time': end_time, |
|
|
'text': text |
|
|
}) |
|
|
print(f" ✅ {text}") |
|
|
else: |
|
|
print(f" ⚠️ 无内容") |
|
|
|
|
|
# 清理GPU缓存 |
|
|
if use_gpu: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ❌ 错误: {e}") |
|
|
continue |
|
|
|
|
|
print(f" 识别完成: {len(results)}/{len(segment_files)}个片段") |
|
|
return results |
|
|
|
|
|
|
|
|
def save_results(results, base_filename): |
|
|
"""保存JSON和TXT格式结果""" |
|
|
if not results: |
|
|
print("⚠️ 没有识别结果") |
|
|
return |
|
|
|
|
|
# 按时间排序 |
|
|
results.sort(key=lambda x: x['start_time']) |
|
|
|
|
|
# 1. 保存JSON格式 |
|
|
json_file = base_filename.replace('.txt', '.json') |
|
|
json_data = [] |
|
|
for result in results: |
|
|
json_data.append({ |
|
|
'start': result['start_time'], |
|
|
'end': result['end_time'], |
|
|
'content': result['text'] |
|
|
}) |
|
|
|
|
|
with open(json_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(json_data, f, ensure_ascii=False, indent=2) |
|
|
print(f"✅ JSON保存: {json_file}") |
|
|
|
|
|
# 2. 保存干净TXT格式 |
|
|
clean_txt_file = base_filename.replace('.txt', '_clean.txt') |
|
|
with open(clean_txt_file, 'w', encoding='utf-8') as f: |
|
|
for result in results: |
|
|
start_time_str = format_time_hms(result['start_time']) |
|
|
end_time_str = format_time_hms(result['end_time']) |
|
|
f.write(f"{start_time_str}-{end_time_str}\n") |
|
|
f.write(f"{result['text']}\n") |
|
|
print(f"✅ 干净文本保存: {clean_txt_file}") |
|
|
|
|
|
|
|
|
def call_ai_model(sentence): |
|
|
"""调用本地AI模型进行文本处理""" |
|
|
url = "http://192.168.3.8:7777/v1/chat/completions" |
|
|
|
|
|
prompt = f"""对以下文本添加标点符号,中文数字转阿拉伯数字。不修改文字内容。句末可以是冒号、逗号、问号、感叹号和句号等任意合适标点。 |
|
|
{sentence}""" |
|
|
|
|
|
payload = { |
|
|
"model": "Qwen3-14B", |
|
|
"messages": [ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post(url, json=payload, timeout=120) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
processed_text = result["choices"][0]["message"]["content"].strip() |
|
|
return processed_text |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f" ❌ API调用失败: {e}") |
|
|
return sentence # 失败时返回原文 |
|
|
|
|
|
|
|
|
def process_transcription_results(input_json): |
|
|
"""处理转录结果""" |
|
|
print("🔄 开始文本后处理...") |
|
|
|
|
|
# 读取原始JSON文件 |
|
|
with open(input_json, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
print(f" 读取到 {len(data)} 个文本段") |
|
|
|
|
|
processed_data = [] |
|
|
total_segments = len(data) |
|
|
|
|
|
for i, item in enumerate(data): |
|
|
start_time = item['start'] |
|
|
end_time = item['end'] |
|
|
original_content = item['content'] |
|
|
|
|
|
print(f" [{i + 1}/{total_segments}] 处理: {start_time:.1f}s-{end_time:.1f}s") |
|
|
print(f" 原文: {original_content}") |
|
|
|
|
|
# 调用AI模型处理 |
|
|
processed_content = call_ai_model(original_content) |
|
|
print(f" 处理后: {processed_content}") |
|
|
|
|
|
processed_data.append({ |
|
|
'start': start_time, |
|
|
'end': end_time, |
|
|
'content': processed_content |
|
|
}) |
|
|
|
|
|
return processed_data |
|
|
|
|
|
|
|
|
def save_processed_results(processed_data, output_json, output_txt): |
|
|
"""保存处理后的结果""" |
|
|
print("💾 保存处理结果...") |
|
|
|
|
|
# 1. 保存JSON格式 |
|
|
with open(output_json, 'w', encoding='utf-8') as f: |
|
|
json.dump(processed_data, f, ensure_ascii=False, indent=2) |
|
|
print(f"✅ JSON保存: {output_json}") |
|
|
|
|
|
# 2. 保存干净TXT格式 |
|
|
with open(output_txt, 'w', encoding='utf-8') as f: |
|
|
for item in processed_data: |
|
|
start_time_str = format_time_hms(item['start']) |
|
|
end_time_str = format_time_hms(item['end']) |
|
|
f.write(f"{start_time_str}-{end_time_str}\n") |
|
|
f.write(f"{item['content']}\n") |
|
|
print(f"✅ 干净文本保存: {output_txt}") |
|
|
|
|
|
|
|
|
def cleanup_temp_files(): |
|
|
"""清理临时文件""" |
|
|
temp_dir = Path("temp_segments") |
|
|
if temp_dir.exists(): |
|
|
for file in temp_dir.glob("segment_*.wav"): |
|
|
file.unlink(missing_ok=True) |
|
|
temp_dir.rmdir() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
print("🎵 === FireRedASR能量分段转录 + 文本后处理工具 === 🎵") |
|
|
|
|
|
# 配置文件路径 |
|
|
input_audio = "test.mp3" |
|
|
output_base = "test.txt" |
|
|
model_dir = "D:/workstation/voice-txt/FireRedASR-test/FireRedASR/pretrained_models/FireRedASR-AED-L" |
|
|
|
|
|
# 检查文件 |
|
|
if not Path(input_audio).exists(): |
|
|
print(f"❌ 音频文件不存在: {input_audio}") |
|
|
return |
|
|
|
|
|
if not Path(model_dir).exists(): |
|
|
print(f"❌ 模型目录不存在: {model_dir}") |
|
|
return |
|
|
|
|
|
total_start_time = time.time() |
|
|
|
|
|
try: |
|
|
# === 第一阶段:ASR转录 === |
|
|
print("\n=== 第一阶段:ASR转录 ===") |
|
|
asr_start_time = time.time() |
|
|
|
|
|
# 1. 加载音频 |
|
|
audio_data, sr, audio_pydub = load_audio(input_audio) |
|
|
|
|
|
# 2. 能量分段 |
|
|
segments = energy_based_segmentation( |
|
|
audio_data, sr, |
|
|
silence_threshold=0.001, # 静音阈值 |
|
|
min_segment_length=2.0, # 最小段长2秒 |
|
|
max_segment_length=455.0 # 最大段长455秒 |
|
|
) |
|
|
|
|
|
# 3. 创建音频片段 |
|
|
segment_files = create_audio_segments(segments, audio_pydub) |
|
|
|
|
|
# 4. 加载模型 |
|
|
model, device = load_fireredasr_model(model_dir) |
|
|
if model is None: |
|
|
return |
|
|
|
|
|
# 5. 语音识别 |
|
|
results = transcribe_segments(segment_files, model, device) |
|
|
|
|
|
# 6. 保存原始结果 |
|
|
save_results(results, output_base) |
|
|
|
|
|
asr_elapsed_time = time.time() - asr_start_time |
|
|
print(f"\n✅ ASR转录完成! 耗时: {asr_elapsed_time:.1f}秒") |
|
|
print(f" 成功识别: {len(results)}/{len(segments)}段") |
|
|
|
|
|
# === 第二阶段:文本后处理 === |
|
|
print("\n=== 第二阶段:文本后处理 ===") |
|
|
processing_start_time = time.time() |
|
|
|
|
|
# 处理转录结果 |
|
|
json_file = output_base.replace('.txt', '.json') |
|
|
processed_data = process_transcription_results(json_file) |
|
|
|
|
|
# 保存处理后的结果 |
|
|
output_json = "test_processed.json" |
|
|
output_txt = "test_processed_clean.txt" |
|
|
save_processed_results(processed_data, output_json, output_txt) |
|
|
|
|
|
processing_elapsed_time = time.time() - processing_start_time |
|
|
print(f"\n✅ 文本后处理完成! 耗时: {processing_elapsed_time / 60:.2f}分钟") |
|
|
|
|
|
# === 总结 === |
|
|
total_elapsed_time = time.time() - total_start_time |
|
|
print(f"\n🎉 全部处理完成!") |
|
|
print(f" ASR转录耗时: {asr_elapsed_time:.1f}秒") |
|
|
print(f" 文本处理耗时: {processing_elapsed_time / 60:.2f}分钟") |
|
|
print(f" 总耗时: {total_elapsed_time / 60:.2f}分钟") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n⏹️ 用户中断") |
|
|
except Exception as e: |
|
|
print(f"\n❌ 错误: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
finally: |
|
|
cleanup_temp_files() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |