# -*- coding: utf-8 -*- """ 爬虫数据清洗与字段映射测试 验证 extractor_rules、extractor_dashscope、db_merge 的正确性 """ import os import sqlite3 import tempfile from pathlib import Path import pytest # 确保 crawler 在 path 中 ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in __import__("sys").path: __import__("sys").path.insert(0, str(ROOT)) from extractor_rules import extract_from_news as extract_rules class TestExtractorRules: """规则提取器单元测试""" def test_trump_1000_targets_no_bases(self): """特朗普说伊朗有1000个军事目标遭到袭击 -> 不应提取 bases_destroyed/bases_damaged""" text = "特朗普说伊朗有1000个军事目标遭到袭击,美国已做好进一步打击准备" out = extract_rules(text) delta = out.get("combat_losses_delta", {}) for side in ("us", "iran"): if side in delta: assert delta[side].get("bases_destroyed") is None, f"{side} bases_destroyed 不应被提取" assert delta[side].get("bases_damaged") is None, f"{side} bases_damaged 不应被提取" def test_base_damaged_when_explicit(self): """阿萨德基地遭袭 -> 应提取 key_location_updates,且 combat_losses 若有则正确""" text = "阿萨德空军基地遭袭,损失严重" out = extract_rules(text) # 规则会触发 key_location_updates(因为 base_attacked 且匹配 阿萨德) assert "key_location_updates" in out kl = out["key_location_updates"] assert len(kl) >= 1 assert any(u.get("side") == "us" and "阿萨德" in (u.get("name_keywords") or "") for u in kl) def test_us_personnel_killed(self): """3名美军阵亡 -> personnel_killed=3""" text = "据报道,3名美军阵亡,另有5人受伤" out = extract_rules(text) assert "combat_losses_delta" in out us = out["combat_losses_delta"].get("us", {}) assert us.get("personnel_killed") == 3 assert us.get("personnel_wounded") == 5 def test_iran_personnel_killed(self): """10名伊朗士兵死亡""" text = "伊朗方面称10名伊朗士兵死亡" out = extract_rules(text) iran = out.get("combat_losses_delta", {}).get("iran", {}) assert iran.get("personnel_killed") == 10 def test_civilian_us_context(self): """美军空袭造成50名平民伤亡 -> loss_us""" text = "美军空袭造成50名平民伤亡" out = extract_rules(text) us = out.get("combat_losses_delta", {}).get("us", {}) assert us.get("civilian_killed") == 50 def test_civilian_iran_context(self): """伊朗空袭造成伊拉克平民50人伤亡 -> loss_ir""" text = "伊朗空袭造成伊拉克平民50人伤亡" out = extract_rules(text) iran = out.get("combat_losses_delta", {}).get("iran", {}) assert iran.get("civilian_killed") == 50 def test_drone_attribution_iran(self): """美军击落伊朗10架无人机 -> iran drones=10""" text = "美军击落伊朗10架无人机" out = extract_rules(text) iran = out.get("combat_losses_delta", {}).get("iran", {}) assert iran.get("drones") == 10 def test_empty_or_short_text(self): """短文本或无内容 -> 无 combat_losses""" assert extract_rules("") == {} or "combat_losses_delta" not in extract_rules("") assert "combat_losses_delta" not in extract_rules("abc") or not extract_rules("abc").get("combat_losses_delta") class TestDbMerge: """db_merge 字段映射与增量逻辑测试""" @pytest.fixture def temp_db(self): with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: path = f.name yield path try: os.unlink(path) except OSError: pass def test_merge_combat_losses_delta(self, temp_db): """merge 正确将 combat_losses_delta 叠加到 DB""" from db_merge import merge merge({"combat_losses_delta": {"us": {"personnel_killed": 3, "personnel_wounded": 2}}}, db_path=temp_db) merge({"combat_losses_delta": {"us": {"personnel_killed": 2}}}, db_path=temp_db) conn = sqlite3.connect(temp_db) row = conn.execute("SELECT personnel_killed, personnel_wounded FROM combat_losses WHERE side='us'").fetchone() conn.close() assert row[0] == 5 assert row[1] == 2 def test_merge_all_combat_fields(self, temp_db): """merge 正确映射所有 combat_losses 字段""" from db_merge import merge delta = { "personnel_killed": 1, "personnel_wounded": 2, "civilian_killed": 3, "civilian_wounded": 4, "bases_destroyed": 1, "bases_damaged": 2, "aircraft": 3, "warships": 4, "armor": 5, "vehicles": 6, "drones": 7, "missiles": 8, "helicopters": 9, "submarines": 10, } merge({"combat_losses_delta": {"iran": delta}}, db_path=temp_db) conn = sqlite3.connect(temp_db) row = conn.execute( """SELECT personnel_killed, personnel_wounded, civilian_killed, civilian_wounded, bases_destroyed, bases_damaged, aircraft, warships, armor, vehicles, drones, missiles, helicopters, submarines FROM combat_losses WHERE side='iran'""" ).fetchone() conn.close() assert row == (1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) def test_merge_key_location_requires_table(self, temp_db): """key_location_updates 需要 key_location 表中有行才能更新""" from db_merge import merge conn = sqlite3.connect(temp_db) conn.execute( """CREATE TABLE IF NOT EXISTS key_location (id INTEGER PRIMARY KEY, side TEXT, name TEXT, lat REAL, lng REAL, type TEXT, region TEXT, status TEXT, damage_level INTEGER)""" ) conn.execute( "INSERT INTO key_location (side, name, lat, lng, type, region, status, damage_level) VALUES ('us', '阿萨德空军基地', 33.0, 43.0, 'Base', 'IRQ', 'operational', 0)" ) conn.commit() conn.close() merge( {"key_location_updates": [{"name_keywords": "阿萨德|asad", "side": "us", "status": "attacked", "damage_level": 2}]}, db_path=temp_db, ) conn = sqlite3.connect(temp_db) row = conn.execute("SELECT status, damage_level FROM key_location WHERE name LIKE '%阿萨德%'").fetchone() conn.close() assert row[0] == "attacked" assert row[1] == 2 class TestEndToEndTrumpExample: """端到端:特朗普 1000 军事目标案例""" def test_full_pipeline_trump_no_bases(self, tmp_path): """完整流程:规则提取 + merge,特朗普案例不应增加 bases""" from db_merge import merge db_path = str(tmp_path / "test.db") (tmp_path / "test.db").touch() # 确保文件存在,merge 才会执行 merge({"combat_losses_delta": {"us": {"bases_destroyed": 0, "bases_damaged": 0}, "iran": {"bases_destroyed": 0, "bases_damaged": 0}}}, db_path=db_path) text = "特朗普说伊朗有1000个军事目标遭到袭击" out = extract_rules(text) # 规则提取不应包含 bases assert "combat_losses_delta" not in out or ( "iran" not in out.get("combat_losses_delta", {}) or out["combat_losses_delta"].get("iran", {}).get("bases_destroyed") is None and out["combat_losses_delta"].get("iran", {}).get("bases_damaged") is None ) if "combat_losses_delta" in out: merge(out, db_path=db_path) conn = sqlite3.connect(db_path) iran = conn.execute("SELECT bases_destroyed, bases_damaged FROM combat_losses WHERE side='iran'").fetchone() conn.close() # 若提取器没输出 bases,merge 不会改;若有错误输出则需要为 0 if iran: assert iran[0] == 0 assert iran[1] == 0