199 lines
8.0 KiB
Python
199 lines
8.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
爬虫数据清洗与字段映射测试
|
||
验证 extractor_rules、extractor_dashscope、db_merge 的正确性
|
||
"""
|
||
import os
|
||
import sqlite3
|
||
import tempfile
|
||
from pathlib import Path
|
||
|
||
import pytest
|
||
|
||
# 确保 crawler 在 path 中
|
||
ROOT = Path(__file__).resolve().parent.parent
|
||
if str(ROOT) not in __import__("sys").path:
|
||
__import__("sys").path.insert(0, str(ROOT))
|
||
|
||
from extractor_rules import extract_from_news as extract_rules
|
||
|
||
|
||
class TestExtractorRules:
|
||
"""规则提取器单元测试"""
|
||
|
||
def test_trump_1000_targets_no_bases(self):
|
||
"""特朗普说伊朗有1000个军事目标遭到袭击 -> 不应提取 bases_destroyed/bases_damaged"""
|
||
text = "特朗普说伊朗有1000个军事目标遭到袭击,美国已做好进一步打击准备"
|
||
out = extract_rules(text)
|
||
delta = out.get("combat_losses_delta", {})
|
||
for side in ("us", "iran"):
|
||
if side in delta:
|
||
assert delta[side].get("bases_destroyed") is None, f"{side} bases_destroyed 不应被提取"
|
||
assert delta[side].get("bases_damaged") is None, f"{side} bases_damaged 不应被提取"
|
||
|
||
def test_base_damaged_when_explicit(self):
|
||
"""阿萨德基地遭袭 -> 应提取 key_location_updates,且 combat_losses 若有则正确"""
|
||
text = "阿萨德空军基地遭袭,损失严重"
|
||
out = extract_rules(text)
|
||
# 规则会触发 key_location_updates(因为 base_attacked 且匹配 阿萨德)
|
||
assert "key_location_updates" in out
|
||
kl = out["key_location_updates"]
|
||
assert len(kl) >= 1
|
||
assert any(u.get("side") == "us" and "阿萨德" in (u.get("name_keywords") or "") for u in kl)
|
||
|
||
def test_us_personnel_killed(self):
|
||
"""3名美军阵亡 -> personnel_killed=3"""
|
||
text = "据报道,3名美军阵亡,另有5人受伤"
|
||
out = extract_rules(text)
|
||
assert "combat_losses_delta" in out
|
||
us = out["combat_losses_delta"].get("us", {})
|
||
assert us.get("personnel_killed") == 3
|
||
assert us.get("personnel_wounded") == 5
|
||
|
||
def test_iran_personnel_killed(self):
|
||
"""10名伊朗士兵死亡"""
|
||
text = "伊朗方面称10名伊朗士兵死亡"
|
||
out = extract_rules(text)
|
||
iran = out.get("combat_losses_delta", {}).get("iran", {})
|
||
assert iran.get("personnel_killed") == 10
|
||
|
||
def test_civilian_us_context(self):
|
||
"""美军空袭造成50名平民伤亡 -> loss_us"""
|
||
text = "美军空袭造成50名平民伤亡"
|
||
out = extract_rules(text)
|
||
us = out.get("combat_losses_delta", {}).get("us", {})
|
||
assert us.get("civilian_killed") == 50
|
||
|
||
def test_civilian_iran_context(self):
|
||
"""伊朗空袭造成伊拉克平民50人伤亡 -> loss_ir"""
|
||
text = "伊朗空袭造成伊拉克平民50人伤亡"
|
||
out = extract_rules(text)
|
||
iran = out.get("combat_losses_delta", {}).get("iran", {})
|
||
assert iran.get("civilian_killed") == 50
|
||
|
||
def test_drone_attribution_iran(self):
|
||
"""美军击落伊朗10架无人机 -> iran drones=10"""
|
||
text = "美军击落伊朗10架无人机"
|
||
out = extract_rules(text)
|
||
iran = out.get("combat_losses_delta", {}).get("iran", {})
|
||
assert iran.get("drones") == 10
|
||
|
||
def test_empty_or_short_text(self):
|
||
"""短文本或无内容 -> 无 combat_losses"""
|
||
assert extract_rules("") == {} or "combat_losses_delta" not in extract_rules("")
|
||
assert "combat_losses_delta" not in extract_rules("abc") or not extract_rules("abc").get("combat_losses_delta")
|
||
|
||
|
||
class TestDbMerge:
|
||
"""db_merge 字段映射与增量逻辑测试"""
|
||
|
||
@pytest.fixture
|
||
def temp_db(self):
|
||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||
path = f.name
|
||
yield path
|
||
try:
|
||
os.unlink(path)
|
||
except OSError:
|
||
pass
|
||
|
||
def test_merge_combat_losses_delta(self, temp_db):
|
||
"""merge 正确将 combat_losses_delta 叠加到 DB"""
|
||
from db_merge import merge
|
||
|
||
merge({"combat_losses_delta": {"us": {"personnel_killed": 3, "personnel_wounded": 2}}}, db_path=temp_db)
|
||
merge({"combat_losses_delta": {"us": {"personnel_killed": 2}}}, db_path=temp_db)
|
||
|
||
conn = sqlite3.connect(temp_db)
|
||
row = conn.execute("SELECT personnel_killed, personnel_wounded FROM combat_losses WHERE side='us'").fetchone()
|
||
conn.close()
|
||
assert row[0] == 5
|
||
assert row[1] == 2
|
||
|
||
def test_merge_all_combat_fields(self, temp_db):
|
||
"""merge 正确映射所有 combat_losses 字段"""
|
||
from db_merge import merge
|
||
|
||
delta = {
|
||
"personnel_killed": 1,
|
||
"personnel_wounded": 2,
|
||
"civilian_killed": 3,
|
||
"civilian_wounded": 4,
|
||
"bases_destroyed": 1,
|
||
"bases_damaged": 2,
|
||
"aircraft": 3,
|
||
"warships": 4,
|
||
"armor": 5,
|
||
"vehicles": 6,
|
||
"drones": 7,
|
||
"missiles": 8,
|
||
"helicopters": 9,
|
||
"submarines": 10,
|
||
}
|
||
merge({"combat_losses_delta": {"iran": delta}}, db_path=temp_db)
|
||
|
||
conn = sqlite3.connect(temp_db)
|
||
row = conn.execute(
|
||
"""SELECT personnel_killed, personnel_wounded, civilian_killed, civilian_wounded,
|
||
bases_destroyed, bases_damaged, aircraft, warships, armor, vehicles,
|
||
drones, missiles, helicopters, submarines FROM combat_losses WHERE side='iran'"""
|
||
).fetchone()
|
||
conn.close()
|
||
assert row == (1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
|
||
|
||
def test_merge_key_location_requires_table(self, temp_db):
|
||
"""key_location_updates 需要 key_location 表中有行才能更新"""
|
||
from db_merge import merge
|
||
|
||
conn = sqlite3.connect(temp_db)
|
||
conn.execute(
|
||
"""CREATE TABLE IF NOT EXISTS key_location (id INTEGER PRIMARY KEY, side TEXT, name TEXT, lat REAL, lng REAL, type TEXT, region TEXT, status TEXT, damage_level INTEGER)"""
|
||
)
|
||
conn.execute(
|
||
"INSERT INTO key_location (side, name, lat, lng, type, region, status, damage_level) VALUES ('us', '阿萨德空军基地', 33.0, 43.0, 'Base', 'IRQ', 'operational', 0)"
|
||
)
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
merge(
|
||
{"key_location_updates": [{"name_keywords": "阿萨德|asad", "side": "us", "status": "attacked", "damage_level": 2}]},
|
||
db_path=temp_db,
|
||
)
|
||
|
||
conn = sqlite3.connect(temp_db)
|
||
row = conn.execute("SELECT status, damage_level FROM key_location WHERE name LIKE '%阿萨德%'").fetchone()
|
||
conn.close()
|
||
assert row[0] == "attacked"
|
||
assert row[1] == 2
|
||
|
||
|
||
class TestEndToEndTrumpExample:
|
||
"""端到端:特朗普 1000 军事目标案例"""
|
||
|
||
def test_full_pipeline_trump_no_bases(self, tmp_path):
|
||
"""完整流程:规则提取 + merge,特朗普案例不应增加 bases"""
|
||
from db_merge import merge
|
||
|
||
db_path = str(tmp_path / "test.db")
|
||
(tmp_path / "test.db").touch() # 确保文件存在,merge 才会执行
|
||
merge({"combat_losses_delta": {"us": {"bases_destroyed": 0, "bases_damaged": 0}, "iran": {"bases_destroyed": 0, "bases_damaged": 0}}}, db_path=db_path)
|
||
|
||
text = "特朗普说伊朗有1000个军事目标遭到袭击"
|
||
out = extract_rules(text)
|
||
# 规则提取不应包含 bases
|
||
assert "combat_losses_delta" not in out or (
|
||
"iran" not in out.get("combat_losses_delta", {})
|
||
or out["combat_losses_delta"].get("iran", {}).get("bases_destroyed") is None
|
||
and out["combat_losses_delta"].get("iran", {}).get("bases_damaged") is None
|
||
)
|
||
if "combat_losses_delta" in out:
|
||
merge(out, db_path=db_path)
|
||
|
||
conn = sqlite3.connect(db_path)
|
||
iran = conn.execute("SELECT bases_destroyed, bases_damaged FROM combat_losses WHERE side='iran'").fetchone()
|
||
conn.close()
|
||
# 若提取器没输出 bases,merge 不会改;若有错误输出则需要为 0
|
||
if iran:
|
||
assert iran[0] == 0
|
||
assert iran[1] == 0
|