Skip to content

Commit

Permalink
nf2go: convert nftables rules to golang code
Browse files Browse the repository at this point in the history
One of the biggest barriers to adopt the netlink format for nftables is
the complexity of writing bytecode.

This commits adds a tool that allows to take an nftables dump and
generate the corresponding golang code and validating that the generated
code produces the exact same output.

Change-Id: I491b35e0d8062de33c67091dd4126d843b231838
Signed-off-by: Antonio Ojea <[email protected]>
  • Loading branch information
aojea committed Feb 3, 2025
1 parent 69f487d commit 444f1f1
Show file tree
Hide file tree
Showing 3 changed files with 855 additions and 0 deletions.
288 changes: 288 additions & 0 deletions internal/nf2go/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
package main

import (
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strings"

"github.com/google/go-cmp/cmp"
"github.com/google/nftables"
"github.com/vishvananda/netns"
)

func main() {
args := os.Args[1:]
if len(args) != 1 {
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump")
}

filename := args[0]

runtime.LockOSThread()
defer runtime.UnlockOSThread()

// Create a new network namespace
ns, err := netns.New()
if err != nil {
log.Fatalf("netns.New() failed: %v", err)
}
n, err := nftables.New(nftables.WithNetNSFd(int(ns)))
if err != nil {
log.Fatalf("nftables.New() failed: %v", err)
}

scriptOutput, err := applyNFTRuleset(filename)
if err != nil {
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
}

var buf bytes.Buffer
// Helper function to print to the file
pf := func(format string, a ...interface{}) {
_, err := fmt.Fprintf(&buf, format, a...)
if err != nil {
log.Fatal(err)
}
}

pf("// Code generated by nft2go. DO NOT EDIT.\n")
pf("package main\n\n")
pf("import (\n")
pf("\t\"fmt\"\n")
pf("\t\"log\"\n")
pf("\t\"github.com/google/nftables\"\n")
pf("\t\"github.com/google/nftables/expr\"\n")
pf(")\n\n")
pf("func main() {\n")
pf("\tn, err:= nftables.New()\n")
pf("\tif err!= nil {\n")
pf("\t\tlog.Fatal(err)\n")
pf("\t}\n\n")
pf("\n")
pf("\tvar expressions []expr.Any\n")
pf("\tvar chain *nftables.Chain\n")

tables, err := n.ListTables()
if err != nil {
log.Fatalf("ListTables failed: %v", err)
}

chains, err := n.ListChains()
if err != nil {
log.Fatal(err)
}

for _, table := range tables {
pf("\ttable:= n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name)
for _, chain := range chains {
if chain.Table.Name != table.Name {
continue
}

sets, err := n.GetSets(table)
if err != nil {
log.Fatal(err)
}
for _, set := range sets {
// TODO datatype and the other options
pf("\tn.AddSet(&nftables.Set{\n")
pf("\t\tTable: table,\n")
pf("\t\tName: \"%s\",\n", set.Name)
pf("\t}, nil)\n")
}

pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n",
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority))

rules, err := n.GetRules(table, chain)
if err != nil {
log.Fatal(err)
}

for _, rule := range rules {
pf("\texpressions = []expr.Any{\n")
for _, exp := range rule.Exprs {
pf("\t\t%#v,\n", exp)
}
pf("\t\t}\n")
pf("\tn.AddRule(&nftables.Rule{\n")
pf("\t\tTable: table,\n")
pf("\t\tChain: chain,\n")
pf("\t\tExprs: expressions,\n")
pf("\t})\n")
}
}

pf("\n\tif err:= n.Flush(); err!= nil {\n")
pf("\t\tlog.Fatalf(\"fail to flush rules: %v\", err)\n")
pf("\t}\n\n")
pf("\tfmt.Println(\"nft ruleset applied.\")\n")
pf("}\n")

// Program nftables using your Go code
if err := flushNFTRuleset(); err != nil {
log.Fatalf("Failed to flush nftables ruleset: %v", err)
}

// Create the output file
// Create a temporary directory
tempDir, err := ioutil.TempDir("", "nftables_gen")
if err != nil {
log.Fatal(err)
}
defer os.RemoveAll(tempDir) // Clean up the temporary directory

// Create the temporary Go file
tempGoFile := filepath.Join(tempDir, "nftables_recreate.go")
f, err := os.Create(tempGoFile)
if err != nil {
log.Fatal(err)
}
defer f.Close()

mw := io.MultiWriter(f, os.Stdout)
buf.WriteTo(mw)

// Format the generated code
cmd := exec.Command("gofmt", "-w", "-s", tempGoFile)
output, err := cmd.CombinedOutput()
if err != nil {
log.Fatalf("gofmt error: %v\nOutput: %s", err, output)
}

// Run the generated code
cmd = exec.Command("go", "run", tempGoFile)
output, err = cmd.CombinedOutput()
if err != nil {
log.Fatalf("Execution error: %v\nOutput: %s", err, output)
}

// Retrieve nftables state using nft
actualOutput, err := listNFTRuleset()
if err != nil {
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
}

expectedOutput, err := os.ReadFile(filename)
if err != nil {
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
}

if !compareMultilineStringsIgnoreIndentation(string(expectedOutput), actualOutput) {
log.Printf("Expected output:\n%s", string(expectedOutput))
log.Printf("Actual output:\n%s", actualOutput)

log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(string(expectedOutput), actualOutput))
}

