Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Discover MNNVL topology with single blocksize
Browse files Browse the repository at this point in the history
ritikasrivastava committed Oct 18, 2024
1 parent 31c08a9 commit c81fc75
Showing 5 changed files with 157 additions and 1 deletion.
1 change: 1 addition & 0 deletions pkg/common/const.go
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ const (
ProviderOCI = "oci"
ProviderGCP = "gcp"
ProviderCW = "cw"
ProviderBM = "baremetal"
ProviderTest = "test"

EngineSLURM = "slurm"
1 change: 1 addition & 0 deletions pkg/common/types.go
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ type Vertex struct {
Name string
ID string
Vertices map[string]*Vertex
Metadata string
}

func (v *Vertex) String() string {
105 changes: 105 additions & 0 deletions pkg/providers/baremetal/mnnvl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package baremetal

Check failure on line 1 in pkg/providers/baremetal/mnnvl.go

GitHub Actions / check

: # github.com/NVIDIA/topograph/pkg/providers/baremetal

import (
"bufio"
"context"
"fmt"
"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/utils"
"strconv"
"strings"
)

const (
BlockTopologyHeader = `##################################################################
# Slurm's network topology configuration file for use with the
# topology/block plugin
##################################################################
`
)

// domain contains map of each domainID(clusterUUID) -> list of nodeNames in that domain
// Each domain will be a separate NVL Domain
type domain struct {
nodeMap map[string]bool // nodeName: true
}

// getNodeList retrieves all the nodenames on the cluster
func getNodeList(cis []common.ComputeInstances) []string {
nodes := []string{}
for _, ci := range cis {
for _, node := range ci.Instances {
nodes = append(nodes, node)
}
}
return nodes
}

// Check if domainID exists in the map
func domainIDExists(id string, domainMap map[string]domain) bool {
if _, exists := domainMap[id]; exists {
return true
}
return false
}

// getClusterOutput reads output from nodeInfo and populates the structs
func getClusterOutput(ctx context.Context, domainMap map[string]domain, nodes []string, cmd string) error {
args := []string{"-R", "ssh", "-w", strings.Join(nodes, ","), cmd}
stdout, err := utils.Exec(ctx, "pdsh", args, nil)
if err != nil {
return fmt.Errorf("Exec error while pdsh\n")
}

scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
nodeLine := scanner.Text()
arr := strings.Split(nodeLine, ":")
nodeName := arr[0]
clusterUUID := strings.TrimSpace(arr[2])
if !domainIDExists(clusterUUID, domainMap) {
domainMap[clusterUUID] = domain{
nodeMap: make(map[string]bool),
}
}
nodeMap := domainMap[clusterUUID].nodeMap
nodeMap[nodeName] = true
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("Scanner error while reading pdsh output\n")
}
return nil
}
func toSlurm(domainMap map[string]domain) *common.Vertex {
root := &common.Vertex{
Vertices: make(map[string]*common.Vertex),
}
blockSize := -1
for domainName, domain := range domainMap {
tree := &common.Vertex{
ID: domainName,
Vertices: make(map[string]*common.Vertex),
}
for node, _ := range domain.nodeMap {
tree.Vertices[node] = &common.Vertex{Name: node, ID: node}
if blockSize == -1 {
blockSize = len(domain.nodeMap)
} else {
fmt.Printf("blockSize different between NVL domains")
}
}
root.Vertices[domainName] = tree
}
root.Metadata = strconv.Itoa(blockSize)
return root
}

func generateTopologyConfig(ctx context.Context, cis []common.ComputeInstances) (*common.Vertex, error) {
domainMap := make(map[string]domain) // domainID: domain
nodes := getNodeList(cis)
err := getClusterOutput(ctx, domainMap, nodes, "nvidia-smi -q | grep ClusterUUID")
if err != nil {
return nil, fmt.Errorf("getClusterOutput failed: %v\n", err)
}
return toSlurm(domainMap), nil
}
49 changes: 49 additions & 0 deletions pkg/providers/baremetal/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package baremetal

import (
"context"
"fmt"

"k8s.io/klog/v2"

"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/engines/slurm"
)

type Provider struct{}

func GetProvider() (*Provider, error) {
return &Provider{}, nil
}

func (p *Provider) GetCredentials(_ *common.Credentials) (interface{}, error) {
return nil, nil
}

func (p *Provider) GetComputeInstances(ctx context.Context, engine common.Engine) ([]common.ComputeInstances, error) {
klog.InfoS("Getting compute instances", "provider", common.ProviderOnPrem, "engine", engine)

Check failure on line 24 in pkg/providers/baremetal/provider.go

GitHub Actions / check

undefined: common.ProviderOnPrem (typecheck)

Check failure on line 24 in pkg/providers/baremetal/provider.go

GitHub Actions / test

undefined: common.ProviderOnPrem

switch engine.(type) {
case *slurm.SlurmEngine:
nodes, err := slurm.GetNodeList(ctx)
if err != nil {
return nil, err
}
i2n := make(map[string]string)
for _, node := range nodes {
i2n[node] = node
}
return []common.ComputeInstances{{Instances: i2n}}, nil
default:
return nil, fmt.Errorf("unsupported engine %q", engine)
}
}

func (p *Provider) GenerateTopologyConfig(ctx context.Context, _ interface{}, _ int, instances []common.ComputeInstances) (*common.Vertex, error) {
if len(instances) > 1 {
return nil, fmt.Errorf("On-prem does not support multi-region topology requests")
}

//call mnnvl code from here
return generateTopologyConfig(ctx, instances)
}
2 changes: 1 addition & 1 deletion pkg/server/http_server.go
Original file line number Diff line number Diff line change
@@ -170,7 +170,7 @@ func parseQuery(vals url.Values) (string, string, map[string]string, error) {

func validate(tr *TopologyRequest) error {
switch tr.provider {
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest:
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest, common.ProviderBM:
//nop
default:
return fmt.Errorf("unsupported provider %s", tr.provider)

0 comments on commit c81fc75

Please sign in to comment.