Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Discover MNNVL topology with single blocksize
Browse files Browse the repository at this point in the history
ritikasrivastava committed Oct 18, 2024

Verified

This commit was signed with the committer’s verified signature.
straight-shoota Johannes Müller
1 parent 31c08a9 commit c81fc75
Showing 5 changed files with 157 additions and 1 deletion.
1 change: 1 addition & 0 deletions pkg/common/const.go
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ const (
ProviderOCI = "oci"
ProviderGCP = "gcp"
ProviderCW = "cw"
ProviderBM = "baremetal"
ProviderTest = "test"

EngineSLURM = "slurm"
1 change: 1 addition & 0 deletions pkg/common/types.go
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ type Vertex struct {
Name string
ID string
Vertices map[string]*Vertex
Metadata string
}

func (v *Vertex) String() string {
105 changes: 105 additions & 0 deletions pkg/providers/baremetal/mnnvl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package baremetal

Check failure on line 1 in pkg/providers/baremetal/mnnvl.go

GitHub Actions / check

: # github.com/NVIDIA/topograph/pkg/providers/baremetal

import (
"bufio"
"context"
"fmt"
"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/utils"
"strconv"
"strings"
)

const (
BlockTopologyHeader = `##################################################################
# Slurm's network topology configuration file for use with the
# topology/block plugin
##################################################################
`
)

// domain contains map of each domainID(clusterUUID) -> list of nodeNames in that domain
// Each domain will be a separate NVL Domain
type domain struct {
nodeMap map[string]bool // nodeName: true
}

// getNodeList retrieves all the nodenames on the cluster
func getNodeList(cis []common.ComputeInstances) []string {
nodes := []string{}
for _, ci := range cis {
for _, node := range ci.Instances {
nodes = append(nodes, node)
}
}
return nodes
}

// Check if domainID exists in the map
func domainIDExists(id string, domainMap map[string]domain) bool {
if _, exists := domainMap[id]; exists {
return true
}
return false
}

// getClusterOutput reads output from nodeInfo and populates the structs
func getClusterOutput(ctx context.Context, domainMap map[string]domain, nodes []string, cmd string) error {
args := []string{"-R", "ssh", "-w", strings.Join(nodes, ","), cmd}
stdout, err := utils.Exec(ctx, "pdsh", args, nil)
if err != nil {
return fmt.Errorf("Exec error while pdsh\n")
}

scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
nodeLine := scanner.Text()
arr := strings.Split(nodeLine, ":")
nodeName := arr[0]
clusterUUID := strings.TrimSpace(arr[2])
if !domainIDExists(clusterUUID, domainMap) {
domainMap[clusterUUID] = domain{
nodeMap: make(map[string]bool),
}
}
nodeMap := domainMap[clusterUUID].nodeMap
nodeMap[nodeName] = true
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("Scanner error while reading pdsh output\n")
}
return nil
}
func toSlurm(domainMap map[string]domain) *common.Vertex {
root := &common.Vertex{
Vertices: make(map[string]*common.Vertex),
}
blockSize := -1
for domainName, domain := range domainMap {
tree := &common.Vertex{
ID: domainName,
Vertices: make(map[string]*common.Vertex),
}
for node, _ := range domain.nodeMap {
tree.Vertices[node] = &common.Vertex{Name: node, ID: node}
if blockSize == -1 {
blockSize = len(domain.nodeMap)
} else {
fmt.Printf("blockSize different between NVL domains")
}
}
root.Vertices[domainName] = tree
}
root.Metadata = strconv.Itoa(blockSize)
return root
}

func generateTopologyConfig(ctx context.Context, cis []common.ComputeInstances) (*common.Vertex, error) {
domainMap := make(map[string]domain) // domainID: domain
nodes := getNodeList(cis)
err := getClusterOutput(ctx, domainMap, nodes, "nvidia-smi -q | grep ClusterUUID")
if err != nil {
return nil, fmt.Errorf("getClusterOutput failed: %v\n", err)
}
return toSlurm(domainMap), nil
}
49 changes: 49 additions & 0 deletions pkg/providers/baremetal/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package baremetal

import (
"context"
"fmt"

"k8s.io/klog/v2"

"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/engines/slurm"
)

type Provider struct{}

func GetProvider() (*Provider, error) {
return &Provider{}, nil
}

func (p *Provider) GetCredentials(_ *common.Credentials) (interface{}, error) {
return nil, nil
}

func (p *Provider) GetComputeInstances(ctx context.Context, engine common.Engine) ([]common.ComputeInstances, error) {
klog.InfoS("Getting compute instances", "provider", common.ProviderOnPrem, "engine", engine)

Check failure on line 24 in pkg/providers/baremetal/provider.go

GitHub Actions / check

undefined: common.ProviderOnPrem (typecheck)

Check failure on line 24 in pkg/providers/baremetal/provider.go

GitHub Actions / test

undefined: common.ProviderOnPrem

switch engine.(type) {
case *slurm.SlurmEngine:
nodes, err := slurm.GetNodeList(ctx)
if err != nil {
return nil, err
}
i2n := make(map[string]string)
for _, node := range nodes {
i2n[node] = node
}
return []common.ComputeInstances{{Instances: i2n}}, nil
default:
return nil, fmt.Errorf("unsupported engine %q", engine)
}
}

func (p *Provider) GenerateTopologyConfig(ctx context.Context, _ interface{}, _ int, instances []common.ComputeInstances) (*common.Vertex, error) {
if len(instances) > 1 {
return nil, fmt.Errorf("On-prem does not support multi-region topology requests")
}

//call mnnvl code from here
return generateTopologyConfig(ctx, instances)
}
2 changes: 1 addition & 1 deletion pkg/server/http_server.go
Original file line number Diff line number Diff line change
@@ -170,7 +170,7 @@ func parseQuery(vals url.Values) (string, string, map[string]string, error) {

func validate(tr *TopologyRequest) error {
switch tr.provider {
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest:
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest, common.ProviderBM:
//nop
default:
return fmt.Errorf("unsupported provider %s", tr.provider)

0 comments on commit c81fc75

Please sign in to comment.