- 
                Notifications
    You must be signed in to change notification settings 
- Fork 159
[WIP] nf2go: convert nftables rules to golang code #298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,299 @@ | ||
| package main | ||
|  | ||
| import ( | ||
| "bytes" | ||
| "context" | ||
| "fmt" | ||
| "io" | ||
| "io/ioutil" | ||
| "log" | ||
| "os" | ||
| "os/exec" | ||
| "path/filepath" | ||
| "regexp" | ||
| "runtime" | ||
| "strings" | ||
| "time" | ||
|  | ||
| "github.com/google/go-cmp/cmp" | ||
| "github.com/google/nftables" | ||
| "github.com/vishvananda/netns" | ||
| ) | ||
|  | ||
| func main() { | ||
| args := os.Args[1:] | ||
| if len(args) != 1 { | ||
| log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump") | ||
| } | ||
|  | ||
| filename := args[0] | ||
|  | ||
| runtime.LockOSThread() | ||
| defer runtime.UnlockOSThread() | ||
|  | ||
| // Create a new network namespace | ||
| ns, err := netns.New() | ||
| if err != nil { | ||
| log.Fatalf("netns.New() failed: %v", err) | ||
| } | ||
| n, err := nftables.New(nftables.WithNetNSFd(int(ns))) | ||
| if err != nil { | ||
| log.Fatalf("nftables.New() failed: %v", err) | ||
| } | ||
|  | ||
| scriptOutput, err := applyNFTRuleset(filename) | ||
| if err != nil { | ||
| log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput) | ||
| } | ||
|  | ||
| var buf bytes.Buffer | ||
| // Helper function to print to the file | ||
| pf := func(format string, a ...interface{}) { | ||
| _, err := fmt.Fprintf(&buf, format, a...) | ||
| if err != nil { | ||
| log.Fatal(err) | ||
| } | ||
| } | ||
|  | ||
| pf("// Code generated by nft2go. DO NOT EDIT.\n") | ||
| pf("package main\n\n") | ||
| pf("import (\n") | ||
| pf("\t\"fmt\"\n") | ||
| pf("\t\"log\"\n") | ||
| pf("\t\"github.com/google/nftables\"\n") | ||
| pf("\t\"github.com/google/nftables/expr\"\n") | ||
| pf(")\n\n") | ||
| pf("func main() {\n") | ||
| pf("\tn, err:= nftables.New()\n") | ||
| pf("\tif err!= nil {\n") | ||
| pf("\t\tlog.Fatal(err)\n") | ||
| pf("\t}\n\n") | ||
| pf("\n") | ||
| pf("\tvar expressions []expr.Any\n") | ||
| pf("\tvar chain *nftables.Chain\n") | ||
| pf("\tvar table *nftables.Table\n") | ||
|  | ||
| tables, err := n.ListTables() | ||
| if err != nil { | ||
| log.Fatalf("ListTables failed: %v", err) | ||
| } | ||
|  | ||
| chains, err := n.ListChains() | ||
| if err != nil { | ||
| log.Fatal(err) | ||
| } | ||
|  | ||
| for _, table := range tables { | ||
| log.Printf("processing table: %s", table.Name) | ||
|  | ||
| pf("\ttable = n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name) | ||
| for _, chain := range chains { | ||
| if chain.Table.Name != table.Name { | ||
| continue | ||
| } | ||
|  | ||
| sets, err := n.GetSets(table) | ||
| if err != nil { | ||
| log.Fatal(err) | ||
| } | ||
| for _, set := range sets { | ||
| // TODO datatype and the other options | ||
| pf("\tn.AddSet(&nftables.Set{\n") | ||
| pf("\t\tTable: table,\n") | ||
| pf("\t\tName: \"%s\",\n", set.Name) | ||
| pf("\t}, nil)\n") | ||
| } | ||
|  | ||
| pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n", | ||
| chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority)) | ||
|  | ||
| rules, err := n.GetRules(table, chain) | ||
| if err != nil { | ||
| log.Fatal(err) | ||
| } | ||
|  | ||
| for _, rule := range rules { | ||
| pf("\texpressions = []expr.Any{\n") | ||
| for _, exp := range rule.Exprs { | ||
| pf("\t\t%#v,\n", exp) | ||
| } | ||
| pf("\t\t}\n") | ||
| pf("\tn.AddRule(&nftables.Rule{\n") | ||
| pf("\t\tTable: table,\n") | ||
| pf("\t\tChain: chain,\n") | ||
| pf("\t\tExprs: expressions,\n") | ||
| pf("\t})\n") | ||
| } | ||
| } | ||
| } | ||
|  | ||
| pf("\n\tif err:= n.Flush(); err!= nil {\n") | ||
| pf("\t\tlog.Fatal(err)\n") | ||
| pf("\t}\n\n") | ||
| pf("\tfmt.Println(\"nft ruleset applied.\")\n") | ||
| pf("}\n") | ||
|  | ||
| // Program nftables using your Go code | ||
| if err := flushNFTRuleset(); err != nil { | ||
| log.Fatalf("Failed to flush nftables ruleset: %v", err) | ||
| } | ||
|  | ||
| // Create the output file | ||
| // Create a temporary directory | ||
| tempDir, err := ioutil.TempDir("", "nftables_gen") | ||
| if err != nil { | ||
| log.Fatal(err) | ||
| } | ||
| defer os.RemoveAll(tempDir) // Clean up the temporary directory | ||
|  | ||
| // Create the temporary Go file | ||
| tempGoFile := filepath.Join(tempDir, "nftables_recreate.go") | ||
| f, err := os.Create(tempGoFile) | ||
| if err != nil { | ||
| log.Fatal(err) | ||
| } | ||
| defer f.Close() | ||
|  | ||
| mw := io.MultiWriter(f, os.Stdout) | ||
| buf.WriteTo(mw) | ||
|  | ||
| // Format the generated code | ||
| log.Printf("formating file: %s", tempGoFile) | ||
| cmd := exec.Command("gofmt", "-w", "-s", tempGoFile) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use https://pkg.go.dev/go/format#Source instead of shelling out to gofmt There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That has a tradeoff in respecting the current GOTOOLCHAIN at the time of execution vs at the time of building this binary? | ||
| output, err := cmd.CombinedOutput() | ||
| if err != nil { | ||
| log.Fatalf("gofmt error: %v\nOutput: %s", err, output) | ||
| } | ||
|  | ||
| // Run the generated code | ||
| ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||
| defer cancel() | ||
|  | ||
| log.Printf("executing file: %s", tempGoFile) | ||
| cmd = exec.CommandContext(ctx, "go", "run", tempGoFile) | ||
| output, err = cmd.CombinedOutput() | ||
| if err != nil { | ||
| log.Fatalf("Execution error: %v\nOutput: %s", err, output) | ||
| } | ||
|  | ||
| // Retrieve nftables state using nft | ||
| log.Printf("obtain current ruleset: %s", tempGoFile) | ||
| actualOutput, err := listNFTRuleset() | ||
| if err != nil { | ||
| log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput) | ||
| } | ||
|  | ||
| expectedOutput, err := os.ReadFile(filename) | ||
| if err != nil { | ||
| log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput) | ||
| } | ||
|  | ||
| if !compareMultilineStringsIgnoreIndentation(string(expectedOutput), actualOutput) { | ||
| log.Printf("Expected output:\n%s", string(expectedOutput)) | ||
| log.Printf("Actual output:\n%s", actualOutput) | ||
|  | ||
| log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(string(expectedOutput), actualOutput)) | ||
| } | ||
|  | ||
| if err := flushNFTRuleset(); err != nil { | ||
| log.Fatalf("Failed to flush nftables ruleset: %v", err) | ||
| } | ||
| } | ||
|  | ||
| func applyNFTRuleset(scriptPath string) (string, error) { | ||
| cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath) | ||
| out, err := cmd.CombinedOutput() | ||
| if err != nil { | ||
| return string(out), err | ||
| } | ||
| return strings.TrimSpace(string(out)), nil | ||
| } | ||
|  | ||
| func listNFTRuleset() (string, error) { | ||
| cmd := exec.Command("nft", "list", "ruleset") | ||
| out, err := cmd.CombinedOutput() | ||
| if err != nil { | ||
| return string(out), err | ||
| } | ||
| return strings.TrimSpace(string(out)), nil | ||
| } | ||
|  | ||
| func flushNFTRuleset() error { | ||
| cmd := exec.Command("nft", "flush", "ruleset") | ||
| return cmd.Run() | ||
| } | ||
|  | ||
| func ChainHookRef(hookNum *nftables.ChainHook) string { | ||
| i := uint32(0) | ||
| if hookNum != nil { | ||
| i = uint32(*hookNum) | ||
| } | ||
| switch i { | ||
| case 0: | ||
| return "nftables.ChainHookPrerouting" | ||
| case 1: | ||
| return "nftables.ChainHookInput" | ||
| case 2: | ||
| return "nftables.ChainHookForward" | ||
| case 3: | ||
| return "nftables.ChainHookOutput" | ||
| case 4: | ||
| return "nftables.ChainHookPostrouting" | ||
| case 5: | ||
| return "nftables.ChainHookIngress" | ||
| case 6: | ||
| return "nftables.ChainHookEgress" | ||
| } | ||
| return "" | ||
| } | ||
|  | ||
| func ChainPrioRef(priority *nftables.ChainPriority) string { | ||
| i := int32(0) | ||
| if priority != nil { | ||
| i = int32(*priority) | ||
| } | ||
| return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i) | ||
| } | ||
|  | ||
| func ChainTypeString(chaintype nftables.ChainType) string { | ||
| switch chaintype { | ||
| case nftables.ChainTypeFilter: | ||
| return "nftables.ChainTypeFilter" | ||
| case nftables.ChainTypeRoute: | ||
| return "nftables.ChainTypeRoute" | ||
| case nftables.ChainTypeNAT: | ||
| return "nftables.ChainTypeNAT" | ||
| default: | ||
| return "nftables.ChainTypeFilter" | ||
| } | ||
| } | ||
|  | ||
| func TableFamilyString(family nftables.TableFamily) string { | ||
| switch family { | ||
| case nftables.TableFamilyUnspecified: | ||
| return "nftables.TableFamilyUnspecified" | ||
| case nftables.TableFamilyINet: | ||
| return "nftables.TableFamilyINet" | ||
| case nftables.TableFamilyIPv4: | ||
| return "nftables.TableFamilyIPv4" | ||
| case nftables.TableFamilyIPv6: | ||
| return "nftables.TableFamilyIPv6" | ||
| case nftables.TableFamilyARP: | ||
| return "nftables.TableFamilyARP" | ||
| case nftables.TableFamilyNetdev: | ||
| return "nftables.TableFamilyNetdev" | ||
| case nftables.TableFamilyBridge: | ||
| return "nftables.TableFamilyBridge" | ||
| default: | ||
| return "nftables.TableFamilyIPv4" | ||
| } | ||
| } | ||
|  | ||
| func compareMultilineStringsIgnoreIndentation(str1, str2 string) bool { | ||
| // Remove all indentation from both strings | ||
| re := regexp.MustCompile(`(?m)^\s+`) | ||
| str1 = re.ReplaceAllString(str1, "") | ||
| str2 = re.ReplaceAllString(str2, "") | ||
|  | ||
| return str1 == str2 | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use flag.Parse :)