You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
239 lines
6.8 KiB
239 lines
6.8 KiB
import requests |
|
import json |
|
import pandas as pd |
|
import csv |
|
from typing import List, Dict |
|
import time |
|
|
|
# %% |
|
# 读取CSV文件 |
|
csv_file_path = "ai_ai_broadcast_info.csv" |
|
|
|
# 尝试不同的编码格式 |
|
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin-1', 'cp1252'] |
|
|
|
df = None |
|
|
|
for encoding in encodings: |
|
try: |
|
print(f"尝试使用 {encoding} 编码读取文件...") |
|
df = pd.read_csv(csv_file_path, encoding=encoding) |
|
print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 行数据") |
|
print(f"列名:{list(df.columns)}") |
|
break |
|
except UnicodeDecodeError as e: |
|
print(f" {encoding} 编码失败:{str(e)}") |
|
continue |
|
except Exception as e: |
|
print(f" {encoding} 编码读取时出现其他错误:{str(e)}") |
|
continue |
|
|
|
if df is None: |
|
print("错误:尝试了所有编码格式都无法读取文件") |
|
exit() |
|
|
|
# 检查是否有id和content列 |
|
if 'id' not in df.columns: |
|
print("错误:CSV文件中没有找到'id'列") |
|
print(f"可用列:{list(df.columns)}") |
|
exit() |
|
elif 'content' not in df.columns: |
|
print("错误:CSV文件中没有找到'content'列") |
|
print(f"可用列:{list(df.columns)}") |
|
exit() |
|
|
|
print(f"数据加载完成,可用的ID值:{sorted(df['id'].unique())}") |
|
|
|
|
|
# %% |
|
class SimpleOpenAIHubClient: |
|
def __init__(self, api_key): |
|
self.api_key = api_key |
|
self.base_url = "https://api.openai-hub.com" |
|
self.headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def chat(self, prompt, model="gpt-4.1"): |
|
"""发送prompt并返回模型回答""" |
|
payload = { |
|
"model": model, |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
"max_tokens": 100000, |
|
"temperature": 0.7 |
|
} |
|
|
|
try: |
|
response = requests.post( |
|
f"{self.base_url}/v1/chat/completions", |
|
headers=self.headers, |
|
json=payload, |
|
timeout=60 |
|
) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
return result['choices'][0]['message']['content'] |
|
else: |
|
return f"错误: {response.status_code} - {response.text}" |
|
except requests.exceptions.RequestException as e: |
|
return f"请求异常: {str(e)}" |
|
|
|
|
|
print("AI客户端类定义完成!") |
|
|
|
# %% |
|
# 设置API Key |
|
API_KEY = "sk-XREp2jnIXyZ6UoCnzZeO0ahmLi9OEXuVAtFLojKFpG9gCZ4e" # 请替换为你的实际API Key |
|
|
|
# 初始化AI客户端 |
|
client = SimpleOpenAIHubClient(API_KEY) |
|
|
|
print("AI模型加载完成!") |
|
|
|
# %% |
|
prompt_template = """任务:文本合并 |
|
要求: |
|
1. 提取JSON数组中所有"text"字段的内容 |
|
2. 按原始顺序直接拼接 |
|
3. 不修改任何文字 |
|
4. 不添加标点符号 |
|
5. 不做错误纠正 |
|
6. 只输出合并后的文本 |
|
|
|
输入数据: |
|
<audio_text> %s </audio_text> |
|
|
|
输出:""" |
|
|
|
print("Prompt模板设置完成!") |
|
|
|
# %% |
|
# 批量处理ID 8-26的数据 |
|
target_ids = list(range(8, 618)) # 8到26(包含26) |
|
results = [] |
|
|
|
print(f"开始批量处理ID {target_ids[0]} 到 {target_ids[-1]} 的数据...") |
|
print(f"目标ID列表:{target_ids}") |
|
|
|
# 统计可用的ID |
|
available_ids = df['id'].unique() |
|
missing_ids = [id for id in target_ids if id not in available_ids] |
|
processable_ids = [id for id in target_ids if id in available_ids] |
|
|
|
if missing_ids: |
|
print(f"警告:以下ID在数据中不存在:{missing_ids}") |
|
print(f"将要处理的ID:{processable_ids}") |
|
|
|
# 循环处理每个ID |
|
for current_id in target_ids: |
|
print(f"\n--- 处理ID {current_id} ---") |
|
|
|
# 查找对应ID的行 |
|
target_row = df[df['id'] == current_id] |
|
|
|
if len(target_row) == 0: |
|
print(f"警告:没有找到id={current_id}的数据行") |
|
results.append({ |
|
'id': current_id, |
|
'original_content': "", |
|
'merged_text': "", |
|
'status': "数据不存在" |
|
}) |
|
continue |
|
|
|
# 获取content内容 |
|
target_content = target_row['content'].iloc[0] |
|
content = str(target_content) if pd.notna(target_content) else "" |
|
|
|
print(f"Content内容长度:{len(content)}") |
|
print(f"Content预览:{content[:100]}..." if content else "Content为空") |
|
|
|
if not content or content == 'nan': |
|
print("Content为空,无法处理") |
|
results.append({ |
|
'id': current_id, |
|
'original_content': content, |
|
'merged_text': "", |
|
'status': "内容为空" |
|
}) |
|
continue |
|
|
|
# 构建prompt |
|
prompt = prompt_template % content |
|
|
|
print("正在调用AI模型处理...") |
|
|
|
# 调用AI模型 |
|
try: |
|
ai_response = client.chat(prompt) |
|
|
|
results.append({ |
|
'id': current_id, |
|
'original_content': content, |
|
'merged_text': ai_response, |
|
'status': "成功" |
|
}) |
|
|
|
print("处理完成!") |
|
print(f"合并后文本预览:{ai_response[:100]}...") |
|
|
|
# 添加延时,避免API调用过于频繁 |
|
time.sleep(1) |
|
|
|
except Exception as e: |
|
print(f"处理出错:{str(e)}") |
|
results.append({ |
|
'id': current_id, |
|
'original_content': content, |
|
'merged_text': "", |
|
'status': f"处理出错: {str(e)}" |
|
}) |
|
|
|
print(f"\n=== 批量处理完成!===") |
|
print(f"总共尝试处理:{len(target_ids)} 条记录") |
|
print(f"实际处理完成:{len(results)} 条记录") |
|
|
|
# %% |
|
# 生成处理结果报告 |
|
result_df = pd.DataFrame(results) |
|
|
|
# 显示处理结果统计 |
|
print("\n处理结果统计:") |
|
status_counts = result_df['status'].value_counts() |
|
print(status_counts) |
|
|
|
# 显示各个状态的ID |
|
for status in status_counts.index: |
|
status_ids = result_df[result_df['status'] == status]['id'].tolist() |
|
print(f"{status}的ID:{status_ids}") |
|
|
|
# 显示结果预览 |
|
print("\n结果预览(前5行):") |
|
preview_df = result_df[['id', 'status', 'merged_text']].copy() |
|
preview_df['merged_text_preview'] = preview_df['merged_text'].apply( |
|
lambda x: str(x)[:100] + "..." if len(str(x)) > 100 else str(x) |
|
) |
|
print(preview_df[['id', 'status', 'merged_text_preview']].head()) |
|
|
|
# 保存到CSV文件 |
|
output_file = "merged_results_all.csv" |
|
result_df.to_csv(output_file, index=False, encoding='utf-8-sig') |
|
|
|
print(f"\n结果已保存到文件:{output_file}") |
|
|
|
# 成功处理的统计 |
|
successful_results = [r for r in results if r['status'] == '成功'] |
|
print(f"\n最终统计:") |
|
print(f"成功处理:{len(successful_results)} 条记录") |
|
print(f"失败处理:{len(results) - len(successful_results)} 条记录") |
|
|
|
if successful_results: |
|
successful_ids = [r['id'] for r in successful_results] |
|
print(f"成功处理的ID:{successful_ids}") |