You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
581 lines
22 KiB
581 lines
22 KiB
import requests |
|
import json |
|
import pandas as pd |
|
from typing import List, Dict |
|
import time |
|
|
|
|
|
class SimpleOpenAIHubClient: |
|
def __init__(self, api_key): |
|
self.api_key = api_key |
|
self.base_url = "https://api.openai-hub.com" |
|
self.headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def chat(self, prompt, model="gpt-4.1"): |
|
"""发送prompt并返回模型回答""" |
|
payload = { |
|
"model": model, |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
"max_tokens": 32768, |
|
"temperature": 0.7 |
|
} |
|
|
|
try: |
|
response = requests.post( |
|
f"{self.base_url}/v1/chat/completions", |
|
headers=self.headers, |
|
json=payload, |
|
timeout=60 |
|
) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
return result['choices'][0]['message']['content'] |
|
else: |
|
return f"错误: {response.status_code} - {response.text}" |
|
except requests.exceptions.RequestException as e: |
|
return f"请求异常: {str(e)}" |
|
|
|
|
|
print("AI客户端类定义完成!") |
|
|
|
# 设置API Key |
|
API_KEY = "sk-XREp2jnIXyZ6UoCnzZeO0ahmLi9OEXuVAtFLojKFpG9gCZ4e" # 请替换为你的实际API Key |
|
|
|
# 初始化AI客户端 |
|
client = SimpleOpenAIHubClient(API_KEY) |
|
print("AI模型加载完成!") |
|
|
|
# 定义批量标注的Prompt模板 |
|
BATCH_SEGMENTATION_PROMPT = """你是一个专业的广播内容段落分割标注员。你的任务是批量判断多个相邻句子对之间是否应该进行段落分割,以便广播员更好地掌握停顿和语调变化。 |
|
|
|
**完整文本内容上下文:** |
|
{context_text} |
|
|
|
**标注规则:** |
|
- 标签0:两个句子属于同一段落,连续播报,轻微停顿 |
|
- 标签1:两个句子属于不同段落,需要明显停顿或语调转换 |
|
|
|
**重要标注要求(请严格遵循):** |
|
- 如果整个文本内容都在讲同一个事,你有理由只输出一段,不是追求分的段越多越细就越好 |
|
- 每个分段必须保持原始语句的绝对顺序 |
|
- 最终分段数可能等于或小于原始语句数量 |
|
- 必须保留所有原始语句文本,不得遗漏任何内容 |
|
- 应客户强烈要求,他们需要的是较粗的分段,不要太细,如同一条通告,不需要分段成具体的每个条款之类的,只需要将整个相同的通告分成一段 |
|
- 优先考虑较粗的分段,避免过度细分 |
|
|
|
**广播分段判断依据(偏向粗分段):** |
|
1. **重大主题转换**:从一个完全不同的话题转向另一个话题(如从天气预报转向安全通知) |
|
2. **文档类型变化**:从一个完整文档转向另一个完整文档(如从禁火令转向倡议书) |
|
3. **内容性质变化**:从通知类内容转向完全不同性质的内容(如从法规转向天气预报) |
|
4. **广播节目段落**:明显的广播节目结构变化(如开场白结束进入正式内容) |
|
5. **分点阐述结构**:标题和所有分点条目内容应该合并为一个完整段落(如"森林防火十不准,一不乱扔烟头,二不随意丢弃火种,三不在林区吸烟"等整体合成一段) |
|
|
|
**广播内容特别注意(粗分段原则):** |
|
- 整个通告、法令、倡议书等应作为一个段落,不要拆分条款 |
|
- 同一主题的多个条款应保持在同一段落 |
|
- 只有在完全不同的文档或重大主题转换时才分段 |
|
- 广播开场白可以独立成段,但具体内容尽量合并 |
|
- 同一类型的预报信息(如天气预报的不同地区)应保持在同一段 |
|
- **分点阐述内容的特殊处理**: |
|
- 标题性内容(如"森林防火十不准")与分点条目内容之间不需要分段 |
|
- 标题和所有的分点条目(如"一不乱扔烟头"、"二不随意丢弃火种"、"三不在林区吸烟"等)应该合并为一个完整段落 |
|
- 分点条目之间不需要分段,应该连续播报 |
|
- 整个分点阐述结构作为一个完整的内容单元,保持连贯性 |
|
|
|
**批量标注说明:** |
|
- 每个句子对都有一个source_id,表示来源文档 |
|
- 请保持原有的source_id不变 |
|
- 将label从-1改为实际的标注结果(0或1) |
|
- 为每个句子对提供简要的分段理由 |
|
- 结合上述完整文本内容理解句子对的上下文语境 |
|
- **特别重要:倾向于标注更多的0(同段落),减少1(分段)的使用,分点阐述结构应保持为一个完整段落** |
|
|
|
现在请对以下句子对进行批量标注: |
|
|
|
{batch_sentence_pairs} |
|
|
|
请直接输出标注结果,格式如下: |
|
```json |
|
[ |
|
{{ |
|
"sentence1": "...", |
|
"sentence2": "...", |
|
"label": 0或1, |
|
"reason": "广播分段理由", |
|
"source_id": 原有的source_id |
|
}} |
|
] |
|
``` |
|
|
|
只输出JSON数据,不要其他说明文字。""" |
|
|
|
|
|
def load_context_data(csv_file="batch_deduplication_results_619-1103_01.csv"): |
|
""" |
|
从batch_deduplication_results_619-1103_01.csv加载上下文数据 |
|
|
|
Args: |
|
csv_file: CSV文件路径 |
|
|
|
Returns: |
|
字典,key为id,value为final_processed_text |
|
""" |
|
try: |
|
print(f"正在读取上下文数据文件: {csv_file}") |
|
|
|
# 尝试不同的编码格式 |
|
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin-1', 'cp1252'] |
|
context_df = None |
|
|
|
for encoding in encodings: |
|
try: |
|
print(f" 尝试使用 {encoding} 编码...") |
|
context_df = pd.read_csv(csv_file, encoding=encoding) |
|
print(f" ✓ 成功使用 {encoding} 编码读取文件") |
|
break |
|
except UnicodeDecodeError: |
|
print(f" {encoding} 编码失败") |
|
continue |
|
except Exception as e: |
|
print(f" {encoding} 编码读取时出现其他错误:{str(e)}") |
|
continue |
|
|
|
if context_df is None: |
|
print(f"✗ 错误:尝试了所有编码格式都无法读取文件 {csv_file}") |
|
return {} |
|
|
|
if 'id' not in context_df.columns or 'final_processed_text' not in context_df.columns: |
|
print(f"✗ 错误:CSV文件缺少必需列") |
|
print(f" 需要的列: ['id', 'final_processed_text']") |
|
print(f" 实际的列: {list(context_df.columns)}") |
|
return {} |
|
|
|
# 创建id到final_processed_text的映射 |
|
context_dict = {} |
|
for _, row in context_df.iterrows(): |
|
context_dict[row['id']] = row['final_processed_text'] if pd.notna(row['final_processed_text']) else "" |
|
|
|
print(f"✓ 成功加载上下文数据") |
|
print(f" - 可用ID数量: {len(context_dict)}") |
|
print(f" - 可用ID列表: {sorted(context_dict.keys())}") |
|
|
|
return context_dict |
|
|
|
except FileNotFoundError: |
|
print(f"✗ 警告:找不到上下文文件 {csv_file}") |
|
print(" 将在没有上下文的情况下进行标注") |
|
return {} |
|
except Exception as e: |
|
print(f"✗ 读取上下文文件时出错: {str(e)}") |
|
print(" 将在没有上下文的情况下进行标注") |
|
return {} |
|
|
|
|
|
def load_failed_data_from_json(json_file="segmentation_results_from_7.json"): |
|
""" |
|
从JSON结果文件中加载标注失败的数据 |
|
|
|
Args: |
|
json_file: JSON结果文件路径 |
|
|
|
Returns: |
|
失败数据列表 |
|
""" |
|
try: |
|
print(f"正在读取JSON结果文件: {json_file}") |
|
with open(json_file, 'r', encoding='utf-8') as f: |
|
all_results = json.load(f) |
|
|
|
print(f"✓ 成功加载JSON文件") |
|
print(f" - 总结果数量: {len(all_results)}") |
|
|
|
# 筛选出失败的数据(label为-1) |
|
failed_data = [item for item in all_results if item.get('label') == -1] |
|
successful_data = [item for item in all_results if item.get('label') in [0, 1]] |
|
|
|
print(f" - 成功标注数量: {len(successful_data)}") |
|
print(f" - 失败标注数量: {len(failed_data)}") |
|
|
|
if len(failed_data) == 0: |
|
print("✓ 没有发现失败的标注数据,无需重新处理") |
|
return [], all_results |
|
|
|
# 统计失败原因 |
|
from collections import Counter |
|
failure_reasons = Counter(item.get('reason', '未知错误') for item in failed_data) |
|
print(f"\n失败原因统计:") |
|
for reason, count in failure_reasons.most_common(): |
|
print(f" - {reason}: {count}次") |
|
|
|
# 统计涉及的source_id |
|
failed_source_ids = sorted(set(item.get('source_id') for item in failed_data)) |
|
print(f"\n涉及的source_id: {failed_source_ids}") |
|
|
|
return failed_data, all_results |
|
|
|
except FileNotFoundError: |
|
print(f"✗ 错误:找不到文件 {json_file}") |
|
return [], [] |
|
except json.JSONDecodeError as e: |
|
print(f"✗ 错误:JSON文件格式错误 - {str(e)}") |
|
return [], [] |
|
except Exception as e: |
|
print(f"✗ 错误:读取JSON文件时出现异常 - {str(e)}") |
|
return [], [] |
|
|
|
|
|
def convert_failed_data_to_sentence_pairs(failed_data): |
|
""" |
|
将失败数据转换为句子对格式,供重新标注使用 |
|
|
|
Args: |
|
failed_data: 失败数据列表 |
|
|
|
Returns: |
|
句子对格式的数据列表 |
|
""" |
|
sentence_pairs_data = [] |
|
|
|
for item in failed_data: |
|
sentence_pair = { |
|
"sentence1": item.get("sentence1", ""), |
|
"sentence2": item.get("sentence2", ""), |
|
"source_id": item.get("source_id"), |
|
"label": -1 # 标记为待标注 |
|
} |
|
sentence_pairs_data.append(sentence_pair) |
|
|
|
return sentence_pairs_data |
|
|
|
|
|
def process_batch_segmentation(sentence_pairs_data, context_dict, batch_size=8): |
|
""" |
|
批量处理句子对的段落分割标注 |
|
|
|
Args: |
|
sentence_pairs_data: 句子对数据列表 |
|
context_dict: 上下文数据字典 |
|
batch_size: 每批处理的数量 |
|
|
|
Returns: |
|
处理结果列表 |
|
""" |
|
all_results = [] |
|
total_pairs = len(sentence_pairs_data) |
|
|
|
print(f"开始批量标注,总共 {total_pairs} 个句子对") |
|
print(f"每批处理 {batch_size} 个句子对") |
|
|
|
# 分批处理 |
|
for i in range(0, total_pairs, batch_size): |
|
batch_end = min(i + batch_size, total_pairs) |
|
current_batch = sentence_pairs_data[i:batch_end] |
|
|
|
print(f"\n处理第 {i // batch_size + 1} 批 (句子对 {i + 1}-{batch_end})") |
|
|
|
try: |
|
# 获取当前批次涉及的source_id的上下文 |
|
source_ids_in_batch = set(pair['source_id'] for pair in current_batch) |
|
context_text = "" |
|
|
|
for source_id in sorted(source_ids_in_batch): |
|
if source_id in context_dict and context_dict[source_id]: |
|
context_text += f"\n--- Source ID {source_id} 完整文本内容 ---\n" |
|
context_text += context_dict[source_id] # 完整内容,不截断 |
|
context_text += "\n" |
|
else: |
|
context_text += f"\n--- Source ID {source_id} ---\n(未找到对应的完整文本内容)\n" |
|
|
|
# 准备当前批次的数据 |
|
batch_json = json.dumps(current_batch, ensure_ascii=False, indent=2) |
|
|
|
# 构建prompt |
|
prompt = BATCH_SEGMENTATION_PROMPT.format( |
|
context_text=context_text, |
|
batch_sentence_pairs=batch_json |
|
) |
|
|
|
print(f"发送请求到AI模型...") |
|
print(f" - 涉及source_id: {sorted(source_ids_in_batch)}") |
|
print(f" - 上下文长度: {len(context_text)} 字符") |
|
print(f" - Prompt总长度: {len(prompt)} 字符") |
|
|
|
# 调用AI模型 |
|
ai_response = client.chat(prompt) |
|
|
|
print(f"收到模型响应") |
|
|
|
# 尝试解析JSON响应 |
|
try: |
|
# 提取JSON部分(去除可能的markdown格式) |
|
json_start = ai_response.find('[') |
|
json_end = ai_response.rfind(']') + 1 |
|
|
|
if json_start != -1 and json_end != 0: |
|
json_content = ai_response[json_start:json_end] |
|
batch_results = json.loads(json_content) |
|
|
|
# 验证结果 |
|
if isinstance(batch_results, list) and len(batch_results) == len(current_batch): |
|
all_results.extend(batch_results) |
|
print(f"✓ 成功处理 {len(batch_results)} 个句子对") |
|
else: |
|
print( |
|
f"✗ 响应格式不正确,期望 {len(current_batch)} 个结果,实际得到 {len(batch_results) if isinstance(batch_results, list) else 'non-list'}") |
|
# 添加错误记录 |
|
for j, pair in enumerate(current_batch): |
|
all_results.append({ |
|
"sentence1": pair["sentence1"], |
|
"sentence2": pair["sentence2"], |
|
"label": -1, |
|
"reason": "重试失败:响应格式错误", |
|
"source_id": pair["source_id"] |
|
}) |
|
else: |
|
print(f"✗ 无法找到有效的JSON响应") |
|
print(f"原始响应前200字符: {ai_response[:200]}...") |
|
# 添加错误记录 |
|
for j, pair in enumerate(current_batch): |
|
all_results.append({ |
|
"sentence1": pair["sentence1"], |
|
"sentence2": pair["sentence2"], |
|
"label": -1, |
|
"reason": "重试失败:JSON解析错误", |
|
"source_id": pair["source_id"] |
|
}) |
|
|
|
except json.JSONDecodeError as e: |
|
print(f"✗ JSON解析失败: {str(e)}") |
|
print(f"原始响应: {ai_response[:200]}...") |
|
|
|
# 添加错误记录 |
|
for j, pair in enumerate(current_batch): |
|
all_results.append({ |
|
"sentence1": pair["sentence1"], |
|
"sentence2": pair["sentence2"], |
|
"label": -1, |
|
"reason": f"重试失败:{str(e)}", |
|
"source_id": pair["source_id"] |
|
}) |
|
|
|
# 添加延时,避免API调用过于频繁 |
|
time.sleep(2) |
|
|
|
except Exception as e: |
|
print(f"✗ 批次处理出错: {str(e)}") |
|
|
|
# 添加错误记录 |
|
for j, pair in enumerate(current_batch): |
|
all_results.append({ |
|
"sentence1": pair["sentence1"], |
|
"sentence2": pair["sentence2"], |
|
"label": -1, |
|
"reason": f"重试异常:{str(e)}", |
|
"source_id": pair["source_id"] |
|
}) |
|
|
|
return all_results |
|
|
|
|
|
def merge_results(original_results, retry_results): |
|
""" |
|
合并原始结果和重试结果 |
|
|
|
Args: |
|
original_results: 原始完整结果列表 |
|
retry_results: 重试结果列表 |
|
|
|
Returns: |
|
合并后的完整结果列表 |
|
""" |
|
print("正在合并结果...") |
|
|
|
# 创建重试结果的映射,用于快速查找 |
|
retry_map = {} |
|
for result in retry_results: |
|
key = (result['sentence1'], result['sentence2'], result['source_id']) |
|
retry_map[key] = result |
|
|
|
merged_results = [] |
|
replaced_count = 0 |
|
|
|
for original_result in original_results: |
|
key = (original_result['sentence1'], original_result['sentence2'], original_result['source_id']) |
|
|
|
# 如果原始结果是失败的,并且重试结果存在,则用重试结果替换 |
|
if original_result.get('label') == -1 and key in retry_map: |
|
merged_results.append(retry_map[key]) |
|
replaced_count += 1 |
|
else: |
|
merged_results.append(original_result) |
|
|
|
print(f"✓ 合并完成,替换了 {replaced_count} 个失败结果") |
|
return merged_results |
|
|
|
|
|
# 执行重新标注失败数据 |
|
print("=" * 60) |
|
print("开始重新标注失败数据") |
|
print("=" * 60) |
|
|
|
# 加载上下文数据 |
|
context_dict = load_context_data("batch_deduplication_results_619-1103_01.csv") |
|
|
|
# 从JSON文件加载失败数据 |
|
failed_data, original_results = load_failed_data_from_json("segmentation_results_from_7.json") |
|
|
|
if len(failed_data) == 0: |
|
print("没有需要重新标注的数据,程序结束。") |
|
exit() |
|
|
|
print(f"\n开始重新标注 {len(failed_data)} 个失败的句子对...") |
|
|
|
# 将失败数据转换为句子对格式 |
|
retry_sentence_pairs = convert_failed_data_to_sentence_pairs(failed_data) |
|
|
|
# 按source_id分组处理(与原始代码保持一致的处理方式) |
|
from collections import defaultdict |
|
|
|
grouped_by_source_id = defaultdict(list) |
|
for pair in retry_sentence_pairs: |
|
grouped_by_source_id[pair['source_id']].append(pair) |
|
|
|
print(f"失败数据涉及 {len(grouped_by_source_id)} 个source_id") |
|
|
|
# 处理每个source_id的失败数据 |
|
all_retry_results = [] |
|
|
|
for source_id in sorted(grouped_by_source_id.keys()): |
|
current_sentence_pairs = grouped_by_source_id[source_id] |
|
|
|
print(f"\n{'=' * 50}") |
|
print(f"重新处理 source_id: {source_id}") |
|
print(f"{'=' * 50}") |
|
print(f" - 该source_id的失败句子对数量: {len(current_sentence_pairs)}") |
|
|
|
# 显示第一个句子对预览 |
|
if len(current_sentence_pairs) > 0: |
|
first_pair = current_sentence_pairs[0] |
|
print(f" - 第一个句子对预览:") |
|
print(f" 句子1: {first_pair['sentence1'][:60]}...") |
|
print(f" 句子2: {first_pair['sentence2'][:60]}...") |
|
|
|
# 检查上下文数据 |
|
if source_id in context_dict and context_dict[source_id]: |
|
print(f" - 上下文数据: 可用,长度 {len(context_dict[source_id])} 字符") |
|
print(f" - 上下文预览: {context_dict[source_id][:100]}...") |
|
else: |
|
print(f" - 上下文数据: 不可用") |
|
|
|
try: |
|
print(f" - 开始重新标注...") |
|
|
|
# 执行批量标注(batch_size=8) |
|
current_results = process_batch_segmentation(current_sentence_pairs, context_dict, batch_size=8) |
|
|
|
# 添加到总结果中 |
|
all_retry_results.extend(current_results) |
|
|
|
# 统计当前source_id的结果 |
|
current_successful = len([r for r in current_results if r['label'] in [0, 1]]) |
|
current_failed = len([r for r in current_results if r['label'] == -1]) |
|
|
|
print(f" ✓ source_id {source_id} 重新标注完成") |
|
print(f" - 成功标注: {current_successful}") |
|
print(f" - 仍然失败: {current_failed}") |
|
|
|
if current_successful > 0: |
|
current_label_0 = len([r for r in current_results if r['label'] == 0]) |
|
current_label_1 = len([r for r in current_results if r['label'] == 1]) |
|
print(f" - 标签0(同段落): {current_label_0}") |
|
print(f" - 标签1(分段): {current_label_1}") |
|
|
|
# 添加延时 |
|
time.sleep(1) |
|
|
|
except Exception as e: |
|
print(f" ✗ source_id {source_id} 重新标注失败: {str(e)}") |
|
|
|
# 为失败的source_id添加错误记录 |
|
for pair in current_sentence_pairs: |
|
all_retry_results.append({ |
|
"sentence1": pair["sentence1"], |
|
"sentence2": pair["sentence2"], |
|
"label": -1, |
|
"reason": f"source_id重试异常:{str(e)}", |
|
"source_id": pair["source_id"] |
|
}) |
|
|
|
print(f"\n{'=' * 60}") |
|
print("重新标注完成!") |
|
print(f"{'=' * 60}") |
|
|
|
# 统计重试结果 |
|
retry_successful = len([r for r in all_retry_results if r['label'] in [0, 1]]) |
|
retry_failed = len([r for r in all_retry_results if r['label'] == -1]) |
|
|
|
print(f"重试结果统计:") |
|
print(f" - 总重试数量: {len(all_retry_results)}") |
|
print(f" - 重试成功: {retry_successful}") |
|
print(f" - 重试仍失败: {retry_failed}") |
|
|
|
if retry_successful > 0: |
|
retry_label_0 = len([r for r in all_retry_results if r['label'] == 0]) |
|
retry_label_1 = len([r for r in all_retry_results if r['label'] == 1]) |
|
print(f" - 标签0(同段落): {retry_label_0}") |
|
print(f" - 标签1(分段): {retry_label_1}") |
|
|
|
# 合并原始结果和重试结果 |
|
final_results = merge_results(original_results, all_retry_results) |
|
|
|
# 统计最终结果 |
|
final_successful = len([r for r in final_results if r['label'] in [0, 1]]) |
|
final_failed = len([r for r in final_results if r['label'] == -1]) |
|
|
|
print(f"\n最终结果统计:") |
|
print(f" - 总数据量: {len(final_results)}") |
|
print(f" - 成功标注: {final_successful}") |
|
print(f" - 失败标注: {final_failed}") |
|
print(f" - 成功率: {final_successful / len(final_results) * 100:.2f}%") |
|
|
|
# 保存合并后的结果 |
|
final_result_df = pd.DataFrame(final_results) |
|
final_csv_file = "segmentation_results_from_7_retried.csv" |
|
final_result_df.to_csv(final_csv_file, index=False, encoding='utf-8-sig') |
|
|
|
print(f"\n最终结果已保存到: {final_csv_file}") |
|
|
|
# 保存详细的JSON结果 |
|
final_json_file = 'segmentation_results_from_7_retried.json' |
|
with open(final_json_file, 'w', encoding='utf-8') as f: |
|
json.dump(final_results, f, ensure_ascii=False, indent=2) |
|
|
|
print(f"详细JSON结果已保存到: {final_json_file}") |
|
|
|
# 显示重试前后对比 |
|
print(f"\n重试前后对比:") |
|
original_successful = len([r for r in original_results if r['label'] in [0, 1]]) |
|
original_failed = len([r for r in original_results if r['label'] == -1]) |
|
print(f" - 重试前成功率: {original_successful / len(original_results) * 100:.2f}%") |
|
print(f" - 重试后成功率: {final_successful / len(final_results) * 100:.2f}%") |
|
print(f" - 成功率提升: {(final_successful - original_successful) / len(final_results) * 100:.2f}%") |
|
|
|
# 显示前几条重试成功的结果 |
|
successful_retries = [r for r in all_retry_results if r['label'] in [0, 1]] |
|
if len(successful_retries) > 0: |
|
print(f"\n前3条重试成功的结果预览:") |
|
for i, result in enumerate(successful_retries[:3]): |
|
print(f"\n{i + 1}. Source ID: {result['source_id']}") |
|
print(f" 句子1: {result['sentence1'][:50]}...") |
|
print(f" 句子2: {result['sentence2'][:50]}...") |
|
print(f" 标签: {result['label']}") |
|
print(f" 理由: {result['reason']}") |