You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
6.1 KiB

import json
import requests
import csv
from collections import defaultdict
class SimpleOpenAIHubClient:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.openai-hub.com"
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def chat(self, prompt, model="gpt-4.1"):
"""发送prompt并返回模型回答"""
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": 100000,
"temperature": 0.7
}
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=60
)
if response.status_code == 200:
result = response.json()
return result['choices'][0]['message']['content']
else:
return f"错误: {response.status_code} - {response.text}"
except requests.exceptions.RequestException as e:
return f"请求异常: {str(e)}"
def load_json_data(file_path):
"""加载JSON数据"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
print(f"成功加载数据,共 {len(data)} 条记录")
return data
except Exception as e:
print(f"加载JSON文件时出错: {e}")
return []
def group_by_id(data):
"""按d_id分组数据"""
grouped = defaultdict(list)
for item in data:
d_id = item.get('d_id')
content = item.get('content', '')
if d_id is not None and content:
grouped[d_id].append({
'start_time': item.get('start_time'),
'end_time': item.get('end_time'),
'content': content
})
# 对每个组内的数据按start_time排序
for d_id in grouped:
grouped[d_id].sort(key=lambda x: x['start_time'])
return grouped
def create_prompt_for_group(group_data):
"""为每个组创建AI prompt"""
# 构建类似JSON的文本数组格式
text_array = []
for i, item in enumerate(group_data):
# 转义双引号
escaped_content = item["content"].replace('"', '""')
text_array.append(f'{{"text": "{escaped_content}"}}')
json_like_text = "[" + ", ".join(text_array) + "]"
prompt_template = """任务:文本合并
要求:
1. 提取JSON数组中所有"text"字段的内容
2. 按原始顺序直接拼接
3. 不修改任何文字
4. 不添加标点符号
5. 不做错误纠正
6. 只输出合并后的文本
输入数据:
<audio_text> %s </audio_text>
输出:"""
return prompt_template % json_like_text
def merge_texts_directly(grouped_data):
"""直接按时间顺序合并文本,不使用AI模型"""
merged_results = {}
total_groups = len(grouped_data)
print(f"开始处理 {total_groups} 个组...")
for i, (d_id, group_data) in enumerate(grouped_data.items(), 1):
print(f"处理组 {i}/{total_groups}: d_id={d_id}, 包含{len(group_data)}个片段")
# 直接拼接所有文本内容
merged_text = ""
for item in group_data:
merged_text += item['content']
# 存储结果
merged_results[d_id] = {
'original_content': json.dumps(
[{"end": item['end_time'], "start": item['start_time'], "text": item['content']} for item in
group_data], ensure_ascii=False),
'merged_text': merged_text,
'segments_count': len(group_data)
}
print(f"完成: d_id={d_id}")
print(f"合并结果预览: {merged_text[:100]}...")
print("-" * 50)
return merged_results
def save_to_csv(merged_results, output_path):
"""保存为CSV文件"""
try:
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['id', 'original_content', 'merged_text', 'status']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# 写入表头
writer.writeheader()
# 按id排序并写入数据
for d_id in sorted(merged_results.keys()):
result = merged_results[d_id]
writer.writerow({
'id': d_id,
'original_content': result['original_content'],
'merged_text': result['merged_text'],
'status': '成功'
})
print(f"CSV文件已保存到: {output_path}")
return True
except Exception as e:
print(f"保存CSV文件时出错: {e}")
return False
def main():
# 输入文件路径
input_file = r"D:\workstation\Data\广播\all_data619-1103.json"
# 输出文件路径
output_file = r"D:\workstation\Data\广播\merged_texts619-1103.csv"
# 加载数据
print("加载JSON数据...")
data = load_json_data(input_file)
if not data:
print("无数据可处理")
return
# 按d_id分组
print("按d_id分组数据...")
grouped_data = group_by_id(data)
print(f"共分为 {len(grouped_data)} 个组")
# 显示分组统计
print("\n分组统计:")
for d_id, group_data in list(grouped_data.items())[:5]: # 显示前5个组的信息
print(f"d_id {d_id}: {len(group_data)} 个片段")
if len(grouped_data) > 5:
print(f"... 还有 {len(grouped_data) - 5} 个组")
# 直接合并文本(不使用AI)
print("\n开始直接文本合并...")
merged_results = merge_texts_directly(grouped_data)
# 保存结果
print("\n保存合并结果...")
if save_to_csv(merged_results, output_file):
print("✅ 所有任务完成!")
print(f"📁 输入文件: {input_file}")
print(f"📄 输出文件: {output_file}")
print(f"📊 处理统计: {len(grouped_data)} 个组,{len(data)} 个原始片段")
else:
print("❌ 保存结果失败")
if __name__ == "__main__":
main()