Skip to content

Commit

Permalink
nf2go: convert nftables rules to golang code
Browse files Browse the repository at this point in the history
One of the biggest barriers to adopt the netlink format for nftables is
the complexity of writing bytecode.

This commits adds a tool that allows to take an nftables dump and
generate the corresponding golang code and validating that the generated
code produces the exact same output.

Change-Id: I491b35e0d8062de33c67091dd4126d843b231838
Signed-off-by: Antonio Ojea <[email protected]>
  • Loading branch information
aojea committed Feb 2, 2025
1 parent 69f487d commit a2644bc
Show file tree
Hide file tree
Showing 3 changed files with 826 additions and 0 deletions.
251 changes: 251 additions & 0 deletions internal/nf2go/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package main

import (
"fmt"
"log"
"os"
"os/exec"
"runtime"
"strings"

"github.com/google/go-cmp/cmp"
"github.com/google/nftables"
"github.com/vishvananda/netns"
)

func main() {
args := os.Args[1:]
if len(args) != 1 {
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump")
}

filename := args[0]

runtime.LockOSThread()
defer runtime.UnlockOSThread()

// Create a new network namespace
ns, err := netns.New()
if err != nil {
log.Fatalf("netns.New() failed: %v", err)
}
n, err := nftables.New(nftables.WithNetNSFd(int(ns)))
if err != nil {
log.Fatalf("nftables.New() failed: %v", err)
}

scriptOutput, err := applyNFTRuleset(filename)
if err != nil {
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
}
if len(scriptOutput) > 0 {
log.Printf("nft output:\n%s", scriptOutput)
}

// Create the output file
f, err := os.Create("nftables_recreate.go")
if err != nil {
log.Fatal(err)
}
defer f.Close()

// Helper function to print to the file
pf := func(format string, a ...interface{}) {
_, err := fmt.Fprintf(f, format, a...)
if err != nil {
log.Fatal(err)
}
}

pf("// Code generated by nft2go. DO NOT EDIT.\n")
pf("package main\n\n")
pf("import (\n")
pf("\t\"fmt\"\n")
pf("\t\"log\"\n")
pf("\t\"github.com/google/nftables\"\n")
pf("\t\"github.com/google/nftables/expr\"\n")
pf(")\n\n")
pf("func main() {\n")
pf("\tn, err:= nftables.New()\n")
pf("\tif err!= nil {\n")
pf("\t\tlog.Fatal(err)\n")
pf("\t}\n\n")
pf("\n")
pf("\tvar expressions []expr.Any\n")
pf("\tvar chain *nftables.Chain\n")

tables, err := n.ListTables()
if err != nil {
log.Fatalf("ListTables failed: %v", err)
}

chains, err := n.ListChains()
if err != nil {
log.Fatal(err)
}

for _, table := range tables {
pf("\ttable:= n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name)
for _, chain := range chains {
if chain.Table.Name != table.Name {
continue
}

pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n",
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority))

rules, err := n.GetRules(table, chain)
if err != nil {
log.Fatal(err)
}

for _, rule := range rules {
pf("\texpressions = []expr.Any{\n")
for _, exp := range rule.Exprs {
pf("\t\t%#v,\n", exp)
}
pf("\t\t}\n")
pf("\tn.AddRule(&nftables.Rule{\n")
pf("\t\tTable: table,\n")
pf("\t\tChain: chain,\n")
pf("\t\tExprs: expressions,\n")
pf("\t})\n")
}
}

pf("\n\tif err:= n.Flush(); err!= nil {\n")
pf("\t\tlog.Fatalf(\"fail to flush rules: %v\", err)\n")
pf("\t}\n\n")
pf("\tfmt.Println(\"nft ruleset applied.\")\n")
pf("}\n")

// Program nftables using your Go code
if err := flushNFTRuleset(); err != nil {
log.Fatalf("Failed to flush nftables ruleset: %v", err)
}

// Format the generated code
cmd := exec.Command("gofmt", "-w", "-s", "nftables_recreate.go")
output, err := cmd.CombinedOutput()
if err != nil {
log.Fatalf("gofmt error: %v\nOutput: %s", err, output)
}

// Run the generated code
cmd = exec.Command("go", "run", "nftables_recreate.go")
output, err = cmd.CombinedOutput()
if err != nil {
log.Fatalf("Execution error: %v\nOutput: %s", err, output)
}

// Retrieve nftables state using nft
actualOutput, err := listNFTRuleset()
if err != nil {
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
}

log.Printf("Actual output:\n%s", actualOutput)

expectedOutput, err := os.ReadFile(filename)
if err != nil {
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
}

if string(expectedOutput) != actualOutput {
log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(expectedOutput, actualOutput))
}

if err := flushNFTRuleset(); err != nil {
log.Fatalf("Failed to flush nftables ruleset: %v", err)
}
}
}

func applyNFTRuleset(scriptPath string) (string, error) {
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath)
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}

func listNFTRuleset() (string, error) {
cmd := exec.Command("nft", "list", "ruleset")
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}

func flushNFTRuleset() error {
cmd := exec.Command("nft", "flush", "ruleset")
return cmd.Run()
}

func ChainHookRef(hookNum *nftables.ChainHook) string {
i := uint32(0)
if hookNum != nil {
i = uint32(*hookNum)
}
switch i {
case 0:
return "nftables.ChainHookPrerouting"
case 1:
return "nftables.ChainHookInput"
case 2:
return "nftables.ChainHookForward"
case 3:
return "nftables.ChainHookOutput"
case 4:
return "nftables.ChainHookPostrouting"
case 5:
return "nftables.ChainHookIngress"
case 6:
return "nftables.ChainHookEgress"
}
return ""
}

func ChainPrioRef(priority *nftables.ChainPriority) string {
i := int32(0)
if priority != nil {
i = int32(*priority)
}
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i)
}

func ChainTypeString(chaintype nftables.ChainType) string {
switch chaintype {
case nftables.ChainTypeFilter:
return "nftables.ChainTypeFilter"
case nftables.ChainTypeRoute:
return "nftables.ChainTypeRoute"
case nftables.ChainTypeNAT:
return "nftables.ChainTypeNAT"
default:
return "nftables.ChainTypeFilter"
}
}

func TableFamilyString(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyUnspecified:
return "nftables.TableFamilyUnspecified"
case nftables.TableFamilyINet:
return "nftables.TableFamilyINet"
case nftables.TableFamilyIPv4:
return "nftables.TableFamilyIPv4"
case nftables.TableFamilyIPv6:
return "nftables.TableFamilyIPv6"
case nftables.TableFamilyARP:
return "nftables.TableFamilyARP"
case nftables.TableFamilyNetdev:
return "nftables.TableFamilyNetdev"
case nftables.TableFamilyBridge:
return "nftables.TableFamilyBridge"
default:
return "nftables.TableFamilyIPv4"
}
}
Loading

0 comments on commit a2644bc

Please sign in to comment.