-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
nf2go: convert nftables rules to golang code
One of the biggest barriers to adopt the netlink format for nftables is the complexity of writing bytecode. This commits adds a tool that allows to take an nftables dump and generate the corresponding golang code and validating that the generated code produces the exact same output. Change-Id: I491b35e0d8062de33c67091dd4126d843b231838 Signed-off-by: Antonio Ojea <[email protected]>
- Loading branch information
Showing
3 changed files
with
826 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"log" | ||
"os" | ||
"os/exec" | ||
"runtime" | ||
"strings" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"github.com/google/nftables" | ||
"github.com/vishvananda/netns" | ||
) | ||
|
||
func main() { | ||
args := os.Args[1:] | ||
if len(args) != 1 { | ||
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump") | ||
} | ||
|
||
filename := args[0] | ||
|
||
runtime.LockOSThread() | ||
defer runtime.UnlockOSThread() | ||
|
||
// Create a new network namespace | ||
ns, err := netns.New() | ||
if err != nil { | ||
log.Fatalf("netns.New() failed: %v", err) | ||
} | ||
n, err := nftables.New(nftables.WithNetNSFd(int(ns))) | ||
if err != nil { | ||
log.Fatalf("nftables.New() failed: %v", err) | ||
} | ||
|
||
scriptOutput, err := applyNFTRuleset(filename) | ||
if err != nil { | ||
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput) | ||
} | ||
if len(scriptOutput) > 0 { | ||
log.Printf("nft output:\n%s", scriptOutput) | ||
} | ||
|
||
// Create the output file | ||
f, err := os.Create("nftables_recreate.go") | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
defer f.Close() | ||
|
||
// Helper function to print to the file | ||
pf := func(format string, a ...interface{}) { | ||
_, err := fmt.Fprintf(f, format, a...) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
} | ||
|
||
pf("// Code generated by nft2go. DO NOT EDIT.\n") | ||
pf("package main\n\n") | ||
pf("import (\n") | ||
pf("\t\"fmt\"\n") | ||
pf("\t\"log\"\n") | ||
pf("\t\"github.com/google/nftables\"\n") | ||
pf("\t\"github.com/google/nftables/expr\"\n") | ||
pf(")\n\n") | ||
pf("func main() {\n") | ||
pf("\tn, err:= nftables.New()\n") | ||
pf("\tif err!= nil {\n") | ||
pf("\t\tlog.Fatal(err)\n") | ||
pf("\t}\n\n") | ||
pf("\n") | ||
pf("\tvar expressions []expr.Any\n") | ||
pf("\tvar chain *nftables.Chain\n") | ||
|
||
tables, err := n.ListTables() | ||
if err != nil { | ||
log.Fatalf("ListTables failed: %v", err) | ||
} | ||
|
||
chains, err := n.ListChains() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
for _, table := range tables { | ||
pf("\ttable:= n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name) | ||
for _, chain := range chains { | ||
if chain.Table.Name != table.Name { | ||
continue | ||
} | ||
|
||
pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n", | ||
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority)) | ||
|
||
rules, err := n.GetRules(table, chain) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
for _, rule := range rules { | ||
pf("\texpressions = []expr.Any{\n") | ||
for _, exp := range rule.Exprs { | ||
pf("\t\t%#v,\n", exp) | ||
} | ||
pf("\t\t}\n") | ||
pf("\tn.AddRule(&nftables.Rule{\n") | ||
pf("\t\tTable: table,\n") | ||
pf("\t\tChain: chain,\n") | ||
pf("\t\tExprs: expressions,\n") | ||
pf("\t})\n") | ||
} | ||
} | ||
|
||
pf("\n\tif err:= n.Flush(); err!= nil {\n") | ||
pf("\t\tlog.Fatalf(\"fail to flush rules: %v\", err)\n") | ||
pf("\t}\n\n") | ||
pf("\tfmt.Println(\"nft ruleset applied.\")\n") | ||
pf("}\n") | ||
|
||
// Program nftables using your Go code | ||
if err := flushNFTRuleset(); err != nil { | ||
log.Fatalf("Failed to flush nftables ruleset: %v", err) | ||
} | ||
|
||
// Format the generated code | ||
cmd := exec.Command("gofmt", "-w", "-s", "nftables_recreate.go") | ||
output, err := cmd.CombinedOutput() | ||
if err != nil { | ||
log.Fatalf("gofmt error: %v\nOutput: %s", err, output) | ||
} | ||
|
||
// Run the generated code | ||
cmd = exec.Command("go", "run", "nftables_recreate.go") | ||
output, err = cmd.CombinedOutput() | ||
if err != nil { | ||
log.Fatalf("Execution error: %v\nOutput: %s", err, output) | ||
} | ||
|
||
// Retrieve nftables state using nft | ||
actualOutput, err := listNFTRuleset() | ||
if err != nil { | ||
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput) | ||
} | ||
|
||
log.Printf("Actual output:\n%s", actualOutput) | ||
|
||
expectedOutput, err := os.ReadFile(filename) | ||
if err != nil { | ||
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput) | ||
} | ||
|
||
if string(expectedOutput) != actualOutput { | ||
log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(expectedOutput, actualOutput)) | ||
} | ||
|
||
if err := flushNFTRuleset(); err != nil { | ||
log.Fatalf("Failed to flush nftables ruleset: %v", err) | ||
} | ||
} | ||
} | ||
|
||
func applyNFTRuleset(scriptPath string) (string, error) { | ||
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath) | ||
out, err := cmd.CombinedOutput() | ||
if err != nil { | ||
return string(out), err | ||
} | ||
return strings.TrimSpace(string(out)), nil | ||
} | ||
|
||
func listNFTRuleset() (string, error) { | ||
cmd := exec.Command("nft", "list", "ruleset") | ||
out, err := cmd.CombinedOutput() | ||
if err != nil { | ||
return string(out), err | ||
} | ||
return strings.TrimSpace(string(out)), nil | ||
} | ||
|
||
func flushNFTRuleset() error { | ||
cmd := exec.Command("nft", "flush", "ruleset") | ||
return cmd.Run() | ||
} | ||
|
||
func ChainHookRef(hookNum *nftables.ChainHook) string { | ||
i := uint32(0) | ||
if hookNum != nil { | ||
i = uint32(*hookNum) | ||
} | ||
switch i { | ||
case 0: | ||
return "nftables.ChainHookPrerouting" | ||
case 1: | ||
return "nftables.ChainHookInput" | ||
case 2: | ||
return "nftables.ChainHookForward" | ||
case 3: | ||
return "nftables.ChainHookOutput" | ||
case 4: | ||
return "nftables.ChainHookPostrouting" | ||
case 5: | ||
return "nftables.ChainHookIngress" | ||
case 6: | ||
return "nftables.ChainHookEgress" | ||
} | ||
return "" | ||
} | ||
|
||
func ChainPrioRef(priority *nftables.ChainPriority) string { | ||
i := int32(0) | ||
if priority != nil { | ||
i = int32(*priority) | ||
} | ||
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i) | ||
} | ||
|
||
func ChainTypeString(chaintype nftables.ChainType) string { | ||
switch chaintype { | ||
case nftables.ChainTypeFilter: | ||
return "nftables.ChainTypeFilter" | ||
case nftables.ChainTypeRoute: | ||
return "nftables.ChainTypeRoute" | ||
case nftables.ChainTypeNAT: | ||
return "nftables.ChainTypeNAT" | ||
default: | ||
return "nftables.ChainTypeFilter" | ||
} | ||
} | ||
|
||
func TableFamilyString(family nftables.TableFamily) string { | ||
switch family { | ||
case nftables.TableFamilyUnspecified: | ||
return "nftables.TableFamilyUnspecified" | ||
case nftables.TableFamilyINet: | ||
return "nftables.TableFamilyINet" | ||
case nftables.TableFamilyIPv4: | ||
return "nftables.TableFamilyIPv4" | ||
case nftables.TableFamilyIPv6: | ||
return "nftables.TableFamilyIPv6" | ||
case nftables.TableFamilyARP: | ||
return "nftables.TableFamilyARP" | ||
case nftables.TableFamilyNetdev: | ||
return "nftables.TableFamilyNetdev" | ||
case nftables.TableFamilyBridge: | ||
return "nftables.TableFamilyBridge" | ||
default: | ||
return "nftables.TableFamilyIPv4" | ||
} | ||
} |
Oops, something went wrong.