diff options
| -rw-r--r-- | src/main.cpp | 17 | ||||
| -rw-r--r-- | src/wg2sd.cpp | 63 | ||||
| -rw-r--r-- | src/wg2sd.hpp | 1 | 
3 files changed, 77 insertions, 4 deletions
| diff --git a/src/main.cpp b/src/main.cpp index b592aa9..19b5ced 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -89,10 +89,11 @@ void err(char const * format, ...) {  }  void print_help(char const * prog) { -	err("Usage: %s [ -o OUTPUT_PATH ] CONFIG_FILE", prog); +	err("Usage: %s [ -h | -f | -o OUTPUT_PATH ] CONFIG_FILE", prog);  	err("Options:");  	err("-o OUTPUT_PATH\tSet the output path (default is /etc/systemd/network)"); -	err("-h\t\tDisplay this help message"); +	err("-f            \tOutput firewall rules"); +	err("-h            \tDisplay this help message");  	exit(EXIT_SUCCESS);  } @@ -144,12 +145,16 @@ void write_systemd_file(SystemdFilespec const & filespec, std::string output_pat  int main(int argc, char ** argv) {  	int opt;  	std::filesystem::path output_path = "/etc/systemd/network"; +	bool print_firewall_rules = false; -	while ((opt = getopt(argc, argv, "o:h")) != -1) { +	while ((opt = getopt(argc, argv, "o:fh")) != -1) {  		switch (opt) {  			case 'o':  				output_path = optarg;  				break; +			case 'f': +				print_firewall_rules = true; +				break;  			case 'h':  				print_help(argv[0]);  				break; @@ -186,6 +191,12 @@ int main(int argc, char ** argv) {  	} +	if(print_firewall_rules) { +		fprintf(stdout, "%s", cfg.firewall.c_str()); + +		return 0; +	} +  	if(!std::filesystem::path(output_path).is_absolute()) {  		output_path = std::filesystem::absolute(output_path);  	} diff --git a/src/wg2sd.cpp b/src/wg2sd.cpp index 2087cf5..4c39a03 100644 --- a/src/wg2sd.cpp +++ b/src/wg2sd.cpp @@ -6,6 +6,7 @@  #include <regex>  #include <argon2.h> +#include <string_view>  std::string hashed_keyfile_name(std::string const & priv_key) {  	constexpr uint8_t const SALT[] = { @@ -70,6 +71,16 @@ namespace wg2sd {  		return std::regex_match(cidr, ipv4);  	} +	std::string_view _get_addr(std::string_view const & cidr) { +		size_t suffix = cidr.rfind('/'); + +		if(suffix == std::string::npos) { +			return cidr; +		} else { +			return cidr.substr(0, suffix); +		} +	} +  	constexpr uint32_t MAIN_TABLE = 254;  	constexpr uint32_t LOCAL_TABLE = 255; @@ -292,6 +303,55 @@ namespace wg2sd {  		return cfg;  	} +	static void _write_table(std::stringstream & firewall, Config const & cfg, std::vector<std::string_view> addrs, bool ipv4) { +		char const * ip = ipv4 ? "ip" : "ip6"; + +		firewall << "table " << ip << " " << cfg.intf.name << " {\n" +		         << "  chain preraw {\n" +		         << "    type filter hook prerouting priority raw; policy accept;\n"; + +		for(std::string_view const & addr : addrs) { +			firewall << "    iifname != \"" << cfg.intf.name << "\" " << ip << " daddr " << addr << " fib saddr type != local drop;\n"; +		} + +		firewall << "  }\n" +		         << "\n" +		         << "  chain premangle {\n" +		         << "    type filter hook prerouting priority mangle; policy accept;\n" +		         << "    meta l4proto udp meta mark set ct mark;\n" +		         << "  }\n" +		         << "\n" +		         << "  chain postmangle {\n" +		         << "    type filter hook postrouting priority mangle; policy accept;\n" +		         << "    meta l4proto udp meta mark " << std::hex << cfg.intf.table << std::dec << "ct mark set meta mark;\n" +		         << "  }\n" +		         << "}\n"; +		 +	} + +	std::string _gen_nftables_firewall(Config const & cfg) { +		std::stringstream firewall; + +		std::vector<std::string_view> ipv4_addrs; +		std::vector<std::string_view> ipv6_addrs; + +		for(std::string const & addr : cfg.intf.addresses) { +			if(_is_ipv4_route(addr)) { +				ipv4_addrs.push_back(_get_addr(addr)); +			} else { +				ipv6_addrs.push_back(_get_addr(addr)); +			} +		} + +		_write_table(firewall, cfg, ipv4_addrs, true); + +		firewall << "\n"; + +		_write_table(firewall, cfg, ipv6_addrs, false); + +		return firewall.str(); +	} +  	static std::string _gen_netdev_cfg(Config const & cfg, uint32_t fwd_table, std::string const & private_keyfile,  			std::vector<SystemdFilespec> & symmetric_keyfiles, std::string const & output_path) {  		std::stringstream netdev; @@ -526,7 +586,8 @@ if(!cfg.intf.field_.empty()) { \  				.contents = cfg.intf.private_key + "\n",  			},  			.symmetric_keyfiles = std::move(symmetric_keyfiles), -			.warnings = std::move(warnings) +			.warnings = std::move(warnings), +			.firewall = _gen_nftables_firewall(cfg),  		};  	} diff --git a/src/wg2sd.hpp b/src/wg2sd.hpp index ea6f7fe..e80a284 100644 --- a/src/wg2sd.hpp +++ b/src/wg2sd.hpp @@ -127,6 +127,7 @@ namespace wg2sd {  		std::vector<SystemdFilespec> symmetric_keyfiles;  		std::vector<std::string> warnings; +		std::string firewall;  	};  	std::string interface_name_from_filename(std::filesystem::path config_path); | 
