From 2f0a9c87bd5acd8fc0852f599599d031cde44bbe Mon Sep 17 00:00:00 2001
From: flu0r1ne <flu0r1ne@flu0r1ne.net>
Date: Fri, 18 Aug 2023 01:13:07 -0500
Subject: Add firewall rule generator

---
 src/main.cpp  | 17 +++++++++++++---
 src/wg2sd.cpp | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 src/wg2sd.hpp |  1 +
 3 files changed, 77 insertions(+), 4 deletions(-)

diff --git a/src/main.cpp b/src/main.cpp
index b592aa9..19b5ced 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -89,10 +89,11 @@ void err(char const * format, ...) {
 }
 
 void print_help(char const * prog) {
-	err("Usage: %s [ -o OUTPUT_PATH ] CONFIG_FILE", prog);
+	err("Usage: %s [ -h | -f | -o OUTPUT_PATH ] CONFIG_FILE", prog);
 	err("Options:");
 	err("-o OUTPUT_PATH\tSet the output path (default is /etc/systemd/network)");
-	err("-h\t\tDisplay this help message");
+	err("-f            \tOutput firewall rules");
+	err("-h            \tDisplay this help message");
 	exit(EXIT_SUCCESS);
 }
 
@@ -144,12 +145,16 @@ void write_systemd_file(SystemdFilespec const & filespec, std::string output_pat
 int main(int argc, char ** argv) {
 	int opt;
 	std::filesystem::path output_path = "/etc/systemd/network";
+	bool print_firewall_rules = false;
 
-	while ((opt = getopt(argc, argv, "o:h")) != -1) {
+	while ((opt = getopt(argc, argv, "o:fh")) != -1) {
 		switch (opt) {
 			case 'o':
 				output_path = optarg;
 				break;
+			case 'f':
+				print_firewall_rules = true;
+				break;
 			case 'h':
 				print_help(argv[0]);
 				break;
@@ -186,6 +191,12 @@ int main(int argc, char ** argv) {
 
 	}
 
+	if(print_firewall_rules) {
+		fprintf(stdout, "%s", cfg.firewall.c_str());
+
+		return 0;
+	}
+
 	if(!std::filesystem::path(output_path).is_absolute()) {
 		output_path = std::filesystem::absolute(output_path);
 	}
diff --git a/src/wg2sd.cpp b/src/wg2sd.cpp
index 2087cf5..4c39a03 100644
--- a/src/wg2sd.cpp
+++ b/src/wg2sd.cpp
@@ -6,6 +6,7 @@
 #include <regex>
 
 #include <argon2.h>
+#include <string_view>
 
 std::string hashed_keyfile_name(std::string const & priv_key) {
 	constexpr uint8_t const SALT[] = {
@@ -70,6 +71,16 @@ namespace wg2sd {
 		return std::regex_match(cidr, ipv4);
 	}
 
+	std::string_view _get_addr(std::string_view const & cidr) {
+		size_t suffix = cidr.rfind('/');
+
+		if(suffix == std::string::npos) {
+			return cidr;
+		} else {
+			return cidr.substr(0, suffix);
+		}
+	}
+
 	constexpr uint32_t MAIN_TABLE = 254;
 	constexpr uint32_t LOCAL_TABLE = 255;
 
@@ -292,6 +303,55 @@ namespace wg2sd {
 		return cfg;
 	}
 
+	static void _write_table(std::stringstream & firewall, Config const & cfg, std::vector<std::string_view> addrs, bool ipv4) {
+		char const * ip = ipv4 ? "ip" : "ip6";
+
+		firewall << "table " << ip << " " << cfg.intf.name << " {\n"
+		         << "  chain preraw {\n"
+		         << "    type filter hook prerouting priority raw; policy accept;\n";
+
+		for(std::string_view const & addr : addrs) {
+			firewall << "    iifname != \"" << cfg.intf.name << "\" " << ip << " daddr " << addr << " fib saddr type != local drop;\n";
+		}
+
+		firewall << "  }\n"
+		         << "\n"
+		         << "  chain premangle {\n"
+		         << "    type filter hook prerouting priority mangle; policy accept;\n"
+		         << "    meta l4proto udp meta mark set ct mark;\n"
+		         << "  }\n"
+		         << "\n"
+		         << "  chain postmangle {\n"
+		         << "    type filter hook postrouting priority mangle; policy accept;\n"
+		         << "    meta l4proto udp meta mark " << std::hex << cfg.intf.table << std::dec << "ct mark set meta mark;\n"
+		         << "  }\n"
+		         << "}\n";
+		
+	}
+
+	std::string _gen_nftables_firewall(Config const & cfg) {
+		std::stringstream firewall;
+
+		std::vector<std::string_view> ipv4_addrs;
+		std::vector<std::string_view> ipv6_addrs;
+
+		for(std::string const & addr : cfg.intf.addresses) {
+			if(_is_ipv4_route(addr)) {
+				ipv4_addrs.push_back(_get_addr(addr));
+			} else {
+				ipv6_addrs.push_back(_get_addr(addr));
+			}
+		}
+
+		_write_table(firewall, cfg, ipv4_addrs, true);
+
+		firewall << "\n";
+
+		_write_table(firewall, cfg, ipv6_addrs, false);
+
+		return firewall.str();
+	}
+
 	static std::string _gen_netdev_cfg(Config const & cfg, uint32_t fwd_table, std::string const & private_keyfile,
 			std::vector<SystemdFilespec> & symmetric_keyfiles, std::string const & output_path) {
 		std::stringstream netdev;
@@ -526,7 +586,8 @@ if(!cfg.intf.field_.empty()) { \
 				.contents = cfg.intf.private_key + "\n",
 			},
 			.symmetric_keyfiles = std::move(symmetric_keyfiles),
-			.warnings = std::move(warnings)
+			.warnings = std::move(warnings),
+			.firewall = _gen_nftables_firewall(cfg),
 		};
 	}
 
diff --git a/src/wg2sd.hpp b/src/wg2sd.hpp
index ea6f7fe..e80a284 100644
--- a/src/wg2sd.hpp
+++ b/src/wg2sd.hpp
@@ -127,6 +127,7 @@ namespace wg2sd {
 		std::vector<SystemdFilespec> symmetric_keyfiles;
 
 		std::vector<std::string> warnings;
+		std::string firewall;
 	};
 
 	std::string interface_name_from_filename(std::filesystem::path config_path);
-- 
cgit v1.2.3