if err := flushNFTRuleset(); err != nil {
log.Fatalf("Failed to flush nftables ruleset: %v", err)
}
}
}

func applyNFTRuleset(scriptPath string) (string, error) {
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath)
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}

func listNFTRuleset() (string, error) {
cmd := exec.Command("nft", "list", "ruleset")
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}

func flushNFTRuleset() error {
cmd := exec.Command("nft", "flush", "ruleset")
return cmd.Run()
}

func ChainHookRef(hookNum *nftables.ChainHook) string {
i := uint32(0)
if hookNum != nil {
i = uint32(*hookNum)
}
switch i {
case 0:
return "nftables.ChainHookPrerouting"
case 1:
return "nftables.ChainHookInput"
case 2:
return "nftables.ChainHookForward"
case 3:
return "nftables.ChainHookOutput"
case 4:
return "nftables.ChainHookPostrouting"
case 5:
return "nftables.ChainHookIngress"
case 6:
return "nftables.ChainHookEgress"
}
return ""
}

func ChainPrioRef(priority *nftables.ChainPriority) string {
i := int32(0)
if priority != nil {
i = int32(*priority)
}
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i)
}

func ChainTypeString(chaintype nftables.ChainType) string {
switch chaintype {
case nftables.ChainTypeFilter:
return "nftables.ChainTypeFilter"
case nftables.ChainTypeRoute:
return "nftables.ChainTypeRoute"
case nftables.ChainTypeNAT:
return "nftables.ChainTypeNAT"
default:
return "nftables.ChainTypeFilter"
}
}

func TableFamilyString(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyUnspecified:
return "nftables.TableFamilyUnspecified"
case nftables.TableFamilyINet:
return "nftables.TableFamilyINet"
case nftables.TableFamilyIPv4:
return "nftables.TableFamilyIPv4"
case nftables.TableFamilyIPv6:
return "nftables.TableFamilyIPv6"
case nftables.TableFamilyARP:
return "nftables.TableFamilyARP"
case nftables.TableFamilyNetdev:
return "nftables.TableFamilyNetdev"
case nftables.TableFamilyBridge:
return "nftables.TableFamilyBridge"
default:
return "nftables.TableFamilyIPv4"
}
}

func compareMultilineStringsIgnoreIndentation(str1, str2 string) bool {
// Remove all indentation from both strings
re := regexp.MustCompile(`(?m)^\s+`)
str1 = re.ReplaceAllString(str1, "")
str2 = re.ReplaceAllString(str2, "")

return str1 == str2
}
Loading

0 comments on commit 444f1f1

Please sign in to comment.