Skip to content

Commit

Permalink
Discover MNNVL topology with single blocksize (#9)
Browse files Browse the repository at this point in the history
* Discover MNNVL topology with single blocksize

Signed-off-by: Ritika Srivastava <[email protected]>

* Change metadata to map

Signed-off-by: Ritika Srivastava <[email protected]>

* Proto files

Signed-off-by: Ritika Srivastava <[email protected]>

* Address errors

Signed-off-by: Ritika Srivastava <[email protected]>

* Remove blank identifier

Signed-off-by: Ritika Srivastava <[email protected]>

* Rename function

Signed-off-by: Ritika Srivastava <[email protected]>

---------

Signed-off-by: Ritika Srivastava <[email protected]>
  • Loading branch information
ritikasrivastava authored Oct 18, 2024
1 parent 31c08a9 commit 676ad8f
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 60 deletions.
1 change: 1 addition & 0 deletions pkg/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
ProviderOCI = "oci"
ProviderGCP = "gcp"
ProviderCW = "cw"
ProviderBM = "baremetal"
ProviderTest = "test"

EngineSLURM = "slurm"
Expand Down
1 change: 1 addition & 0 deletions pkg/common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Vertex struct {
Name string
ID string
Vertices map[string]*Vertex
Metadata map[string]string
}

func (v *Vertex) String() string {
Expand Down
72 changes: 14 additions & 58 deletions pkg/protos/topology.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/topology_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 101 additions & 0 deletions pkg/providers/baremetal/mnnvl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package baremetal

import (
"bufio"
"context"
"fmt"
"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/utils"
"strconv"
"strings"
)

// domain contains map of each domainID(clusterUUID) -> list of nodeNames in that domain
// Each domain will be a separate NVL Domain
type domain struct {
nodeMap map[string]bool // nodeName: true
}

// getNodeList retrieves all the nodenames on the cluster
func getNodeList(cis []common.ComputeInstances) []string {
nodes := []string{}
for _, ci := range cis {
for _, node := range ci.Instances {
nodes = append(nodes, node)
}
}
return nodes
}

// Check if domainID exists in the map
func domainIDExists(id string, domainMap map[string]domain) bool {
if _, exists := domainMap[id]; exists {
return true
}
return false
}

// getClusterOutput reads output from nodeInfo and populates the structs
func getClusterOutput(ctx context.Context, domainMap map[string]domain, nodes []string, cmd string) error {
args := []string{"-R", "ssh", "-w", strings.Join(nodes, ","), cmd}
stdout, err := utils.Exec(ctx, "pdsh", args, nil)
if err != nil {
return fmt.Errorf("Exec error while pdsh\n")
}

scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
nodeLine := scanner.Text()
arr := strings.Split(nodeLine, ":")
nodeName := arr[0]
clusterUUID := strings.TrimSpace(arr[2])
if !domainIDExists(clusterUUID, domainMap) {
domainMap[clusterUUID] = domain{
nodeMap: make(map[string]bool),
}
}
nodeMap := domainMap[clusterUUID].nodeMap
nodeMap[nodeName] = true
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("Scanner error while reading pdsh output\n")
}
return nil
}
func toGraph(domainMap map[string]domain) *common.Vertex {
root := &common.Vertex{
Vertices: make(map[string]*common.Vertex),
Metadata: make(map[string]string),
}
blockSize := -1
for domainName, domain := range domainMap {
tree := &common.Vertex{
ID: domainName,
Vertices: make(map[string]*common.Vertex),
}
for node := range domain.nodeMap {
tree.Vertices[node] = &common.Vertex{Name: node, ID: node}
if blockSize == -1 {
blockSize = len(domain.nodeMap)
} else {
fmt.Printf("blockSize different between NVL domains")
}
}
root.Vertices[domainName] = tree
}
// add root metadata
root.Metadata["engine"] = "slurm"
root.Metadata["plugin"] = "topology/block"
root.Metadata["blocksize"] = strconv.Itoa(blockSize)
return root
}

func generateTopologyConfig(ctx context.Context, cis []common.ComputeInstances) (*common.Vertex, error) {
domainMap := make(map[string]domain) // domainID: domain
nodes := getNodeList(cis)
err := getClusterOutput(ctx, domainMap, nodes, "nvidia-smi -q | grep ClusterUUID")
if err != nil {
return nil, fmt.Errorf("getClusterOutput failed: %v\n", err)
}
return toGraph(domainMap), nil
}
49 changes: 49 additions & 0 deletions pkg/providers/baremetal/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package baremetal

import (
"context"
"fmt"

"k8s.io/klog/v2"

"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/engines/slurm"
)

type Provider struct{}

func GetProvider() (*Provider, error) {
return &Provider{}, nil
}

func (p *Provider) GetCredentials(_ *common.Credentials) (interface{}, error) {
return nil, nil
}

func (p *Provider) GetComputeInstances(ctx context.Context, engine common.Engine) ([]common.ComputeInstances, error) {
klog.InfoS("Getting compute instances", "provider", common.ProviderBM, "engine", engine)

switch engine.(type) {
case *slurm.SlurmEngine:
nodes, err := slurm.GetNodeList(ctx)
if err != nil {
return nil, err
}
i2n := make(map[string]string)
for _, node := range nodes {
i2n[node] = node
}
return []common.ComputeInstances{{Instances: i2n}}, nil
default:
return nil, fmt.Errorf("unsupported engine %q", engine)
}
}

func (p *Provider) GenerateTopologyConfig(ctx context.Context, _ interface{}, _ int, instances []common.ComputeInstances) (*common.Vertex, error) {
if len(instances) > 1 {
return nil, fmt.Errorf("On-prem does not support multi-region topology requests")
}

//call mnnvl code from here
return generateTopologyConfig(ctx, instances)
}
2 changes: 1 addition & 1 deletion pkg/server/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func parseQuery(vals url.Values) (string, string, map[string]string, error) {

func validate(tr *TopologyRequest) error {
switch tr.provider {
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest:
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest, common.ProviderBM:
//nop
default:
return fmt.Errorf("unsupported provider %s", tr.provider)
Expand Down

0 comments on commit 676ad8f

Please sign in to comment.