You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

514 lines
20 KiB

import requests
import json
import pandas as pd
from typing import List, Dict
import time
class SimpleOpenAIHubClient:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.openai-hub.com"
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def chat(self, prompt, model="gpt-4.1"):
"""发送prompt并返回模型回答"""
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": 32768,
"temperature": 0.7
}
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=60
)
if response.status_code == 200:
result = response.json()
return result['choices'][0]['message']['content']
else:
return f"错误: {response.status_code} - {response.text}"
except requests.exceptions.RequestException as e:
return f"请求异常: {str(e)}"
print("AI客户端类定义完成!")
# 设置API Key
API_KEY = "sk-XREp2jnIXyZ6UoCnzZeO0ahmLi9OEXuVAtFLojKFpG9gCZ4e" # 请替换为你的实际API Key
# 初始化AI客户端
client = SimpleOpenAIHubClient(API_KEY)
print("AI模型加载完成!")
# 定义批量标注的Prompt模板
BATCH_SEGMENTATION_PROMPT = """你是一个专业的广播内容段落分割标注员。你的任务是批量判断多个相邻句子对之间是否应该进行段落分割,以便广播员更好地掌握停顿和语调变化。
**完整文本内容上下文:**
{context_text}
**标注规则:**
- 标签0:两个句子属于同一段落,连续播报,轻微停顿
- 标签1:两个句子属于不同段落,需要明显停顿或语调转换
**重要标注要求(请严格遵循):**
- 如果整个文本内容都在讲同一个事,你有理由只输出一段,不是追求分的段越多越细就越好
- 每个分段必须保持原始语句的绝对顺序
- 最终分段数可能等于或小于原始语句数量
- 必须保留所有原始语句文本,不得遗漏任何内容
- 应客户强烈要求,他们需要的是较粗的分段,不要太细,如同一条通告,不需要分段成具体的每个条款之类的,只需要将整个相同的通告分成一段
- 优先考虑较粗的分段,避免过度细分
**广播分段判断依据(偏向粗分段):**
1. **重大主题转换**:从一个完全不同的话题转向另一个话题(如从天气预报转向安全通知)
2. **文档类型变化**:从一个完整文档转向另一个完整文档(如从禁火令转向倡议书)
3. **内容性质变化**:从通知类内容转向完全不同性质的内容(如从法规转向天气预报)
4. **广播节目段落**:明显的广播节目结构变化(如开场白结束进入正式内容)
5. **分点阐述结构**:标题和所有分点条目内容应该合并为一个完整段落(如"森林防火十不准,一不乱扔烟头,二不随意丢弃火种,三不在林区吸烟"等整体合成一段)
**广播内容特别注意(粗分段原则):**
- 整个通告、法令、倡议书等应作为一个段落,不要拆分条款
- 同一主题的多个条款应保持在同一段落
- 只有在完全不同的文档或重大主题转换时才分段
- 广播开场白可以独立成段,但具体内容尽量合并
- 同一类型的预报信息(如天气预报的不同地区)应保持在同一段
- **分点阐述内容的特殊处理**:
- 标题性内容(如"森林防火十不准")与分点条目内容之间不需要分段
- 标题和所有的分点条目(如"一不乱扔烟头""二不随意丢弃火种""三不在林区吸烟"等)应该合并为一个完整段落
- 分点条目之间不需要分段,应该连续播报
- 整个分点阐述结构作为一个完整的内容单元,保持连贯性
**批量标注说明:**
- 每个句子对都有一个source_id,表示来源文档
- 请保持原有的source_id不变
- 将label从-1改为实际的标注结果(0或1)
- 为每个句子对提供简要的分段理由
- 结合上述完整文本内容理解句子对的上下文语境
- **特别重要:倾向于标注更多的0(同段落),减少1(分段)的使用,分点阐述结构应保持为一个完整段落**
现在请对以下句子对进行批量标注:
{batch_sentence_pairs}
请直接输出标注结果,格式如下:
```json
[
{{
"sentence1": "...",
"sentence2": "...",
"label": 0或1,
"reason": "广播分段理由",
"source_id": 原有的source_id
}}
]
```
只输出JSON数据,不要其他说明文字。"""
def load_context_data(csv_file="batch_deduplication_results_619-1103_01.csv"):
"""
从batch_deduplication_results_batch_deduplication_results_619-1103_01.csv加载上下文数据
Args:
csv_file: CSV文件路径
Returns:
字典,key为id,value为final_processed_text
"""
try:
print(f"正在读取上下文数据文件: {csv_file}")
# 尝试不同的编码格式
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin-1', 'cp1252']
context_df = None
for encoding in encodings:
try:
print(f" 尝试使用 {encoding} 编码...")
context_df = pd.read_csv(csv_file, encoding=encoding)
print(f" ✓ 成功使用 {encoding} 编码读取文件")
break
except UnicodeDecodeError:
print(f" {encoding} 编码失败")
continue
except Exception as e:
print(f" {encoding} 编码读取时出现其他错误:{str(e)}")
continue
if context_df is None:
print(f"✗ 错误:尝试了所有编码格式都无法读取文件 {csv_file}")
return {}
if 'id' not in context_df.columns or 'final_processed_text' not in context_df.columns:
print(f"✗ 错误:CSV文件缺少必需列")
print(f" 需要的列: ['id', 'final_processed_text']")
print(f" 实际的列: {list(context_df.columns)}")
return {}
# 创建id到final_processed_text的映射
context_dict = {}
for _, row in context_df.iterrows():
context_dict[row['id']] = row['final_processed_text'] if pd.notna(row['final_processed_text']) else ""
print(f"✓ 成功加载上下文数据")
print(f" - 可用ID数量: {len(context_dict)}")
print(f" - 可用ID列表: {sorted(context_dict.keys())}")
return context_dict
except FileNotFoundError:
print(f"✗ 警告:找不到上下文文件 {csv_file}")
print(" 将在没有上下文的情况下进行标注")
return {}
except Exception as e:
print(f"✗ 读取上下文文件时出错: {str(e)}")
print(" 将在没有上下文的情况下进行标注")
return {}
def process_batch_segmentation(sentence_pairs_data, context_dict, batch_size=10):
"""
批量处理句子对的段落分割标注
Args:
sentence_pairs_data: 句子对数据列表
context_dict: 上下文数据字典
batch_size: 每批处理的数量
Returns:
处理结果列表
"""
all_results = []
total_pairs = len(sentence_pairs_data)
print(f"开始批量标注,总共 {total_pairs} 个句子对")
print(f"每批处理 {batch_size} 个句子对")
# 分批处理
for i in range(0, total_pairs, batch_size):
batch_end = min(i + batch_size, total_pairs)
current_batch = sentence_pairs_data[i:batch_end]
print(f"\n处理第 {i // batch_size + 1} 批 (句子对 {i + 1}-{batch_end})")
try:
# 获取当前批次涉及的source_id的上下文
source_ids_in_batch = set(pair['source_id'] for pair in current_batch)
context_text = ""
for source_id in sorted(source_ids_in_batch):
if source_id in context_dict and context_dict[source_id]:
context_text += f"\n--- Source ID {source_id} 完整文本内容 ---\n"
context_text += context_dict[source_id] # 完整内容,不截断
context_text += "\n"
else:
context_text += f"\n--- Source ID {source_id} ---\n(未找到对应的完整文本内容)\n"
# 准备当前批次的数据
batch_json = json.dumps(current_batch, ensure_ascii=False, indent=2)
# 构建prompt
prompt = BATCH_SEGMENTATION_PROMPT.format(
context_text=context_text,
batch_sentence_pairs=batch_json
)
print(f"发送请求到AI模型...")
print(f" - 涉及source_id: {sorted(source_ids_in_batch)}")
print(f" - 上下文长度: {len(context_text)} 字符")
# 调用AI模型
ai_response = client.chat(prompt)
print(f"收到模型响应")
# 尝试解析JSON响应
try:
# 提取JSON部分(去除可能的markdown格式)
json_start = ai_response.find('[')
json_end = ai_response.rfind(']') + 1
if json_start != -1 and json_end != 0:
json_content = ai_response[json_start:json_end]
batch_deduplication_results_all = json.loads(json_content)
# 验证结果
if isinstance(batch_deduplication_results_all, list) and len(batch_deduplication_results_all) == len(current_batch):
all_results.extend(batch_deduplication_results_all)
print(f"✓ 成功处理 {len(batch_deduplication_results_all)} 个句子对")
else:
print(
f"✗ 响应格式不正确,期望 {len(current_batch)} 个结果,实际得到 {len(batch_deduplication_results_all) if isinstance(batch_deduplication_results_all, list) else 'non-list'}")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": "处理失败:响应格式错误",
"source_id": pair["source_id"]
})
else:
print(f"✗ 无法找到有效的JSON响应")
print(f"原始响应前200字符: {ai_response[:200]}...")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": "处理失败:JSON解析错误",
"source_id": pair["source_id"]
})
except json.JSONDecodeError as e:
print(f"✗ JSON解析失败: {str(e)}")
print(f"原始响应: {ai_response[:200]}...")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"处理失败:{str(e)}",
"source_id": pair["source_id"]
})
# 添加延时,避免API调用过于频繁
time.sleep(2)
except Exception as e:
print(f"✗ 批次处理出错: {str(e)}")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"处理异常:{str(e)}",
"source_id": pair["source_id"]
})
return all_results
# 执行批量标注
print("=" * 60)
print("开始执行批量段落分割标注(从source_id 7开始)")
print("=" * 60)
# 加载上下文数据
context_dict = load_context_data("batch_deduplication_results_619-1103_01.csv")
# 从JSON文件加载数据
input_file = "all_sentence_pairs_for_annotation.json"
try:
print(f"正在读取数据文件: {input_file}")
with open(input_file, 'r', encoding='utf-8') as f:
all_sentence_pairs_data = json.load(f)
print(f"✓ 成功加载数据文件")
print(f" - 总句子对数量: {len(all_sentence_pairs_data)}")
# 检查数据格式
if len(all_sentence_pairs_data) > 0:
sample_item = all_sentence_pairs_data[0]
required_fields = ['sentence1', 'sentence2', 'source_id']
missing_fields = [field for field in required_fields if field not in sample_item]
if missing_fields:
print(f"✗ 数据格式错误,缺少字段: {missing_fields}")
print(f" 实际字段: {list(sample_item.keys())}")
exit()
# 获取所有unique的source_id
all_source_ids = sorted(set(item.get('source_id') for item in all_sentence_pairs_data))
print(f"✓ 发现的source_id列表: {all_source_ids}")
# 【修改】:筛选出source_id >= 7的ID
filtered_source_ids = [sid for sid in all_source_ids if sid >= 7]
print(f"✓ 筛选后的source_id列表(>=7): {filtered_source_ids}")
if not filtered_source_ids:
print("✗ 没有找到source_id >= 7的数据")
exit()
# 统计各source_id的句子对数量
from collections import Counter
source_counts = Counter(item.get('source_id') for item in all_sentence_pairs_data)
print(f" - 各source_id的句子对数量(>=7):")
for source_id in filtered_source_ids:
print(f" source_id {source_id}: {source_counts[source_id]}")
# 检查上下文数据可用性
print(f"\n上下文数据可用性检查:")
available_context_ids = []
missing_context_ids = []
for source_id in filtered_source_ids:
if source_id in context_dict and context_dict[source_id]:
available_context_ids.append(source_id)
print(f" ✓ source_id {source_id}: 上下文长度 {len(context_dict[source_id])} 字符")
else:
missing_context_ids.append(source_id)
print(f" ✗ source_id {source_id}: 缺少上下文数据")
if missing_context_ids:
print(f"\n警告:以下source_id缺少上下文数据: {missing_context_ids}")
print("这些ID的标注可能不够准确")
print(f"\n开始处理source_id >= 7的标注任务...")
except FileNotFoundError:
print(f"✗ 错误:找不到文件 {input_file}")
print("请确保文件存在于当前目录中")
exit()
except json.JSONDecodeError as e:
print(f"✗ 错误:JSON文件格式错误 - {str(e)}")
exit()
except Exception as e:
print(f"✗ 错误:读取文件时出现异常 - {str(e)}")
exit()
# 准备所有结果
all_results = []
# 【修改】:遍历筛选后的source_id(>=7)
for current_source_id in filtered_source_ids:
print(f"\n{'=' * 50}")
print(f"处理 source_id: {current_source_id}")
print(f"{'=' * 50}")
# 筛选当前source_id的数据
current_sentence_pairs = [item for item in all_sentence_pairs_data if item.get('source_id') == current_source_id]
if len(current_sentence_pairs) == 0:
print(f"✗ 警告:source_id={current_source_id} 没有找到句子对数据")
continue
print(f" - 当前source_id的句子对数量: {len(current_sentence_pairs)}")
# 显示第一个句子对预览
if len(current_sentence_pairs) > 0:
first_pair = current_sentence_pairs[0]
print(f" - 第一个句子对预览:")
print(f" 句子1: {first_pair['sentence1'][:60]}...")
print(f" 句子2: {first_pair['sentence2'][:60]}...")
# 检查上下文数据
if current_source_id in context_dict and context_dict[current_source_id]:
print(f" - 上下文数据: 可用,长度 {len(context_dict[current_source_id])} 字符")
print(f" - 上下文预览: {context_dict[current_source_id][:100]}...")
else:
print(f" - 上下文数据: 不可用")
try:
print(f" - 开始处理标注...")
# 执行批量标注(每批处理5个句子对)
current_results = process_batch_segmentation(current_sentence_pairs, context_dict, batch_size=8)
# 添加到总结果中
all_results.extend(current_results)
# 统计当前source_id的结果
current_successful = len([r for r in current_results if r['label'] in [0, 1]])
current_failed = len([r for r in current_results if r['label'] == -1])
print(f" ✓ source_id {current_source_id} 处理完成")
print(f" - 成功标注: {current_successful}")
print(f" - 失败数量: {current_failed}")
if current_successful > 0:
current_label_0 = len([r for r in current_results if r['label'] == 0])
current_label_1 = len([r for r in current_results if r['label'] == 1])
print(f" - 标签0(同段落): {current_label_0}")
print(f" - 标签1(分段): {current_label_1}")
# 添加延时,避免处理过快
time.sleep(1)
except Exception as e:
print(f" ✗ source_id {current_source_id} 处理失败: {str(e)}")
# 为失败的source_id添加错误记录
for pair in current_sentence_pairs:
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"source_id处理异常:{str(e)}",
"source_id": pair["source_id"]
})
print(f"\n{'=' * 60}")
print("所有source_id(>=7)处理完成!")
print(f"{'=' * 60}")
# 使用所有结果进行后续处理
results = all_results
print(f"\n{'=' * 60}")
print("批量标注完成!")
print(f"{'=' * 60}")
# 统计结果
total_processed = len(results)
successful_labels = len([r for r in results if r['label'] in [0, 1]])
failed_labels = len([r for r in results if r['label'] == -1])
print(f"总处理数量: {total_processed}")
print(f"成功标注: {successful_labels}")
print(f"失败数量: {failed_labels}")
if successful_labels > 0:
label_0_count = len([r for r in results if r['label'] == 0])
label_1_count = len([r for r in results if r['label'] == 1])
print(f"标签0(同段落): {label_0_count}")
print(f"标签1(分段): {label_1_count}")
# 【修改】:保存结果到CSV,文件名增加后缀以区分
result_df = pd.DataFrame(results)
output_file = "segmentation_labeling_results_from_7.csv"
result_df.to_csv(output_file, index=False, encoding='utf-8-sig')
print(f"\n结果已保存到: {output_file}")
# 显示前几条结果
print(f"\n前3条标注结果预览:")
for i, result in enumerate(results[:3]):
print(f"\n{i + 1}. Source ID: {result['source_id']}")
print(f" 句子1: {result['sentence1'][:50]}...")
print(f" 句子2: {result['sentence2'][:50]}...")
print(f" 标签: {result['label']}")
print(f" 理由: {result['reason']}")
# 【修改】:保存详细的JSON结果,文件名增加后缀以区分
json_output_file = 'segmentation_results_from_7.json'
with open(json_output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"详细JSON结果已保存到: {json_output_file}")
# 【新增】:显示处理的source_id范围统计
if results:
processed_source_ids = sorted(set(r['source_id'] for r in results))
print(f"\n实际处理的source_id范围: {min(processed_source_ids)} - {max(processed_source_ids)}")
print(f"共处理了 {len(processed_source_ids)} 个source_id")