Skip to content

Commit

Permalink
Add test cases (#60)
Browse files Browse the repository at this point in the history
* Add test cases

Signed-off-by: Ritika Srivastava <[email protected]>

* Move unexported functions

Signed-off-by: Ritika Srivastava <[email protected]>

---------

Signed-off-by: Ritika Srivastava <[email protected]>
  • Loading branch information
ritikasrivastava authored Jan 24, 2025
1 parent 63519a8 commit 9166017
Show file tree
Hide file tree
Showing 4 changed files with 429 additions and 200 deletions.
88 changes: 54 additions & 34 deletions pkg/providers/baremetal/mnnvl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package baremetal

import (
"bufio"
"bytes"
"context"
"fmt"
"strconv"
Expand All @@ -12,7 +13,7 @@ import (
"github.com/NVIDIA/topograph/pkg/topology"
)

// domain contains map of each domainID(clusterUUID) -> list of nodeNames in that domain
// domain contains map of each domainID(clusterUUID) -> map of nodeNames in that domain
// Each domain will be a separate NVL Domain
type domain struct {
nodeMap map[string]bool // nodeName: true
Expand All @@ -39,22 +40,8 @@ func domainIDExists(id string, domainMap map[string]domain) bool {
return false
}

func getIbTree(ctx context.Context, _ []string) (*topology.Vertex, error) {
nodeVisited := make(map[string]bool)
treeRoot := &topology.Vertex{
Vertices: make(map[string]*topology.Vertex),
}
ibPrefix := "IB"
ibCount := 0
func populatePartitions(stdout *bytes.Buffer) (map[string][]string, error) {
partitionNodeMap := make(map[string][]string)
partitionVisitedMap := make(map[string]bool)

args := []string{"-h"}
stdout, err := exec.Exec(ctx, "sinfo", args, nil)
if err != nil {
return nil, fmt.Errorf("exec error in sinfo: %v", err)
}

// scan each line containing slurm partition and the nodes in it
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
Expand All @@ -72,6 +59,28 @@ func getIbTree(ctx context.Context, _ []string) (*topology.Vertex, error) {
// map of slurm partition name -> node names
partitionNodeMap[partitionName] = append(partitionNodeMap[partitionName], nodesArr...)
}
return partitionNodeMap, nil
}

func getIbTree(ctx context.Context, _ []string) (*topology.Vertex, error) {
nodeVisited := make(map[string]bool)
treeRoot := &topology.Vertex{
Vertices: make(map[string]*topology.Vertex),
}
ibPrefix := "IB"
ibCount := 0
partitionVisitedMap := make(map[string]bool)

args := []string{"-h"}
stdout, err := exec.Exec(ctx, "sinfo", args, nil)
if err != nil {
return nil, fmt.Errorf("exec error in sinfo: %v", err)
}

partitionNodeMap, err := populatePartitions(stdout)
if err != nil {
return nil, fmt.Errorf("populatePartitions failed : %v", err)
}
for pName, nodes := range partitionNodeMap {
// for each partition in slurm, find the IB tree it belongs to
if _, exists := partitionVisitedMap[pName]; !exists {
Expand Down Expand Up @@ -116,15 +125,17 @@ func deCompressNodeNames(nodeList string) ([]string, error) {
arr := strings.Split(nodeList, ",")
prefix := ""
var nodeName string
resetPrefix := false

// example : nodename-1-[001-004 , 007, 91-99 , 100], nodename-2-89
for _, entry := range arr {
// example : nodename-1-[001-004
if strings.Contains(entry, "[") {
// example : 100]
// example : nodename-1-[001-004]
entryWithoutSuffix := strings.TrimSuffix(entry, "]")
tuple := strings.Split(entryWithoutSuffix, "[")
prefix = tuple[0]
resetPrefix = false
// example : nodename-1-[001-004
if strings.Contains(tuple[1], "-") {
nr := strings.Split(tuple[1], "-")
Expand Down Expand Up @@ -153,10 +164,10 @@ func deCompressNodeNames(nodeList string) ([]string, error) {
// example: 100], nodename-2-89, 90
if len(prefix) > 0 { //prefix exists, so must be a suffix.
if strings.HasSuffix(entry, "]") { //if suffix has ], reset prefix
nv := strings.Split(entry, "]")
nodeName = prefix + nv[0]
prefix = ""
} else if strings.Contains(entry, "-") { // suffix containing range of nodes
entry = strings.TrimSuffix(entry, "]")
resetPrefix = true
}
if strings.Contains(entry, "-") { // suffix containing range of nodes
// example: 100-102]
nr := strings.Split(entry, "-")
w := len(nr[0])
Expand All @@ -173,11 +184,17 @@ func deCompressNodeNames(nodeList string) ([]string, error) {
nodeName = prefix + suffixNum
nodeArr = append(nodeArr, nodeName)
}
if resetPrefix {
prefix = ""
}
// avoid another nodename append at the end
continue
} else {
//example: 90
nodeName = prefix + entry
if resetPrefix {
prefix = ""
}
}
} else { // no prefix yet, must be whole nodename
//example: nodename-2-89
Expand All @@ -189,22 +206,16 @@ func deCompressNodeNames(nodeList string) ([]string, error) {
return nodeArr, nil
}

// getClusterOutput reads output from nodeInfo and populates the structs
func getClusterOutput(ctx context.Context, domainMap map[string]domain, nodes []string, cmd string) error {
args := []string{"-R", "ssh", "-w", strings.Join(nodes, ","), cmd}
stdout, err := exec.Exec(ctx, "pdsh", args, nil)
if err != nil {
return fmt.Errorf("exec error while pdsh: %v", err)
}

func populateDomains(stdout *bytes.Buffer) (map[string]domain, error) {
domainMap := make(map[string]domain) // domainID: domain
scanner := bufio.NewScanner(stdout)
cliqueId := ""
clusterUUID := ""
domainName := ""
for scanner.Scan() {
nodeLine := scanner.Text()
arr := strings.Split(nodeLine, ":")
nodeName := arr[0]
nodeName := strings.TrimSpace(arr[0])
itemName := strings.TrimSpace(arr[1])
if itemName == "CliqueId" {
cliqueId = strings.TrimSpace(arr[2])
Expand All @@ -221,10 +232,19 @@ func getClusterOutput(ctx context.Context, domainMap map[string]domain, nodes []
nodeMap[nodeName] = true
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error while reading pdsh output: %v", err)
return nil, fmt.Errorf("scanner error while reading pdsh output: %v", err)
}
return domainMap, nil
}

return nil
// getClusterOutput reads output from nodeInfo and populates the structs
func getClusterOutput(ctx context.Context, nodes []string, cmd string) (map[string]domain, error) {
args := []string{"-R", "ssh", "-w", strings.Join(nodes, ","), cmd}
stdout, err := exec.Exec(ctx, "pdsh", args, nil)
if err != nil {
return nil, fmt.Errorf("exec error while pdsh: %v", err)
}
return populateDomains(stdout)
}

func toGraph(domainMap map[string]domain, treeRoot *topology.Vertex) *topology.Vertex {
Expand All @@ -251,9 +271,9 @@ func toGraph(domainMap map[string]domain, treeRoot *topology.Vertex) *topology.V
}

func generateTopologyConfig(ctx context.Context, cis []topology.ComputeInstances) (*topology.Vertex, error) {
domainMap := make(map[string]domain) // domainID: domain

nodes := getNodeList(cis)
err := getClusterOutput(ctx, domainMap, nodes, `nvidia-smi -q | grep "ClusterUUID\|CliqueId"`)
domainMap, err := getClusterOutput(ctx, nodes, `nvidia-smi -q | grep "ClusterUUID\|CliqueId"`)
if err != nil {
return nil, fmt.Errorf("getClusterOutput failed: %v", err)
}
Expand Down
92 changes: 92 additions & 0 deletions pkg/providers/baremetal/topology_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package baremetal

import (
"bytes"
"testing"

"github.com/stretchr/testify/require"
)

func TestClique(t *testing.T) {
cliqueOutput := `node-10: CliqueId : 4000000004
node-10: ClusterUUID : 50000000-0000-0000-0000-000000000005
node-10: CliqueId : 4000000004
node-10: ClusterUUID : 50000000-0000-0000-0000-000000000005
node-07: CliqueId : 4000000005
node-07: ClusterUUID : 50000000-0000-0000-0000-000000000004
node-07: CliqueId : 4000000005
node-07: ClusterUUID : 50000000-0000-0000-0000-000000000004
node-08: CliqueId : 4000000005
node-08: ClusterUUID : 50000000-0000-0000-0000-000000000004
node-08: CliqueId : 4000000005
node-08: ClusterUUID : 50000000-0000-0000-0000-000000000004
node-09: CliqueId : 4000000005
node-09: ClusterUUID : 50000000-0000-0000-0000-000000000005
node-09: CliqueId : 4000000005
node-09: ClusterUUID : 50000000-0000-0000-0000-000000000005`

domainObj45 := domain{
nodeMap: map[string]bool{
"node-07": true,
"node-08": true,
},
}

domainObj54 := domain{
nodeMap: map[string]bool{
"node-10": true,
},
}

domainObj55 := domain{
nodeMap: map[string]bool{
"node-09": true,
},
}

expectedDomainMap := map[string]domain{
"50000000-0000-0000-0000-0000000000044000000005": domainObj45,
"50000000-0000-0000-0000-0000000000054000000004": domainObj54,
"50000000-0000-0000-0000-0000000000054000000005": domainObj55,
}

domainMap, err := populateDomains(bytes.NewBufferString(cliqueOutput))
require.NoError(t, err)
require.Equal(t, expectedDomainMap, domainMap)
}

func TestSlurmPartition(t *testing.T) {
partitions := `cq up 6:00:00 1 down* node2-14
cq up 6:00:00 1 drain node1-01
cq up 6:00:00 30 idle node1-[02-16],node2-[01-13,15-16]
c1q up 8:00:00 1 drain node1-01
c1q up 8:00:00 15 idle node1-[02-16]
c2q up 8:00:00 1 down* node2-14
c2q up 8:00:00 15 idle node2-[01-13,15-16]`

expectedPartitionMap := map[string][]string{
"cq": {"node2-14", "node1-01", "node1-02", "node1-03", "node1-04", "node1-05", "node1-06", "node1-07", "node1-08", "node1-09", "node1-10", "node1-11", "node1-12", "node1-13", "node1-14", "node1-15", "node1-16", "node2-01", "node2-02", "node2-03", "node2-04", "node2-05", "node2-06", "node2-07", "node2-08", "node2-09", "node2-10", "node2-11", "node2-12", "node2-13", "node2-15", "node2-16"},
"c1q": {"node1-01", "node1-02", "node1-03", "node1-04", "node1-05", "node1-06", "node1-07", "node1-08", "node1-09", "node1-10", "node1-11", "node1-12", "node1-13", "node1-14", "node1-15", "node1-16"},
"c2q": {"node2-14", "node2-01", "node2-02", "node2-03", "node2-04", "node2-05", "node2-06", "node2-07", "node2-08", "node2-09", "node2-10", "node2-11", "node2-12", "node2-13", "node2-15", "node2-16"},
}

partitionMap, err := populatePartitions(bytes.NewBufferString(partitions))
require.NoError(t, err)
require.Equal(t, expectedPartitionMap, partitionMap)
}
Loading

0 comments on commit 9166017

Please sign in to comment.