You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
206 lines
6.1 KiB
206 lines
6.1 KiB
import json |
|
import requests |
|
import csv |
|
from collections import defaultdict |
|
|
|
|
|
class SimpleOpenAIHubClient: |
|
def __init__(self, api_key): |
|
self.api_key = api_key |
|
self.base_url = "https://api.openai-hub.com" |
|
self.headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def chat(self, prompt, model="gpt-4.1"): |
|
"""发送prompt并返回模型回答""" |
|
payload = { |
|
"model": model, |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
"max_tokens": 100000, |
|
"temperature": 0.7 |
|
} |
|
try: |
|
response = requests.post( |
|
f"{self.base_url}/v1/chat/completions", |
|
headers=self.headers, |
|
json=payload, |
|
timeout=60 |
|
) |
|
if response.status_code == 200: |
|
result = response.json() |
|
return result['choices'][0]['message']['content'] |
|
else: |
|
return f"错误: {response.status_code} - {response.text}" |
|
except requests.exceptions.RequestException as e: |
|
return f"请求异常: {str(e)}" |
|
|
|
|
|
def load_json_data(file_path): |
|
"""加载JSON数据""" |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
data = json.load(file) |
|
print(f"成功加载数据,共 {len(data)} 条记录") |
|
return data |
|
except Exception as e: |
|
print(f"加载JSON文件时出错: {e}") |
|
return [] |
|
|
|
|
|
def group_by_id(data): |
|
"""按d_id分组数据""" |
|
grouped = defaultdict(list) |
|
for item in data: |
|
d_id = item.get('d_id') |
|
content = item.get('content', '') |
|
if d_id is not None and content: |
|
grouped[d_id].append({ |
|
'start_time': item.get('start_time'), |
|
'end_time': item.get('end_time'), |
|
'content': content |
|
}) |
|
|
|
# 对每个组内的数据按start_time排序 |
|
for d_id in grouped: |
|
grouped[d_id].sort(key=lambda x: x['start_time']) |
|
|
|
return grouped |
|
|
|
|
|
def create_prompt_for_group(group_data): |
|
"""为每个组创建AI prompt""" |
|
# 构建类似JSON的文本数组格式 |
|
text_array = [] |
|
for i, item in enumerate(group_data): |
|
# 转义双引号 |
|
escaped_content = item["content"].replace('"', '""') |
|
text_array.append(f'{{"text": "{escaped_content}"}}') |
|
|
|
json_like_text = "[" + ", ".join(text_array) + "]" |
|
|
|
prompt_template = """任务:文本合并 |
|
要求: |
|
1. 提取JSON数组中所有"text"字段的内容 |
|
2. 按原始顺序直接拼接 |
|
3. 不修改任何文字 |
|
4. 不添加标点符号 |
|
5. 不做错误纠正 |
|
6. 只输出合并后的文本 |
|
|
|
输入数据: |
|
<audio_text> %s </audio_text> |
|
|
|
输出:""" |
|
|
|
return prompt_template % json_like_text |
|
|
|
|
|
def merge_texts_directly(grouped_data): |
|
"""直接按时间顺序合并文本,不使用AI模型""" |
|
merged_results = {} |
|
total_groups = len(grouped_data) |
|
|
|
print(f"开始处理 {total_groups} 个组...") |
|
|
|
for i, (d_id, group_data) in enumerate(grouped_data.items(), 1): |
|
print(f"处理组 {i}/{total_groups}: d_id={d_id}, 包含{len(group_data)}个片段") |
|
|
|
# 直接拼接所有文本内容 |
|
merged_text = "" |
|
for item in group_data: |
|
merged_text += item['content'] |
|
|
|
# 存储结果 |
|
merged_results[d_id] = { |
|
'original_content': json.dumps( |
|
[{"end": item['end_time'], "start": item['start_time'], "text": item['content']} for item in |
|
group_data], ensure_ascii=False), |
|
'merged_text': merged_text, |
|
'segments_count': len(group_data) |
|
} |
|
|
|
print(f"完成: d_id={d_id}") |
|
print(f"合并结果预览: {merged_text[:100]}...") |
|
print("-" * 50) |
|
|
|
return merged_results |
|
|
|
|
|
def save_to_csv(merged_results, output_path): |
|
"""保存为CSV文件""" |
|
try: |
|
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile: |
|
fieldnames = ['id', 'original_content', 'merged_text', 'status'] |
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |
|
|
|
# 写入表头 |
|
writer.writeheader() |
|
|
|
# 按id排序并写入数据 |
|
for d_id in sorted(merged_results.keys()): |
|
result = merged_results[d_id] |
|
writer.writerow({ |
|
'id': d_id, |
|
'original_content': result['original_content'], |
|
'merged_text': result['merged_text'], |
|
'status': '成功' |
|
}) |
|
|
|
print(f"CSV文件已保存到: {output_path}") |
|
return True |
|
except Exception as e: |
|
print(f"保存CSV文件时出错: {e}") |
|
return False |
|
|
|
|
|
def main(): |
|
# 输入文件路径 |
|
input_file = r"D:\workstation\Data\广播\all_data619-1103.json" |
|
|
|
# 输出文件路径 |
|
output_file = r"D:\workstation\Data\广播\merged_texts619-1103.csv" |
|
|
|
# 加载数据 |
|
print("加载JSON数据...") |
|
data = load_json_data(input_file) |
|
|
|
if not data: |
|
print("无数据可处理") |
|
return |
|
|
|
# 按d_id分组 |
|
print("按d_id分组数据...") |
|
grouped_data = group_by_id(data) |
|
print(f"共分为 {len(grouped_data)} 个组") |
|
|
|
# 显示分组统计 |
|
print("\n分组统计:") |
|
for d_id, group_data in list(grouped_data.items())[:5]: # 显示前5个组的信息 |
|
print(f"d_id {d_id}: {len(group_data)} 个片段") |
|
if len(grouped_data) > 5: |
|
print(f"... 还有 {len(grouped_data) - 5} 个组") |
|
|
|
# 直接合并文本(不使用AI) |
|
print("\n开始直接文本合并...") |
|
merged_results = merge_texts_directly(grouped_data) |
|
|
|
# 保存结果 |
|
print("\n保存合并结果...") |
|
if save_to_csv(merged_results, output_file): |
|
print("✅ 所有任务完成!") |
|
print(f"📁 输入文件: {input_file}") |
|
print(f"📄 输出文件: {output_file}") |
|
print(f"📊 处理统计: {len(grouped_data)} 个组,{len(data)} 个原始片段") |
|
else: |
|
print("❌ 保存结果失败") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |