Flask app progress

This commit is contained in:
Matthew Grotke 2026-05-17 03:26:01 -04:00
parent c4fe022d42
commit b0994069ad
38 changed files with 6631 additions and 220 deletions

View file

@ -100,6 +100,7 @@ import urllib.error
import argparse
from datetime import datetime
from pathlib import Path
from validation import VALID_PROTOCOLS, VALID_BLOCKLIST_FORMATS
SCRIPT_DIR = Path(__file__).parent
CONFIG_FILE = SCRIPT_DIR / "core.json"
@ -119,14 +120,14 @@ NAT_SERVICE_FILE = SYSTEMD_DIR / f"{NAT_SERVICE_NAME}.service"
log = None
# ------------------------------------------------------------------------------
# ===================================================================
# Logging
# ------------------------------------------------------------------------------
# ===================================================================
def chown_to_script_dir_owner(path):
"""Chown a file to the owner of the script directory.
This works correctly whether invoked via sudo, directly as root (e.g. systemd timer),
or as a normal user the script directory owner is always the right target.
or as a normal user - the script directory owner is always the right target.
"""
try:
stat = SCRIPT_DIR.stat()
@ -159,9 +160,9 @@ def setup_logging(max_kb, errors_only):
)
log = logging.getLogger("dns-dhcp")
# ------------------------------------------------------------------------------
# ===================================================================
# Helpers
# ------------------------------------------------------------------------------
# ===================================================================
def service_warning(action, svc, stderr):
"""Print a service start/restart warning, adding --install hint if unit not found."""
@ -172,7 +173,7 @@ def service_warning(action, svc, stderr):
def die(msg):
print(f"ERROR: {msg}")
print(f"ERROR: {msg}", file=sys.stderr)
sys.exit(1)
def check_root():
@ -279,9 +280,9 @@ def expand_protocols(rule):
return [("tcp", rule, " (tcp)"), ("udp", rule, " (udp)")]
return [(proto, rule, "")]
# ------------------------------------------------------------------------------
# ===================================================================
# Load
# ------------------------------------------------------------------------------
# ===================================================================
def load_config():
if not CONFIG_FILE.exists():
@ -292,9 +293,9 @@ def load_config():
die("No vlans defined in core.json.")
return data
# ------------------------------------------------------------------------------
# ===================================================================
# Validate
# ------------------------------------------------------------------------------
# ===================================================================
def validate_config(data):
errors = []
@ -330,8 +331,8 @@ def validate_config(data):
for field in ("name", "description", "save_as", "url", "format"):
if not bl.get(field):
errors.append(f"{label}: missing or empty field '{field}'.")
if bl.get("format") and bl["format"] not in ("dnsmasq", "hosts"):
errors.append(f"{label}: format must be 'dnsmasq' or 'hosts'.")
if bl.get("format") and bl["format"] not in VALID_BLOCKLIST_FORMATS:
errors.append(f"{label}: format must be one of: {', '.join(sorted(VALID_BLOCKLIST_FORMATS))}.")
if name:
if name in blocklists_by_name:
errors.append(f"{label}: duplicate blocklist name '{name}'.")
@ -365,6 +366,9 @@ def validate_config(data):
else:
seen_interfaces[iface] = name
if vlan.get("mdns_reflection") is True and is_wg(vlan):
errors.append(f"{label}: mdns_reflection must be false for WireGuard interfaces.")
if is_wg(vlan):
vpi = vlan.get("vpn_information")
if not isinstance(vpi, dict):
@ -538,7 +542,7 @@ def validate_config(data):
errors.append(f"{label}: use_blocklists references unknown blocklist '{bl_name}'.")
# -- NAT / firewall validation ---------------------------------------------
valid_protos = {"tcp", "udp", "both"}
valid_protos = VALID_PROTOCOLS
known_interfaces = set(seen_interfaces.keys())
def nat_check_port(label, port):
@ -621,25 +625,11 @@ def validate_config(data):
if r.get("dst_port") is not None:
nat_check_port(f"{label} dst_port", r.get("dst_port"))
# -- mdns_reflection validation --------------------------------------------
mdns = data.get("mdns_reflection", {})
if mdns.get("enabled") is True:
known_vlan_names = {v["name"] for v in data["vlans"]}
reflect_vlans = mdns.get("reflect_vlans", [])
for vname in reflect_vlans:
if vname not in known_vlan_names:
errors.append(f"mdns_reflection.reflect_vlans: '{vname}' is not a known VLAN name.")
else:
vlan = next(v for v in data["vlans"] if v["name"] == vname)
if is_wg(vlan):
errors.append(f"mdns_reflection.reflect_vlans: '{vname}' is a WireGuard VLAN "
f"and cannot participate in mDNS reflection.")
if not reflect_vlans:
errors.append("mdns_reflection.reflect_vlans is empty. "
"Add at least two VLAN names or set enabled: false.")
elif len(reflect_vlans) < 2:
errors.append("mdns_reflection.reflect_vlans must contain at least two VLANs — "
"reflecting mDNS on a single VLAN has no effect.")
# -- radius_default uniqueness check ---------------------------------------
defaults = [v["name"] for v in data["vlans"] if v.get("radius_default") is True]
if len(defaults) > 1:
errors.append(f"Multiple VLANs have radius_default: true ({', '.join(defaults)}). "
f"Only one VLAN may be the RADIUS default.")
# -- banned_ips validation -------------------------------------------------
for idx, entry in enumerate(data.get("banned_ips", [])):
@ -654,14 +644,14 @@ def validate_config(data):
errors.append(f"{lbl}: {e}")
if errors:
print("Validation failed:")
print("Validation failed:", file=sys.stderr)
for e in errors:
print(f" - {e}")
print(f" - {e}", file=sys.stderr)
sys.exit(1)
# ------------------------------------------------------------------------------
# ===================================================================
# Build systemd-networkd files
# ------------------------------------------------------------------------------
# ===================================================================
def build_netdev(vlan):
return "\n".join([
@ -787,9 +777,9 @@ def apply_networkd(data, dry_run=False, only_if_changed=False):
print("systemd-networkd: no changes. Good.")
# ------------------------------------------------------------------------------
# ===================================================================
# Blocklist management
# ------------------------------------------------------------------------------
# ===================================================================
def combo_hash(names):
"""Return a stable 8-char hex hash for a list/set of blocklist names."""
@ -934,9 +924,9 @@ def update_blocklists(data):
any_failed = any(content is None for content, _ in downloaded.values())
return not any_failed
# ------------------------------------------------------------------------------
# ===================================================================
# Build per-VLAN dnsmasq config
# ------------------------------------------------------------------------------
# ===================================================================
def _wan_has_ipv6(iface):
"""Return True if the WAN interface has a non-link-local IPv6 address."""
@ -1087,9 +1077,9 @@ def build_vlan_dnsmasq_conf(vlan, data):
return "\n".join(L)
# ------------------------------------------------------------------------------
# ===================================================================
# Build per-VLAN systemd service unit
# ------------------------------------------------------------------------------
# ===================================================================
def build_vlan_service(vlan):
name = vlan["name"]
@ -1133,9 +1123,9 @@ def build_vlan_service(vlan):
return "\n".join(lines)
# ------------------------------------------------------------------------------
# ===================================================================
# System dnsmasq / resolv.conf
# ------------------------------------------------------------------------------
# ===================================================================
def ensure_resolv_conf(data):
"""Ensure /etc/resolv.conf points to the physical VLAN gateway (vlan_id=1)."""
@ -1297,9 +1287,9 @@ def restore_ntp():
else:
print("systemd-timesyncd is not available on this system.")
# ------------------------------------------------------------------------------
# ===================================================================
# Apply dnsmasq instances
# ------------------------------------------------------------------------------
# ===================================================================
def wg_interface_up(iface):
"""Return True if the WireGuard interface exists and is up."""
@ -1452,9 +1442,9 @@ def apply_dnsmasq_instances(data, dry_run=False, start_if_needed=True):
else:
print(f" WARNING: {svc} is not running -- skipping (run --apply to start it)")
# ------------------------------------------------------------------------------
# ===================================================================
# Timer management
# ------------------------------------------------------------------------------
# ===================================================================
def parse_time_to_calendar(time_str):
parts = time_str.strip().split(":")
@ -1519,9 +1509,9 @@ def remove_timer():
print(f"Not found, skipping: {f}")
subprocess.run(["systemctl", "daemon-reload"], capture_output=True, text=True)
# ------------------------------------------------------------------------------
# ===================================================================
# banned_ips expansion
# ------------------------------------------------------------------------------
# ===================================================================
def _expand_banned_ipv4(ip_str):
"""Convert an IPv4 pattern (CIDR, wildcard, range) to nftables set elements."""
@ -1531,7 +1521,7 @@ def _expand_banned_ipv4(ip_str):
parts = ip_str.split('.')
if len(parts) != 4:
raise ValueError(f"Invalid IPv4 pattern: {ip_str!r} expected 4 octets")
raise ValueError(f"Invalid IPv4 pattern: {ip_str!r} - expected 4 octets")
def parse_octet(s, pos):
if s == '*':
@ -1587,7 +1577,7 @@ def _expand_banned_ipv4(ip_str):
_enum_cidr(idx + 1, chosen + [v])
_enum_cidr(0, [])
else:
# No trailing wildcards enumerate outer 3 octets, express last as range
# No trailing wildcards - enumerate outer 3 octets, express last as range
outer_ranges = ranges[:3]
lo4, hi4 = ranges[3]
@ -1682,9 +1672,9 @@ def banned_ip_sets(data):
return v4, v6
# ------------------------------------------------------------------------------
# ===================================================================
# nftables config generation
# ------------------------------------------------------------------------------
# ===================================================================
def build_nft_config(data, dry_run=False):
wan = data["general"]["wan_interface"]
@ -1946,9 +1936,9 @@ def build_nft_config(data, dry_run=False):
return "\n".join(L)
# ------------------------------------------------------------------------------
# ===================================================================
# nftables apply / disable / status
# ------------------------------------------------------------------------------
# ===================================================================
def table_exists(family, name):
result = subprocess.run(
@ -1977,8 +1967,8 @@ def apply_nft_config(config_text):
capture_output=True, text=True
)
if result.returncode != 0:
print("ERROR: nft rejected the ruleset:")
print(result.stderr)
print("ERROR: nft rejected the ruleset:", file=sys.stderr)
print(result.stderr, file=sys.stderr)
sys.exit(1)
def apply_nftables(data, dry_run=False):
@ -2075,9 +2065,9 @@ def show_rules():
else:
print(result.stdout)
# ------------------------------------------------------------------------------
# ===================================================================
# nftables boot service
# ------------------------------------------------------------------------------
# ===================================================================
def install_nat_service():
script_path = Path(__file__).resolve()
@ -2121,13 +2111,13 @@ def remove_nat_service():
else:
print(f"Boot service not found, skipping: {NAT_SERVICE_NAME}.service")
# ------------------------------------------------------------------------------
# ===================================================================
# Status
# ------------------------------------------------------------------------------
# ===================================================================
# ------------------------------------------------------------------------------
# ===================================================================
# RADIUS
# ------------------------------------------------------------------------------
# ===================================================================
RADIUS_SECRET_FILE = SCRIPT_DIR / ".radius-secret"
RADIUS_CLIENTS_CONF = Path("/etc/freeradius/3.0/clients.conf")
@ -2275,25 +2265,19 @@ def apply_radius(data):
service_warning("start", "freeradius", result.stderr)
# ------------------------------------------------------------------------------
# ===================================================================
# Avahi mDNS Reflector
# ------------------------------------------------------------------------------
# ===================================================================
AVAHI_CONF_FILE = Path("/etc/avahi/avahi-daemon.conf")
def avahi_enabled(data):
"""Return True if mdns_reflection is enabled with at least two VLANs configured."""
mdns = data.get("mdns_reflection", {})
return mdns.get("enabled") is True
"""Return True if at least one non-WireGuard VLAN has mdns_reflection enabled."""
return any(v.get("mdns_reflection") is True for v in data.get("vlans", []) if not is_wg(v))
def avahi_interfaces(data):
"""Return list of interface names for mDNS reflection based on reflect_vlans."""
reflect = data.get("mdns_reflection", {}).get("reflect_vlans", [])
ifaces = []
for vlan in data["vlans"]:
if vlan["name"] in reflect and not is_wg(vlan):
ifaces.append(vlan["interface"])
return ifaces
"""Return list of interface names for VLANs with mdns_reflection enabled."""
return [v["interface"] for v in data.get("vlans", []) if v.get("mdns_reflection") is True and not is_wg(v)]
def build_avahi_conf(data):
"""Patch avahi-daemon.conf directives needed for cross-VLAN mDNS reflection.
@ -2317,7 +2301,7 @@ def build_avahi_conf(data):
replacement = f"{directive}={value}"
if pattern.search(text):
return pattern.sub(replacement, text)
# Not present at all this shouldn't happen with a standard avahi install
# Not present at all - this shouldn't happen with a standard avahi install
# but append it to the relevant section if needed
return text + f"\n{replacement}\n"
@ -2403,8 +2387,8 @@ def show_status(data):
r_enabled = subprocess.run(["systemctl", "is-enabled", unit], capture_output=True, text=True)
active = r_active.stdout.strip()
enabled = r_enabled.stdout.strip()
active_sym = "" if active == "active" else ""
enabled_sym = "" if enabled == "enabled" else ""
active_sym = "+" if active == "active" else "x"
enabled_sym = "+" if enabled == "enabled" else "x"
active_ok = "(OK) " if active == expected_active else "(BAD)"
enabled_ok = "(OK) " if enabled == "enabled" else "(BAD)"
return active_sym, active, active_ok, enabled_sym, enabled, enabled_ok
@ -2416,7 +2400,7 @@ def show_status(data):
else:
units.append((vlan_service_name(vlan), None, "active"))
units.append((f"{TIMER_NAME}.timer", None, "active"))
units.append((NAT_SERVICE_NAME, None, "inactive")) # oneshot exits after running
units.append((NAT_SERVICE_NAME, None, "inactive")) # oneshot - exits after running
units.append(("freeradius", None, "active"))
units.append(("avahi-daemon", None, "active"))
@ -2456,9 +2440,9 @@ def show_configs(data):
else:
print(f"No config found at {cf} (not yet applied).")
# ------------------------------------------------------------------------------
# ===================================================================
# Leases
# ------------------------------------------------------------------------------
# ===================================================================
def reset_leases(data, vlan_name=None):
"""Stop dnsmasq instances, delete lease files, restart instances.
@ -2572,9 +2556,9 @@ def show_leases(data):
if not any_leases:
print("No active leases found.")
# ------------------------------------------------------------------------------
# ===================================================================
# Metrics
# ------------------------------------------------------------------------------
# ===================================================================
def collect_metrics(data):
"""
@ -2755,9 +2739,9 @@ def show_metrics(data):
print(f" NXDOMAIN : {s['nxdomain']:,}")
print(f" Latency : {s['avg_latency_ms']}ms (last recorded)")
# ------------------------------------------------------------------------------
# ===================================================================
# Stop / disable
# ------------------------------------------------------------------------------
# ===================================================================
def stop_instances(data):
"""Remove timer and stop all per-VLAN instances (config files preserved)."""
@ -2867,19 +2851,19 @@ def _suggest_static_ip(physical_vlan):
chosen = max(non_gateway, key=lambda ip: ip.packed[-1])
return f"{chosen}/{prefix}"
# All identities end in .1 pick a random unused host in the subnet
# All identities end in .1 - pick a random unused host in the subnet
hosts = list(network.hosts())
candidates = [h for h in hosts if h not in known_ips and h.packed[-1] != 1]
if candidates:
chosen = random.choice(candidates)
return f"{chosen}/{prefix}"
# Degenerate fallback extremely small subnet
# Degenerate fallback - extremely small subnet
return f"{list(network.hosts())[0]}/{prefix}"
# ------------------------------------------------------------------------------
# ===================================================================
# Dry-run helpers
# ------------------------------------------------------------------------------
# ===================================================================
def _svc_state(unit):
"""Return 'active', 'inactive', or 'unknown' for a systemd unit."""
@ -2900,12 +2884,12 @@ def _dry_run_conflicting_services(data):
if state == "active":
print(f" Would stop and disable: {label} (currently: active)")
else:
print(f" {label}: not active no action needed")
print(f" {label}: not active - no action needed")
chrony_ok = subprocess.run(["systemctl", "cat", "chrony"],
capture_output=True, text=True).returncode == 0
if not chrony_ok:
print(" chrony: not installed dependency check would have prompted to install it")
print(" chrony: not installed - dependency check would have prompted to install it")
else:
chrony_conf = Path("/etc/chrony/chrony.conf")
if chrony_conf.exists():
@ -2922,7 +2906,7 @@ def _dry_run_conflicting_services(data):
if missing:
print(f" Would add chrony allow directives for: {', '.join(missing)}")
else:
print(" chrony.conf already has required allow directives no change needed")
print(" chrony.conf already has required allow directives - no change needed")
print(f" Would enable and restart: chrony")
if subprocess.run(["which", "ufw"], capture_output=True, text=True).returncode == 0:
@ -2930,20 +2914,20 @@ def _dry_run_conflicting_services(data):
if "Status: active" in status.stdout:
print(" Would disable: ufw (currently: active)")
else:
print(" ufw: not active no rule action needed")
print(" ufw: not active - no rule action needed")
if _svc_enabled("ufw"):
print(" Would disable: ufw.service (currently: enabled at boot)")
else:
print(" ufw.service: not enabled at boot no action needed")
print(" ufw.service: not enabled at boot - no action needed")
else:
print(" ufw: not installed no action needed")
print(" ufw: not installed - no action needed")
r = subprocess.run(["systemctl", "is-enabled", "dnsmasq"],
capture_output=True, text=True)
if r.stdout.strip() in ("enabled", "enabled-runtime"):
print(f" Would stop and disable: system dnsmasq.service (currently: enabled)")
else:
print(" system dnsmasq.service: not enabled no action needed")
print(" system dnsmasq.service: not enabled - no action needed")
physical = next((v for v in data["vlans"] if is_physical(v)), None)
if physical:
@ -2956,7 +2940,7 @@ def _dry_run_conflicting_services(data):
if wanted not in current:
print(f" Would update /etc/resolv.conf: nameserver {gw}")
else:
print(f" /etc/resolv.conf already points to {gw} no change needed")
print(f" /etc/resolv.conf already points to {gw} - no change needed")
def _dry_run_blocklists(data):
print("-- Blocklists (dry-run) ----------------------------------------------")
@ -2982,7 +2966,7 @@ def _dry_run_timer(data):
for path, label in [(TIMER_FILE, "timer unit"), (TIMER_SVC_FILE, "service unit")]:
action = "update" if path.exists() else "create and enable"
print(f" Would {action}: {path}")
print(f" Schedule: daily at {execute_time} local time (Persistent=true catches up if missed)")
print(f" Schedule: daily at {execute_time} local time (Persistent=true - catches up if missed)")
def _dry_run_boot_service():
print("-- Boot service (dry-run) --------------------------------------------")
@ -3016,11 +3000,11 @@ def _dry_run_disable(data, iface, use_dhcp, static_cidr, resolv_ok, dns_choice,
if r.returncode == 0:
print(f" Would flush nftables table: {table}")
else:
print(f" nftables table {table}: not present no action needed")
print(f" nftables table {table}: not present - no action needed")
if NAT_SERVICE_FILE.exists():
print(f" Would stop, disable, and remove: {NAT_SERVICE_NAME}.service")
else:
print(f" {NAT_SERVICE_NAME}.service: not installed no action needed")
print(f" {NAT_SERVICE_NAME}.service: not installed - no action needed")
print()
print("-- Restoring NTP client (dry-run) ------------------------------------")
@ -3028,7 +3012,7 @@ def _dry_run_disable(data, iface, use_dhcp, static_cidr, resolv_ok, dns_choice,
if state == "active":
print(f" Would stop and disable: chrony (currently: active)")
else:
print(f" chrony: not active no action needed")
print(f" chrony: not active - no action needed")
r = subprocess.run(["systemctl", "cat", "systemd-timesyncd"],
capture_output=True, text=True)
if r.returncode == 0:
@ -3063,9 +3047,9 @@ def _dry_run_disable(data, iface, use_dhcp, static_cidr, resolv_ok, dns_choice,
print(f" nameserver {static_nameserver}")
print()
# ------------------------------------------------------------------------------
# ===================================================================
# Disable wizard
# ------------------------------------------------------------------------------
# ===================================================================
def cmd_disable(data, dry_run=False):
"""Interactive wizard to revert the machine from router to plain network client."""
@ -3085,7 +3069,7 @@ def cmd_disable(data, dry_run=False):
print()
# ------------------------------------------------------------------
# Step 1 Confirmation
# Step 1 - Confirmation
# ------------------------------------------------------------------
while True:
print(" [1] Proceed with reversion")
@ -3100,7 +3084,7 @@ def cmd_disable(data, dry_run=False):
print()
# ------------------------------------------------------------------
# Step 2 IP configuration
# Step 2 - IP configuration
# ------------------------------------------------------------------
physical = next((v for v in data["vlans"] if is_physical(v)), None)
if physical is None:
@ -3110,7 +3094,7 @@ def cmd_disable(data, dry_run=False):
print(" How should this machine obtain its IP address after reversion?")
print()
print(" [1] Obtain IP via DHCP (recommended let the new router assign one)")
print(" [1] Obtain IP via DHCP (recommended - let the new router assign one)")
print(" [2] Use a static IP")
print()
@ -3156,7 +3140,7 @@ def cmd_disable(data, dry_run=False):
print()
# ------------------------------------------------------------------
# Step 3 DNS resolver
# Step 3 - DNS resolver
# ------------------------------------------------------------------
# If resolv.conf is already a plain file with no router gateway IPs, leave it alone.
@ -3187,7 +3171,7 @@ def cmd_disable(data, dry_run=False):
print()
if resolved_available:
print(" [1] Re-enable systemd-resolved (recommended adapts to any network)")
print(" [1] Re-enable systemd-resolved (recommended - adapts to any network)")
print(" [2] Enter a static nameserver IP")
while True:
choice = input(" Choice [1/2]: ").strip()
@ -3219,7 +3203,7 @@ def cmd_disable(data, dry_run=False):
print()
# ------------------------------------------------------------------
# Step 4 Execute (or dry-run summary)
# Step 4 - Execute (or dry-run summary)
# ------------------------------------------------------------------
if dry_run:
_dry_run_disable(data, iface, use_dhcp, static_cidr, resolv_ok, dns_choice, static_nameserver)
@ -3260,9 +3244,9 @@ def cmd_disable(data, dry_run=False):
else:
print(f" Interface {iface} will use static IP: {static_cidr}")
# ------------------------------------------------------------------------------
# ===================================================================
# Main
# ------------------------------------------------------------------------------
# ===================================================================
def cmd_install(data):
@ -3455,7 +3439,7 @@ def main():
sys.exit(0)
if args.dry_run and not any([args.apply, args.disable]):
print("ERROR: --dry-run must be combined with --apply or --disable.")
print("ERROR: --dry-run must be combined with --apply or --disable.", file=sys.stderr)
sys.exit(1)
data = load_config()