summaryrefslogtreecommitdiff
path: root/blocklist2nft
blob: fc5fc4ad02c03dcd66cf17453dcdab8c1bb0f442 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/python3
""" blocklist2nft """

import argparse
import ipaddress
import logging
import re
import subprocess
import tempfile

import requests

blocklist_urls= [
    "https://rules.emergingthreats.net/fwrules/emerging-Block-IPs.txt",
    "https://rules.emergingthreats.net/blockrules/compromised-ips.txt",
    "https://iplists.firehol.org/files/firehol_level1.netset"
    ]

NFT_PRIORITY = -2

def nft_commands(v4_list, v6_list, priority):
    """ generate nft table commands """
    return f"""
add table inet blocklist2nft
delete table inet blocklist2nft
add table inet blocklist2nft {{
    set addr-set-drop4 {{
        type ipv4_addr
        flags interval
        elements = {{ {v4_list}
        }}
    }}
    set addr-set-drop6 {{
        type ipv6_addr
        flags interval
        elements = {{
            FEC0::/10, {v6_list}
        }}
    }}
    chain INPUT {{
        type filter hook input priority filter {priority}; policy accept;
        ip saddr @addr-set-drop4 drop
        ip6 saddr @addr-set-drop6 drop
    }}
}}
    """

ipv4_list = []
ipv6_list = []

def is_ip_address(a):
    """ is this an IP address block """
    n = ipaddress.ip_network(a)
    return n

def get_address_feed(feed):
    """ download and pack a feed """
    try:
        with requests.get(feed, stream=True, timeout=30) as r:
            r.raise_for_status()
            for r in r.iter_lines(decode_unicode=True):
                if isinstance(r, bytes):
                    r = r.decode('utf-8', errors="replace")
                if not re.match(r'^[0-9a-fA-F]', r):
                    continue
                net = is_ip_address(r)
                if net is None:
                    continue
                if net.version == 4:
                    ipv4_list.append(net)
                if net.version == 6:
                    print(r)
                    ipv6_list.append(net)
    except requests.exceptions.ConnectionError as e:
        log.error(e)

def main():
    """" self-explanatory """
    ipv4_list_str = ""
    ipv6_list_str = ""

    try:
        for feed in blocklist_urls:
            get_address_feed(feed)
        for net in ipaddress.collapse_addresses(ipv4_list):
            ipv4_list_str += f"{net}, \n"
        for net in ipaddress.collapse_addresses(ipv6_list):
            ipv6_list_str += f"{net}, \n"

        ipv4_list_str = ipv4_list_str[:-3]
        ipv6_list_str = ipv6_list_str[:-3]

        with tempfile.NamedTemporaryFile(delete=True) as nft_file:
            nft_contents = nft_commands(ipv4_list_str, ipv6_list_str, args.priority)
            nft_contents_bytes = nft_contents.encode("utf-8")
            nft_file.write(nft_contents_bytes)

            nft_result = subprocess.run(
                ["nft", "-f", nft_file.name],
                check=True,
                capture_output=True,
                text=True
            )
            if nft_result.returncode != 0:
                print(f"nft call failed with code {nft_result.returncode}")
                print(nft_result.stderr)
                print(nft_result.stdout)

    except Exception as e:
        log.error(e)


ap = argparse.ArgumentParser(description="Create an nftables blocklist from a URL data source")
ap.add_argument("--priority", type=int,
    default=NFT_PRIORITY, help="nft chain hook priority (default -2)")
ap.add_argument("--debug", action="store_true", help="Enable debug logging")
args = ap.parse_args()

log = logging.getLogger("blocklist2nft")
lh = logging.StreamHandler()
lf = logging.Formatter(fmt="%(levelname)s: %(message)s")
lh.setFormatter(lf)
log.addHandler(lh)

log.setLevel(logging.INFO)
if args.debug:
    log.setLevel(logging.DEBUG)

if __name__ == '__main__':
    main()