Files
usa/crawler/tests/test_extraction.py
2026-03-02 23:21:07 +08:00

199 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
爬虫数据清洗与字段映射测试
验证 extractor_rules、extractor_dashscope、db_merge 的正确性
"""
import os
import sqlite3
import tempfile
from pathlib import Path
import pytest
# 确保 crawler 在 path 中
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in __import__("sys").path:
__import__("sys").path.insert(0, str(ROOT))
from extractor_rules import extract_from_news as extract_rules
class TestExtractorRules:
"""规则提取器单元测试"""
def test_trump_1000_targets_no_bases(self):
"""特朗普说伊朗有1000个军事目标遭到袭击 -> 不应提取 bases_destroyed/bases_damaged"""
text = "特朗普说伊朗有1000个军事目标遭到袭击美国已做好进一步打击准备"
out = extract_rules(text)
delta = out.get("combat_losses_delta", {})
for side in ("us", "iran"):
if side in delta:
assert delta[side].get("bases_destroyed") is None, f"{side} bases_destroyed 不应被提取"
assert delta[side].get("bases_damaged") is None, f"{side} bases_damaged 不应被提取"
def test_base_damaged_when_explicit(self):
"""阿萨德基地遭袭 -> 应提取 key_location_updates且 combat_losses 若有则正确"""
text = "阿萨德空军基地遭袭,损失严重"
out = extract_rules(text)
# 规则会触发 key_location_updates因为 base_attacked 且匹配 阿萨德)
assert "key_location_updates" in out
kl = out["key_location_updates"]
assert len(kl) >= 1
assert any(u.get("side") == "us" and "阿萨德" in (u.get("name_keywords") or "") for u in kl)
def test_us_personnel_killed(self):
"""3名美军阵亡 -> personnel_killed=3"""
text = "据报道3名美军阵亡另有5人受伤"
out = extract_rules(text)
assert "combat_losses_delta" in out
us = out["combat_losses_delta"].get("us", {})
assert us.get("personnel_killed") == 3
assert us.get("personnel_wounded") == 5
def test_iran_personnel_killed(self):
"""10名伊朗士兵死亡"""
text = "伊朗方面称10名伊朗士兵死亡"
out = extract_rules(text)
iran = out.get("combat_losses_delta", {}).get("iran", {})
assert iran.get("personnel_killed") == 10
def test_civilian_us_context(self):
"""美军空袭造成50名平民伤亡 -> loss_us"""
text = "美军空袭造成50名平民伤亡"
out = extract_rules(text)
us = out.get("combat_losses_delta", {}).get("us", {})
assert us.get("civilian_killed") == 50
def test_civilian_iran_context(self):
"""伊朗空袭造成伊拉克平民50人伤亡 -> loss_ir"""
text = "伊朗空袭造成伊拉克平民50人伤亡"
out = extract_rules(text)
iran = out.get("combat_losses_delta", {}).get("iran", {})
assert iran.get("civilian_killed") == 50
def test_drone_attribution_iran(self):
"""美军击落伊朗10架无人机 -> iran drones=10"""
text = "美军击落伊朗10架无人机"
out = extract_rules(text)
iran = out.get("combat_losses_delta", {}).get("iran", {})
assert iran.get("drones") == 10
def test_empty_or_short_text(self):
"""短文本或无内容 -> 无 combat_losses"""
assert extract_rules("") == {} or "combat_losses_delta" not in extract_rules("")
assert "combat_losses_delta" not in extract_rules("abc") or not extract_rules("abc").get("combat_losses_delta")
class TestDbMerge:
"""db_merge 字段映射与增量逻辑测试"""
@pytest.fixture
def temp_db(self):
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
path = f.name
yield path
try:
os.unlink(path)
except OSError:
pass
def test_merge_combat_losses_delta(self, temp_db):
"""merge 正确将 combat_losses_delta 叠加到 DB"""
from db_merge import merge
merge({"combat_losses_delta": {"us": {"personnel_killed": 3, "personnel_wounded": 2}}}, db_path=temp_db)
merge({"combat_losses_delta": {"us": {"personnel_killed": 2}}}, db_path=temp_db)
conn = sqlite3.connect(temp_db)
row = conn.execute("SELECT personnel_killed, personnel_wounded FROM combat_losses WHERE side='us'").fetchone()
conn.close()
assert row[0] == 5
assert row[1] == 2
def test_merge_all_combat_fields(self, temp_db):
"""merge 正确映射所有 combat_losses 字段"""
from db_merge import merge
delta = {
"personnel_killed": 1,
"personnel_wounded": 2,
"civilian_killed": 3,
"civilian_wounded": 4,
"bases_destroyed": 1,
"bases_damaged": 2,
"aircraft": 3,
"warships": 4,
"armor": 5,
"vehicles": 6,
"drones": 7,
"missiles": 8,
"helicopters": 9,
"submarines": 10,
}
merge({"combat_losses_delta": {"iran": delta}}, db_path=temp_db)
conn = sqlite3.connect(temp_db)
row = conn.execute(
"""SELECT personnel_killed, personnel_wounded, civilian_killed, civilian_wounded,
bases_destroyed, bases_damaged, aircraft, warships, armor, vehicles,
drones, missiles, helicopters, submarines FROM combat_losses WHERE side='iran'"""
).fetchone()
conn.close()
assert row == (1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
def test_merge_key_location_requires_table(self, temp_db):
"""key_location_updates 需要 key_location 表中有行才能更新"""
from db_merge import merge
conn = sqlite3.connect(temp_db)
conn.execute(
"""CREATE TABLE IF NOT EXISTS key_location (id INTEGER PRIMARY KEY, side TEXT, name TEXT, lat REAL, lng REAL, type TEXT, region TEXT, status TEXT, damage_level INTEGER)"""
)
conn.execute(
"INSERT INTO key_location (side, name, lat, lng, type, region, status, damage_level) VALUES ('us', '阿萨德空军基地', 33.0, 43.0, 'Base', 'IRQ', 'operational', 0)"
)
conn.commit()
conn.close()
merge(
{"key_location_updates": [{"name_keywords": "阿萨德|asad", "side": "us", "status": "attacked", "damage_level": 2}]},
db_path=temp_db,
)
conn = sqlite3.connect(temp_db)
row = conn.execute("SELECT status, damage_level FROM key_location WHERE name LIKE '%阿萨德%'").fetchone()
conn.close()
assert row[0] == "attacked"
assert row[1] == 2
class TestEndToEndTrumpExample:
"""端到端:特朗普 1000 军事目标案例"""
def test_full_pipeline_trump_no_bases(self, tmp_path):
"""完整流程:规则提取 + merge特朗普案例不应增加 bases"""
from db_merge import merge
db_path = str(tmp_path / "test.db")
(tmp_path / "test.db").touch() # 确保文件存在merge 才会执行
merge({"combat_losses_delta": {"us": {"bases_destroyed": 0, "bases_damaged": 0}, "iran": {"bases_destroyed": 0, "bases_damaged": 0}}}, db_path=db_path)
text = "特朗普说伊朗有1000个军事目标遭到袭击"
out = extract_rules(text)
# 规则提取不应包含 bases
assert "combat_losses_delta" not in out or (
"iran" not in out.get("combat_losses_delta", {})
or out["combat_losses_delta"].get("iran", {}).get("bases_destroyed") is None
and out["combat_losses_delta"].get("iran", {}).get("bases_damaged") is None
)
if "combat_losses_delta" in out:
merge(out, db_path=db_path)
conn = sqlite3.connect(db_path)
iran = conn.execute("SELECT bases_destroyed, bases_damaged FROM combat_losses WHERE side='iran'").fetchone()
conn.close()
# 若提取器没输出 basesmerge 不会改;若有错误输出则需要为 0
if iran:
assert iran[0] == 0
assert iran[1] == 0