diff --git a/crawler/__pycache__/db_merge.cpython-39.pyc b/crawler/__pycache__/db_merge.cpython-39.pyc index 603ecce..7c24f51 100644 Binary files a/crawler/__pycache__/db_merge.cpython-39.pyc and b/crawler/__pycache__/db_merge.cpython-39.pyc differ diff --git a/crawler/__pycache__/extractor_rules.cpython-39.pyc b/crawler/__pycache__/extractor_rules.cpython-39.pyc index 64d4792..2688afe 100644 Binary files a/crawler/__pycache__/extractor_rules.cpython-39.pyc and b/crawler/__pycache__/extractor_rules.cpython-39.pyc differ diff --git a/crawler/db_merge.py b/crawler/db_merge.py index 78fbe63..3875dcb 100644 --- a/crawler/db_merge.py +++ b/crawler/db_merge.py @@ -140,18 +140,19 @@ def merge(extracted: Dict[str, Any], db_path: Optional[str] = None) -> bool: if "key_location_updates" in extracted: try: for u in extracted["key_location_updates"]: - kw = (u.get("name_keywords") or "").replace("|", " ").split() + kw_raw = (u.get("name_keywords") or "").strip() + if not kw_raw: + continue + # 支持 "a|b|c" 或 "a b c" 分隔 + kw = [k.strip() for k in kw_raw.replace("|", " ").split() if k.strip()] side = u.get("side") - status = u.get("status", "attacked")[:20] + status = (u.get("status") or "attacked")[:20] dmg = u.get("damage_level", 2) if not kw or side not in ("us", "iran"): continue - conditions = " OR ".join( - "(LOWER(name) LIKE ? OR name LIKE ?)" for _ in kw - ) - params = [status, dmg, side] - for k in kw: - params.extend([f"%{k}%", f"%{k}%"]) + # 简化:name LIKE '%kw%' 对每个关键词 OR 连接,支持中英文 + conditions = " OR ".join("name LIKE ?" for _ in kw) + params = [status, dmg, side] + [f"%{k}%" for k in kw] cur = conn.execute( f"UPDATE key_location SET status=?, damage_level=? WHERE side=? AND ({conditions})", params, diff --git a/crawler/extractor_ai.py b/crawler/extractor_ai.py index cd38e80..e87d6bd 100644 --- a/crawler/extractor_ai.py +++ b/crawler/extractor_ai.py @@ -30,8 +30,10 @@ def _call_ollama_extract(text: str, timeout: int = 10) -> Optional[Dict[str, Any - 战损(仅当新闻明确提及数字时填写,格式 us_XXX / iran_XXX): us_personnel_killed, iran_personnel_killed, us_personnel_wounded, iran_personnel_wounded, us_civilian_killed, iran_civilian_killed, us_civilian_wounded, iran_civilian_wounded, - us_bases_destroyed, iran_bases_destroyed, us_bases_damaged, iran_bases_damaged, - us_aircraft, iran_aircraft, us_warships, iran_warships, us_armor, iran_armor, us_vehicles, iran_vehicles + us_bases_destroyed, iran_bases_destroyed, us_bases_damaged, iran_bases_damaged. + 重要:bases_* 仅指已确认损毁/受损的基地数量;"军事目标"/targets 等泛指不是基地,若报道只说"X个军事目标遭袭"而无具体基地名,不填写 bases_* + us_aircraft, iran_aircraft, us_warships, iran_warships, us_armor, iran_armor, us_vehicles, iran_vehicles, + us_drones, iran_drones, us_missiles, iran_missiles, us_helicopters, iran_helicopters, us_submarines, iran_submarines - retaliation_sentiment: 0-100,仅当新闻涉及伊朗报复情绪时 - wall_street_value: 0-100,仅当新闻涉及美股/市场反应时 - key_location_updates: 当新闻提及具体基地/地点遭袭时,数组项 { "name_keywords": "asad|阿萨德|assad", "side": "us", "status": "attacked", "damage_level": 1-3 } @@ -79,7 +81,7 @@ def extract_from_news(text: str, timestamp: Optional[str] = None) -> Dict[str, A # combat_losses 增量(仅数字字段) loss_us = {} loss_ir = {} - for k in ["personnel_killed", "personnel_wounded", "civilian_killed", "civilian_wounded", "bases_destroyed", "bases_damaged", "aircraft", "warships", "armor", "vehicles"]: + for k in ["personnel_killed", "personnel_wounded", "civilian_killed", "civilian_wounded", "bases_destroyed", "bases_damaged", "aircraft", "warships", "armor", "vehicles", "drones", "missiles", "helicopters", "submarines"]: uk = f"us_{k}" ik = f"iran_{k}" if uk in parsed and isinstance(parsed[uk], (int, float)): diff --git a/crawler/extractor_dashscope.py b/crawler/extractor_dashscope.py index 8d92bbf..ff4c81e 100644 --- a/crawler/extractor_dashscope.py +++ b/crawler/extractor_dashscope.py @@ -33,11 +33,13 @@ def _call_dashscope_extract(text: str, timeout: int = 15) -> Optional[Dict[str, - 战损(仅当新闻明确提及数字时填写): us_personnel_killed, iran_personnel_killed, us_personnel_wounded, iran_personnel_wounded, us_civilian_killed, iran_civilian_killed, us_civilian_wounded, iran_civilian_wounded, - us_bases_destroyed, iran_bases_destroyed, us_bases_damaged, iran_bases_damaged, - us_aircraft, iran_aircraft, us_warships, iran_warships, us_armor, iran_armor, us_vehicles, iran_vehicles + us_bases_destroyed, iran_bases_destroyed, us_bases_damaged, iran_bases_damaged. + 重要:bases_* 仅指已确认损毁/受损的基地数量;"军事目标"/"targets"等泛指不是基地,若报道只说"X个军事目标遭袭"而无具体基地名,不填写 bases_* + us_aircraft, iran_aircraft, us_warships, iran_warships, us_armor, iran_armor, us_vehicles, iran_vehicles, + us_drones, iran_drones, us_missiles, iran_missiles, us_helicopters, iran_helicopters, us_submarines, iran_submarines - retaliation_sentiment: 0-100,仅当新闻涉及伊朗报复/反击情绪时 - wall_street_value: 0-100,仅当新闻涉及美股/市场反应时 -- key_location_updates: 当新闻提及具体基地遭袭时,数组 [{{"name_keywords":"阿萨德|asad|assad","side":"us","status":"attacked","damage_level":1-3}}] +- key_location_updates: 当新闻提及具体基地/设施遭袭时必填,数组 [{{"name_keywords":"阿萨德|asad|assad|阿因","side":"us","status":"attacked","damage_level":1-3}}]。常用关键词:阿萨德|asad|巴格达|baghdad|乌代德|udeid|埃尔比勒|erbil|因吉尔利克|incirlik|德黑兰|tehran|阿巴斯|abbas|布什尔|bushehr|伊斯法罕|isfahan|纳坦兹|natanz 原文: {str(text)[:800]} @@ -82,7 +84,8 @@ def extract_from_news(text: str, timestamp: Optional[str] = None) -> Dict[str, A loss_us = {} loss_ir = {} for k in ["personnel_killed", "personnel_wounded", "civilian_killed", "civilian_wounded", - "bases_destroyed", "bases_damaged", "aircraft", "warships", "armor", "vehicles"]: + "bases_destroyed", "bases_damaged", "aircraft", "warships", "armor", "vehicles", + "drones", "missiles", "helicopters", "submarines"]: uk, ik = f"us_{k}", f"iran_{k}" if uk in parsed and isinstance(parsed[uk], (int, float)): loss_us[k] = max(0, int(parsed[uk])) diff --git a/crawler/extractor_rules.py b/crawler/extractor_rules.py index e067349..b70a7fd 100644 --- a/crawler/extractor_rules.py +++ b/crawler/extractor_rules.py @@ -36,6 +36,8 @@ def extract_from_news(text: str, timestamp: Optional[str] = None) -> Dict[str, A if v is not None: loss_us["personnel_killed"] = v v = _first_int(t, r"(\d+)\s*名?\s*(?:美军|美国)\s*受伤") + if v is None and ("美军" in (text or "") or "美国" in (text or "")): + v = _first_int(text or t, r"另有\s*(\d+)\s*人\s*受伤") if v is not None: loss_us["personnel_wounded"] = v v = _first_int(t, r"美军\s*伤亡\s*(\d+)") @@ -57,7 +59,7 @@ def extract_from_news(text: str, timestamp: Optional[str] = None) -> Dict[str, A v = _first_int(t, r"(\d+)\s*名?\s*伊朗\s*伤亡") if v is not None: loss_ir["personnel_killed"] = v - v = _first_int(t, r"(\d+)\s*名?\s*(?:伊朗|伊朗军队)\s*(?:死亡|阵亡)") + v = _first_int(t, r"(\d+)\s*名?\s*(?:伊朗|伊朗军队)[\s\w]*(?:死亡|阵亡)") if v is not None: loss_ir["personnel_killed"] = v v = _first_int(t, r"(\d+)\s*名?\s*伊朗\s*受伤") @@ -75,28 +77,42 @@ def extract_from_news(text: str, timestamp: Optional[str] = None) -> Dict[str, A if v is not None: loss_ir["personnel_wounded"] = v - # 平民伤亡(中英文) + # 平民伤亡(中英文,按阵营归属) v = _first_int(t, r"(\d+)\s*名?\s*平民\s*(?:伤亡|死亡)") if v is not None: - loss_us["civilian_killed"] = v - v = _first_int(t, r"(\d+)[\s\w]*(?:civilian|civil)[\s\w]*(?:killed|dead)") if loss_us.get("civilian_killed") is None else None + if "伊朗" in text or "iran" in t: + loss_ir["civilian_killed"] = v + else: + loss_us["civilian_killed"] = v + v = _first_int(t, r"(\d+)[\s\w]*(?:civilian|civil)[\s\w]*(?:killed|dead)") if loss_us.get("civilian_killed") is None and loss_ir.get("civilian_killed") is None else None if v is not None: - loss_us["civilian_killed"] = v + if "iran" in t: + loss_ir["civilian_killed"] = v + else: + loss_us["civilian_killed"] = v v = _first_int(t, r"(\d+)[\s\w]*(?:civilian|civil)[\s\w]*(?:wounded|injured)") if v is not None: - loss_us["civilian_wounded"] = v + if "iran" in t: + loss_ir["civilian_wounded"] = v + else: + loss_us["civilian_wounded"] = v + v = _first_int(text or t, r"伊朗[\s\w]*(?:空袭|打击)[\s\w]*造成[^\d]*(\d+)[\s\w]*(?:平民|人|伤亡)") + if v is not None: + loss_ir["civilian_killed"] = v - # 基地损毁(美方基地居多)+ 中文 - v = _first_int(t, r"(\d+)[\s\w]*(?:base|基地)[\s\w]*(?:destroyed|leveled|摧毁|夷平)") - if v is not None: - loss_us["bases_destroyed"] = v - v = _first_int(t, r"(\d+)[\s\w]*(?:base|基地)[\s\w]*(?:damaged|hit|struck|受损|袭击)") - if v is not None: - loss_us["bases_damaged"] = v - if ("base" in t or "基地" in t) and ("destroy" in t or "level" in t or "摧毁" in t or "夷平" in t) and not loss_us.get("bases_destroyed"): - loss_us["bases_destroyed"] = 1 - if ("base" in t or "基地" in t) and ("damage" in t or "hit" in t or "struck" in t or "strike" in t or "袭击" in t or "受损" in t) and not loss_us.get("bases_damaged"): - loss_us["bases_damaged"] = 1 + # 基地损毁(仅匹配 base/基地,排除"军事目标"等泛指) + skip_bases = "军事目标" in (text or "") and "基地" not in (text or "") and "base" not in t + if not skip_bases: + v = _first_int(t, r"(\d+)[\s\w]*(?:base|基地)[\s\w]*(?:destroyed|leveled|摧毁|夷平)") + if v is not None: + loss_us["bases_destroyed"] = v + v = _first_int(t, r"(\d+)[\s\w]*(?:base|基地)[\s\w]*(?:damaged|hit|struck|受损|袭击)") + if v is not None: + loss_us["bases_damaged"] = v + if ("base" in t or "基地" in t) and ("destroy" in t or "level" in t or "摧毁" in t or "夷平" in t) and not loss_us.get("bases_destroyed"): + loss_us["bases_destroyed"] = 1 + if ("base" in t or "基地" in t) and ("damage" in t or "hit" in t or "struck" in t or "strike" in t or "袭击" in t or "受损" in t) and not loss_us.get("bases_damaged"): + loss_us["bases_damaged"] = 1 # 战机 / 舰船(根据上下文判断阵营) v = _first_int(t, r"(\d+)[\s\w]*(?:aircraft|plane|jet|fighter|f-?16|f-?35|f-?18)[\s\w]*(?:down|destroyed|lost|shot)") @@ -114,6 +130,48 @@ def extract_from_news(text: str, timestamp: Optional[str] = None) -> Dict[str, A else: loss_us["warships"] = v + # 无人机 drone / uav / 无人机 + v = _first_int(t, r"(\d+)[\s\w]*(?:drone|uav|无人机)[\s\w]*(?:down|destroyed|shot|击落|摧毁)") + if v is None: + v = _first_int(text or t, r"(?:击落|摧毁)[^\d]*(\d+)[\s\w]*(?:drone|uav|无人机|架)") + if v is None: + v = _first_int(t, r"(?:drone|uav|无人机)[\s\w]*(\d+)[\s\w]*(?:down|destroyed|shot|击落|摧毁)") + if v is not None: + if "iran" in t or "iranian" in t or "shahed" in t or "沙希德" in t or "伊朗" in (text or ""): + loss_ir["drones"] = v + else: + loss_us["drones"] = v + + # 导弹 missile / 导弹 + v = _first_int(t, r"(\d+)[\s\w]*(?:missile|导弹)[\s\w]*(?:fired|launched|intercepted|destroyed|发射|拦截|击落)") + if v is not None: + if "iran" in t or "iranian" in t: + loss_ir["missiles"] = v + else: + loss_us["missiles"] = v + v = _first_int(t, r"(?:missile|导弹)[\s\w]*(\d+)[\s\w]*(?:fired|launched|intercepted|destroyed|发射|拦截)") if not loss_us.get("missiles") and not loss_ir.get("missiles") else None + if v is not None: + if "iran" in t: + loss_ir["missiles"] = v + else: + loss_us["missiles"] = v + + # 直升机 helicopter / 直升机 + v = _first_int(t, r"(\d+)[\s\w]*(?:helicopter|直升机)[\s\w]*(?:down|destroyed|crashed|crashes|击落|坠毁)") + if v is not None: + if "iran" in t or "iranian" in t: + loss_ir["helicopters"] = v + else: + loss_us["helicopters"] = v + + # 潜艇 submarine / 潜艇 + v = _first_int(t, r"(\d+)[\s\w]*(?:submarine|潜艇)[\s\w]*(?:sunk|damaged|hit|destroyed|击沉|受损)") + if v is not None: + if "iran" in t or "iranian" in t: + loss_ir["submarines"] = v + else: + loss_us["submarines"] = v + if loss_us: out.setdefault("combat_losses_delta", {})["us"] = loss_us if loss_ir: @@ -124,11 +182,14 @@ def extract_from_news(text: str, timestamp: Optional[str] = None) -> Dict[str, A out["wall_street"] = {"time": ts, "value": 55} # key_location_updates:受袭基地(与 key_location.name 匹配) - # 新闻提及基地遭袭时,更新对应基地 status - base_attacked = ("base" in t or "基地" in t) and ("attack" in t or "hit" in t or "strike" in t or "damage" in t or "袭击" in t or "打击" in t) + # 新闻提及基地遭袭时,更新对应基地 status;放宽触发词以匹配更多英文报道 + attack_words = ("attack" in t or "attacked" in t or "hit" in t or "strike" in t or "struck" in t or "strikes" in t + or "damage" in t or "damaged" in t or "target" in t or "targeted" in t or "bomb" in t or "bombed" in t + or "袭击" in (text or "") or "遭袭" in (text or "") or "打击" in (text or "") or "受损" in (text or "") or "摧毁" in (text or "")) + base_attacked = ("base" in t or "基地" in t or "outpost" in t or "facility" in t) and attack_words if base_attacked: updates: list = [] - # 常见美军基地关键词 -> name_keywords(用于 db_merge 的 LIKE 匹配) + # 常见美军基地关键词 -> name_keywords(用于 db_merge 的 LIKE 匹配,需与 key_location.name 能匹配) bases_all = [ ("阿萨德|阿因|asad|assad|ain", "us"), ("巴格达|baghdad", "us"), diff --git a/crawler/requirements.txt b/crawler/requirements.txt index 5facd77..427768e 100644 --- a/crawler/requirements.txt +++ b/crawler/requirements.txt @@ -1,5 +1,6 @@ requests>=2.31.0 feedparser>=6.0.0 +pytest>=7.0.0 fastapi>=0.109.0 uvicorn>=0.27.0 deep-translator>=1.11.0 diff --git a/crawler/tests/__init__.py b/crawler/tests/__init__.py new file mode 100644 index 0000000..ca23acf --- /dev/null +++ b/crawler/tests/__init__.py @@ -0,0 +1 @@ +# crawler tests diff --git a/crawler/tests/__pycache__/__init__.cpython-39.pyc b/crawler/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..56b1bc7 Binary files /dev/null and b/crawler/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/crawler/tests/__pycache__/test_extraction.cpython-39-pytest-8.4.2.pyc b/crawler/tests/__pycache__/test_extraction.cpython-39-pytest-8.4.2.pyc new file mode 100644 index 0000000..5842544 Binary files /dev/null and b/crawler/tests/__pycache__/test_extraction.cpython-39-pytest-8.4.2.pyc differ diff --git a/crawler/tests/test_extraction.py b/crawler/tests/test_extraction.py new file mode 100644 index 0000000..05bd1de --- /dev/null +++ b/crawler/tests/test_extraction.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +""" +爬虫数据清洗与字段映射测试 +验证 extractor_rules、extractor_dashscope、db_merge 的正确性 +""" +import os +import sqlite3 +import tempfile +from pathlib import Path + +import pytest + +# 确保 crawler 在 path 中 +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT) not in __import__("sys").path: + __import__("sys").path.insert(0, str(ROOT)) + +from extractor_rules import extract_from_news as extract_rules + + +class TestExtractorRules: + """规则提取器单元测试""" + + def test_trump_1000_targets_no_bases(self): + """特朗普说伊朗有1000个军事目标遭到袭击 -> 不应提取 bases_destroyed/bases_damaged""" + text = "特朗普说伊朗有1000个军事目标遭到袭击,美国已做好进一步打击准备" + out = extract_rules(text) + delta = out.get("combat_losses_delta", {}) + for side in ("us", "iran"): + if side in delta: + assert delta[side].get("bases_destroyed") is None, f"{side} bases_destroyed 不应被提取" + assert delta[side].get("bases_damaged") is None, f"{side} bases_damaged 不应被提取" + + def test_base_damaged_when_explicit(self): + """阿萨德基地遭袭 -> 应提取 key_location_updates,且 combat_losses 若有则正确""" + text = "阿萨德空军基地遭袭,损失严重" + out = extract_rules(text) + # 规则会触发 key_location_updates(因为 base_attacked 且匹配 阿萨德) + assert "key_location_updates" in out + kl = out["key_location_updates"] + assert len(kl) >= 1 + assert any(u.get("side") == "us" and "阿萨德" in (u.get("name_keywords") or "") for u in kl) + + def test_us_personnel_killed(self): + """3名美军阵亡 -> personnel_killed=3""" + text = "据报道,3名美军阵亡,另有5人受伤" + out = extract_rules(text) + assert "combat_losses_delta" in out + us = out["combat_losses_delta"].get("us", {}) + assert us.get("personnel_killed") == 3 + assert us.get("personnel_wounded") == 5 + + def test_iran_personnel_killed(self): + """10名伊朗士兵死亡""" + text = "伊朗方面称10名伊朗士兵死亡" + out = extract_rules(text) + iran = out.get("combat_losses_delta", {}).get("iran", {}) + assert iran.get("personnel_killed") == 10 + + def test_civilian_us_context(self): + """美军空袭造成50名平民伤亡 -> loss_us""" + text = "美军空袭造成50名平民伤亡" + out = extract_rules(text) + us = out.get("combat_losses_delta", {}).get("us", {}) + assert us.get("civilian_killed") == 50 + + def test_civilian_iran_context(self): + """伊朗空袭造成伊拉克平民50人伤亡 -> loss_ir""" + text = "伊朗空袭造成伊拉克平民50人伤亡" + out = extract_rules(text) + iran = out.get("combat_losses_delta", {}).get("iran", {}) + assert iran.get("civilian_killed") == 50 + + def test_drone_attribution_iran(self): + """美军击落伊朗10架无人机 -> iran drones=10""" + text = "美军击落伊朗10架无人机" + out = extract_rules(text) + iran = out.get("combat_losses_delta", {}).get("iran", {}) + assert iran.get("drones") == 10 + + def test_empty_or_short_text(self): + """短文本或无内容 -> 无 combat_losses""" + assert extract_rules("") == {} or "combat_losses_delta" not in extract_rules("") + assert "combat_losses_delta" not in extract_rules("abc") or not extract_rules("abc").get("combat_losses_delta") + + +class TestDbMerge: + """db_merge 字段映射与增量逻辑测试""" + + @pytest.fixture + def temp_db(self): + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + path = f.name + yield path + try: + os.unlink(path) + except OSError: + pass + + def test_merge_combat_losses_delta(self, temp_db): + """merge 正确将 combat_losses_delta 叠加到 DB""" + from db_merge import merge + + merge({"combat_losses_delta": {"us": {"personnel_killed": 3, "personnel_wounded": 2}}}, db_path=temp_db) + merge({"combat_losses_delta": {"us": {"personnel_killed": 2}}}, db_path=temp_db) + + conn = sqlite3.connect(temp_db) + row = conn.execute("SELECT personnel_killed, personnel_wounded FROM combat_losses WHERE side='us'").fetchone() + conn.close() + assert row[0] == 5 + assert row[1] == 2 + + def test_merge_all_combat_fields(self, temp_db): + """merge 正确映射所有 combat_losses 字段""" + from db_merge import merge + + delta = { + "personnel_killed": 1, + "personnel_wounded": 2, + "civilian_killed": 3, + "civilian_wounded": 4, + "bases_destroyed": 1, + "bases_damaged": 2, + "aircraft": 3, + "warships": 4, + "armor": 5, + "vehicles": 6, + "drones": 7, + "missiles": 8, + "helicopters": 9, + "submarines": 10, + } + merge({"combat_losses_delta": {"iran": delta}}, db_path=temp_db) + + conn = sqlite3.connect(temp_db) + row = conn.execute( + """SELECT personnel_killed, personnel_wounded, civilian_killed, civilian_wounded, + bases_destroyed, bases_damaged, aircraft, warships, armor, vehicles, + drones, missiles, helicopters, submarines FROM combat_losses WHERE side='iran'""" + ).fetchone() + conn.close() + assert row == (1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + + def test_merge_key_location_requires_table(self, temp_db): + """key_location_updates 需要 key_location 表中有行才能更新""" + from db_merge import merge + + conn = sqlite3.connect(temp_db) + conn.execute( + """CREATE TABLE IF NOT EXISTS key_location (id INTEGER PRIMARY KEY, side TEXT, name TEXT, lat REAL, lng REAL, type TEXT, region TEXT, status TEXT, damage_level INTEGER)""" + ) + conn.execute( + "INSERT INTO key_location (side, name, lat, lng, type, region, status, damage_level) VALUES ('us', '阿萨德空军基地', 33.0, 43.0, 'Base', 'IRQ', 'operational', 0)" + ) + conn.commit() + conn.close() + + merge( + {"key_location_updates": [{"name_keywords": "阿萨德|asad", "side": "us", "status": "attacked", "damage_level": 2}]}, + db_path=temp_db, + ) + + conn = sqlite3.connect(temp_db) + row = conn.execute("SELECT status, damage_level FROM key_location WHERE name LIKE '%阿萨德%'").fetchone() + conn.close() + assert row[0] == "attacked" + assert row[1] == 2 + + +class TestEndToEndTrumpExample: + """端到端:特朗普 1000 军事目标案例""" + + def test_full_pipeline_trump_no_bases(self, tmp_path): + """完整流程:规则提取 + merge,特朗普案例不应增加 bases""" + from db_merge import merge + + db_path = str(tmp_path / "test.db") + (tmp_path / "test.db").touch() # 确保文件存在,merge 才会执行 + merge({"combat_losses_delta": {"us": {"bases_destroyed": 0, "bases_damaged": 0}, "iran": {"bases_destroyed": 0, "bases_damaged": 0}}}, db_path=db_path) + + text = "特朗普说伊朗有1000个军事目标遭到袭击" + out = extract_rules(text) + # 规则提取不应包含 bases + assert "combat_losses_delta" not in out or ( + "iran" not in out.get("combat_losses_delta", {}) + or out["combat_losses_delta"].get("iran", {}).get("bases_destroyed") is None + and out["combat_losses_delta"].get("iran", {}).get("bases_damaged") is None + ) + if "combat_losses_delta" in out: + merge(out, db_path=db_path) + + conn = sqlite3.connect(db_path) + iran = conn.execute("SELECT bases_destroyed, bases_damaged FROM combat_losses WHERE side='iran'").fetchone() + conn.close() + # 若提取器没输出 bases,merge 不会改;若有错误输出则需要为 0 + if iran: + assert iran[0] == 0 + assert iran[1] == 0 diff --git a/docs/DATA_FLOW.md b/docs/DATA_FLOW.md new file mode 100644 index 0000000..1a5856c --- /dev/null +++ b/docs/DATA_FLOW.md @@ -0,0 +1,62 @@ +# 前端数据更新链路与字段映射 + +## 1. 前端数据点 + +| 组件 | 数据 | API 字段 | DB 表/列 | +|------|------|----------|----------| +| HeaderPanel | lastUpdated | situation.lastUpdated | situation.updated_at | +| HeaderPanel | powerIndex | usForces/iranForces.powerIndex | power_index | +| HeaderPanel | feedbackCount, shareCount | POST /api/feedback, /api/share | feedback, share_count | +| TimelinePanel | recentUpdates | situation.recentUpdates | situation_update | +| WarMap | keyLocations | usForces/iranForces.keyLocations | key_location | +| BaseStatusPanel | 基地统计 | keyLocations (status, damage_level) | key_location | +| CombatLossesPanel | 人员/平民伤亡 | combatLosses, civilianCasualtiesTotal | combat_losses | +| CombatLossesOtherPanel | 装备毁伤 | combatLosses (bases, aircraft, drones, …) | combat_losses | +| PowerChart | 雷达图 | powerIndex | power_index | +| WallStreetTrend | 美股趋势 | wallStreetInvestmentTrend | wall_street_trend | +| RetaliationGauge | 报复指数 | retaliationSentiment | retaliation_current/history | + +**轮询**: `fetchSituation()` 加载,WebSocket `/ws` 每 3 秒广播。`GET /api/situation` → `getSituation()`。 + +## 2. 爬虫 → DB 字段映射 + +| 提取器输出 | DB 表 | 逻辑 | +|------------|-------|------| +| situation_update | situation_update | INSERT | +| combat_losses_delta | combat_losses | 增量叠加 (ADD) | +| retaliation | retaliation_current, retaliation_history | REPLACE / APPEND | +| wall_street | wall_street_trend | INSERT | +| key_location_updates | key_location | UPDATE status, damage_level WHERE name LIKE | + +### combat_losses 字段对应 + +| 提取器 (us/iran) | DB 列 | +|------------------|-------| +| personnel_killed | personnel_killed | +| personnel_wounded | personnel_wounded | +| civilian_killed | civilian_killed | +| civilian_wounded | civilian_wounded | +| bases_destroyed | bases_destroyed | +| bases_damaged | bases_damaged | +| aircraft, warships, armor, vehicles | 同名 | +| drones, missiles, helicopters, submarines | 同名 | + +## 3. 测试用例 + +运行: `npm run crawler:test:extraction` + +| 用例 | 输入 | 预期 | +|------|------|------| +| 特朗普 1000 军事目标 | "特朗普说伊朗有1000个军事目标遭到袭击" | 不提取 bases_destroyed/bases_damaged | +| 阿萨德基地遭袭 | "阿萨德空军基地遭袭,损失严重" | 输出 key_location_updates | +| 美军伤亡 | "3名美军阵亡,另有5人受伤" | personnel_killed=3, personnel_wounded=5 | +| 伊朗平民 | "伊朗空袭造成伊拉克平民50人伤亡" | iran.civilian_killed=50 | +| 伊朗无人机 | "美军击落伊朗10架无人机" | iran.drones=10 | +| db_merge 增量 | 两次 merge 3+2 | personnel_killed=5 | + +## 4. 注意事项 + +- **bases_***: 仅指已确认损毁/受损的基地;"军事目标"/targets 不填 bases_*。 +- **正则 [\s\w]***: 会匹配数字,导致 (\d+) 只捕获末位;数字前用 `[^\d]*`。 +- **伊朗平民**: 规则已支持 "伊朗空袭造成…平民" 归入 loss_ir。 +- **key_location**: 需 name LIKE '%keyword%' 匹配,关键词见 extractor_rules.bases_all。 diff --git a/package.json b/package.json index 3345eb0..07e3cc5 100644 --- a/package.json +++ b/package.json @@ -11,6 +11,7 @@ "crawler": "cd crawler && python main.py", "gdelt": "cd crawler && uvicorn realtime_conflict_service:app --host 0.0.0.0 --port 8000", "crawler:test": "cd crawler && python3 -c \"import sys; sys.path.insert(0,'.'); from scrapers.rss_scraper import fetch_all; n=len(fetch_all()); print('RSS 抓取:', n, '条' if n else '(0 条,检查网络或关键词过滤)')\"", + "crawler:test:extraction": "cd crawler && python3 -m pytest tests/test_extraction.py -v", "build": "vite build", "typecheck": "tsc --noEmit", "lint": "eslint .", diff --git a/server/data.db-shm b/server/data.db-shm index beb6dd2..14174d9 100644 Binary files a/server/data.db-shm and b/server/data.db-shm differ diff --git a/server/data.db-wal b/server/data.db-wal index 60a97c7..eebc6c7 100644 Binary files a/server/data.db-wal and b/server/data.db-wal differ diff --git a/server/index.js b/server/index.js index 33535b7..66350fa 100644 --- a/server/index.js +++ b/server/index.js @@ -40,18 +40,22 @@ if (fs.existsSync(distPath)) { const server = http.createServer(app) +const { getStats } = require('./stats') + const wss = new WebSocketServer({ server, path: '/ws' }) wss.on('connection', (ws) => { - ws.send(JSON.stringify({ type: 'situation', data: getSituation() })) + ws.send(JSON.stringify({ type: 'situation', data: getSituation(), stats: getStats() })) }) + function broadcastSituation() { try { - const data = JSON.stringify({ type: 'situation', data: getSituation() }) + const data = JSON.stringify({ type: 'situation', data: getSituation(), stats: getStats() }) wss.clients.forEach((c) => { if (c.readyState === 1) c.send(data) }) } catch (_) {} } +app.set('broadcastSituation', broadcastSituation) setInterval(broadcastSituation, 3000) // 供爬虫调用:更新 situation.updated_at 并立即广播 diff --git a/server/routes.js b/server/routes.js index d0e6438..5d621be 100644 --- a/server/routes.js +++ b/server/routes.js @@ -1,5 +1,6 @@ const express = require('express') const { getSituation } = require('./situationData') +const { getStats } = require('./stats') const db = require('./db') const router = express.Router() @@ -84,16 +85,6 @@ function getClientIp(req) { return req.ip || req.socket?.remoteAddress || 'unknown' } -function getStats() { - const viewers = db.prepare( - "SELECT COUNT(*) as n FROM visits WHERE last_seen > datetime('now', '-2 minutes')" - ).get().n - const cumulative = db.prepare('SELECT total FROM visitor_count WHERE id = 1').get()?.total ?? 0 - const feedbackCount = db.prepare('SELECT COUNT(*) as n FROM feedback').get().n ?? 0 - const shareCount = db.prepare('SELECT total FROM share_count WHERE id = 1').get()?.total ?? 0 - return { viewers, cumulative, feedbackCount, shareCount } -} - router.post('/visit', (req, res) => { try { const ip = getClientIp(req) @@ -103,6 +94,8 @@ router.post('/visit', (req, res) => { db.prepare( 'INSERT INTO visitor_count (id, total) VALUES (1, 1) ON CONFLICT(id) DO UPDATE SET total = total + 1' ).run() + const broadcast = req.app?.get?.('broadcastSituation') + if (typeof broadcast === 'function') broadcast() res.json(getStats()) } catch (err) { console.error(err) diff --git a/server/stats.js b/server/stats.js new file mode 100644 index 0000000..d943e6f --- /dev/null +++ b/server/stats.js @@ -0,0 +1,13 @@ +const db = require('./db') + +function getStats() { + const viewers = db.prepare( + "SELECT COUNT(*) as n FROM visits WHERE last_seen > datetime('now', '-2 minutes')" + ).get().n + const cumulative = db.prepare('SELECT total FROM visitor_count WHERE id = 1').get()?.total ?? 0 + const feedbackCount = db.prepare('SELECT COUNT(*) as n FROM feedback').get().n ?? 0 + const shareCount = db.prepare('SELECT total FROM share_count WHERE id = 1').get()?.total ?? 0 + return { viewers, cumulative, feedbackCount, shareCount } +} + +module.exports = { getStats } diff --git a/src/api/websocket.ts b/src/api/websocket.ts index 0adc10b..10e702e 100644 --- a/src/api/websocket.ts +++ b/src/api/websocket.ts @@ -15,7 +15,7 @@ export function connectSituationWebSocket(onData: Handler): () => void { ws.onmessage = (e) => { try { const msg = JSON.parse(e.data) - if (msg.type === 'situation' && msg.data) handler?.(msg.data) + if (msg.type === 'situation') handler?.({ situation: msg.data, stats: msg.stats }) } catch (_) {} } ws.onclose = () => { diff --git a/src/components/HeaderPanel.tsx b/src/components/HeaderPanel.tsx index a351c7f..c778991 100644 --- a/src/components/HeaderPanel.tsx +++ b/src/components/HeaderPanel.tsx @@ -1,6 +1,7 @@ import { useState, useEffect } from 'react' import { StatCard } from './StatCard' import { useSituationStore } from '@/store/situationStore' +import { useStatsStore } from '@/store/statsStore' import { useReplaySituation } from '@/hooks/useReplaySituation' import { usePlaybackStore } from '@/store/playbackStore' import { Wifi, WifiOff, Clock, Share2, Heart, Eye, MessageSquare } from 'lucide-react' @@ -23,10 +24,12 @@ export function HeaderPanel() { const [now, setNow] = useState(() => new Date()) const [likes, setLikes] = useState(getStoredLikes) const [liked, setLiked] = useState(false) - const [viewers, setViewers] = useState(0) - const [cumulative, setCumulative] = useState(0) - const [feedbackCount, setFeedbackCount] = useState(0) - const [shareCount, setShareCount] = useState(0) + const stats = useStatsStore((s) => s.stats) + const setStats = useStatsStore((s) => s.setStats) + const viewers = stats.viewers ?? 0 + const cumulative = stats.cumulative ?? 0 + const feedbackCount = stats.feedbackCount ?? 0 + const shareCount = stats.shareCount ?? 0 const [feedbackOpen, setFeedbackOpen] = useState(false) const [feedbackText, setFeedbackText] = useState('') const [feedbackSending, setFeedbackSending] = useState(false) @@ -41,13 +44,14 @@ export function HeaderPanel() { try { const res = await fetch('/api/visit', { method: 'POST' }) const data = await res.json() - if (data.viewers != null) setViewers(data.viewers) - if (data.cumulative != null) setCumulative(data.cumulative) - if (data.feedbackCount != null) setFeedbackCount(data.feedbackCount) - if (data.shareCount != null) setShareCount(data.shareCount) + setStats({ + viewers: data.viewers, + cumulative: data.cumulative, + feedbackCount: data.feedbackCount, + shareCount: data.shareCount, + }) } catch { - setViewers((v) => (v > 0 ? v : 0)) - setCumulative((c) => (c > 0 ? c : 0)) + setStats({ viewers: 0, cumulative: 0 }) } } @@ -79,7 +83,7 @@ export function HeaderPanel() { try { const res = await fetch('/api/share', { method: 'POST' }) const data = await res.json() - if (data.shareCount != null) setShareCount(data.shareCount) + if (data.shareCount != null) setStats({ shareCount: data.shareCount }) } catch {} } } @@ -102,7 +106,7 @@ export function HeaderPanel() { if (data.ok) { setFeedbackText('') setFeedbackDone(true) - setFeedbackCount((c) => c + 1) + setStats({ feedbackCount: (feedbackCount ?? 0) + 1 }) setTimeout(() => { setFeedbackOpen(false) setFeedbackDone(false) diff --git a/src/store/situationStore.ts b/src/store/situationStore.ts index 870c09f..d52d67a 100644 --- a/src/store/situationStore.ts +++ b/src/store/situationStore.ts @@ -3,6 +3,7 @@ import type { MilitarySituation } from '@/data/mockData' import { INITIAL_MOCK_DATA } from '@/data/mockData' import { fetchSituation } from '@/api/situation' import { connectSituationWebSocket } from '@/api/websocket' +import { useStatsStore } from './statsStore' interface SituationState { situation: MilitarySituation @@ -60,9 +61,11 @@ function pollSituation() { export function startSituationWebSocket(): () => void { useSituationStore.getState().setLastError(null) - disconnectWs = connectSituationWebSocket((data) => { + disconnectWs = connectSituationWebSocket((payload) => { + const { situation, stats } = payload as { situation?: MilitarySituation; stats?: { viewers?: number; cumulative?: number; feedbackCount?: number; shareCount?: number } } useSituationStore.getState().setConnected(true) - useSituationStore.getState().setSituation(data as MilitarySituation) + if (situation) useSituationStore.getState().setSituation(situation) + if (stats) useStatsStore.getState().setStats(stats) }) pollSituation() diff --git a/src/store/statsStore.ts b/src/store/statsStore.ts new file mode 100644 index 0000000..4290524 --- /dev/null +++ b/src/store/statsStore.ts @@ -0,0 +1,18 @@ +import { create } from 'zustand' + +export interface Stats { + viewers?: number + cumulative?: number + feedbackCount?: number + shareCount?: number +} + +interface StatsState { + stats: Stats + setStats: (stats: Stats) => void +} + +export const useStatsStore = create((set) => ({ + stats: {}, + setStats: (stats) => set((s) => ({ stats: { ...s.stats, ...stats } })), +}))