diff --git a/api/ocm/extensions/attrs/init.go b/api/ocm/extensions/attrs/init.go index 283af4105..a5e3057c2 100644 --- a/api/ocm/extensions/attrs/init.go +++ b/api/ocm/extensions/attrs/init.go @@ -5,8 +5,10 @@ import ( _ "ocm.software/ocm/api/ocm/extensions/attrs/hashattr" _ "ocm.software/ocm/api/ocm/extensions/attrs/keepblobattr" _ "ocm.software/ocm/api/ocm/extensions/attrs/mapocirepoattr" + _ "ocm.software/ocm/api/ocm/extensions/attrs/maxworkersattr" _ "ocm.software/ocm/api/ocm/extensions/attrs/ociuploadattr" _ "ocm.software/ocm/api/ocm/extensions/attrs/plugincacheattr" _ "ocm.software/ocm/api/ocm/extensions/attrs/plugindirattr" _ "ocm.software/ocm/api/ocm/extensions/attrs/signingattr" + ) diff --git a/api/ocm/extensions/attrs/maxworkersattr/maxworkersattr.go b/api/ocm/extensions/attrs/maxworkersattr/maxworkersattr.go new file mode 100644 index 000000000..8784bbb43 --- /dev/null +++ b/api/ocm/extensions/attrs/maxworkersattr/maxworkersattr.go @@ -0,0 +1,99 @@ +package maxworkersattr + +import ( + "fmt" + "strconv" + + "ocm.software/ocm/api/datacontext" + "ocm.software/ocm/api/utils/runtime" +) + +const ( + // ATTR_KEY is the full unique key for the max workers attribute. + // This key should reflect its location or purpose within the ocm.software domain. + ATTR_KEY = "ocm.software/ocm/api/ocm/extensions/attrs/maxworkers" + // ATTR_SHORT is a shorter alias for the max workers attribute, useful for CLI. + ATTR_SHORT = "maxworkers" + + // InternalDefault is now 0. If the user doesn't specify the attribute, + // Get() will return 0, signaling the calling code to use CPU-based auto-detection. + InternalDefault uint = 0 +) + +func init() { + // This function runs automatically when the package is imported. + // It registers your attribute type with the OCM data context. + datacontext.RegisterAttributeType(ATTR_KEY, AttributeType{}, ATTR_SHORT) +} + +// AttributeType implements the datacontext.AttributeType interface for max workers. +type AttributeType struct{} + +// Name returns the full unique key of the attribute. +func (a AttributeType) Name() string { + return ATTR_KEY +} + +// Description provides documentation for the attribute, visible in help messages. +func (a AttributeType) Description() string { + return ` +*integer* +Specifies the maximum number of concurrent workers to use for resource and source +transfer operations. This can influence performance and resource consumption. +A value of 0 (or not specified) indicates auto-detection based on CPU cores. +` +} + +// Encode converts the attribute's Go value (uint) to its marshaled byte representation (e.g., JSON). +func (a AttributeType) Encode(v interface{}, marshaller runtime.Marshaler) ([]byte, error) { + val, ok := v.(uint) // Expecting a uint for number of workers + if !ok { + // Attempt to convert from int if it's passed as int (common Go numeric type) + if intVal, ok := v.(int); ok { + if intVal < 0 { + return nil, fmt.Errorf("negative integer for maxworkers not allowed") + } + val = uint(intVal) + } else { + return nil, fmt.Errorf("unsigned integer (uint) or integer (int) required for maxworkers") + } + } + return marshaller.Marshal(val) // Marshal the uint value +} + +// Decode converts the marshaled byte representation (e.g., JSON) to the attribute's Go value (uint). +func (a AttributeType) Decode(data []byte, unmarshaller runtime.Unmarshaler) (interface{}, error) { + var value uint // Decode into a uint + err := unmarshaller.Unmarshal(data, &value) + if err != nil { + var s string + if e := unmarshaller.Unmarshal(data, &s); e == nil { + parsedVal, err := strconv.ParseUint(s, 10, 32) + if err == nil { + return uint(parsedVal), nil + } + } + return nil, fmt.Errorf("failed to decode maxworkers as uint: %w", err) + } + return value, nil +} + +//////////////////////////////////////////////////////////////////////////////// + +func Get(ctx datacontext.Context) uint { + a := ctx.GetAttributes().GetAttribute(ATTR_KEY) + if a == nil { + // If the attribute is NOT explicitly set by the user, return 0. + // This 0 will signal the calling code to use the CPU-based auto-detection. + return 50 + } + if val, ok := a.(uint); ok { + return val // Return the user-specified value (can be 0 if user explicitly set it to 0). + } + // Fallback in case of type mismatch (should ideally not happen with correct Encode/Decode) + return 0 // Default to 0 for auto-detection in case of error +} + +func Set(ctx datacontext.Context, workers uint) error { + return ctx.GetAttributes().SetAttribute(ATTR_KEY, workers) +} diff --git a/api/ocm/tools/transfer/transfer.go b/api/ocm/tools/transfer/transfer.go index d6b49092f..7504f363f 100644 --- a/api/ocm/tools/transfer/transfer.go +++ b/api/ocm/tools/transfer/transfer.go @@ -3,6 +3,8 @@ package transfer import ( "context" "fmt" + "runtime" + "sync" "github.com/mandelsoft/goutils/errors" "github.com/mandelsoft/goutils/finalizer" @@ -12,17 +14,20 @@ import ( "ocm.software/ocm/api/ocm/compdesc" ocmcpi "ocm.software/ocm/api/ocm/cpi" "ocm.software/ocm/api/ocm/extensions/accessmethods/none" + "ocm.software/ocm/api/ocm/extensions/attrs/maxworkersattr" "ocm.software/ocm/api/ocm/tools/transfer/internal" "ocm.software/ocm/api/ocm/tools/transfer/transferhandler/standard" - "ocm.software/ocm/api/utils/errkind" common "ocm.software/ocm/api/utils/misc" - "ocm.software/ocm/api/utils/runtime" + runtimeutil "ocm.software/ocm/api/utils/runtime" ) type WalkingState = common.WalkingState[*struct{}, interface{}] type TransportClosure = common.NameVersionInfo[*struct{}] +// TransferWorkersEnvVar is the environment variable to configure the number of transfer workers. +const TransferWorkersEnvVar = "OCM_TRANSFER_WORKERS" // This constant is now technically unused, but kept for context. + func TransferVersion(printer common.Printer, closure TransportClosure, src ocmcpi.ComponentVersionAccess, tgt ocmcpi.Repository, handler TransferHandler) error { return TransferVersionWithContext(common.WithPrinter(context.Background(), common.AssurePrinter(printer)), closure, src, tgt, handler) } @@ -98,7 +103,7 @@ func transferVersion(ctx context.Context, log logging.Logger, state WalkingState } if ok { // execute transport as if the component version were not present - // on the target side. + // on the target side. } else { // determine transport mode for component version present // on the target side. @@ -136,7 +141,6 @@ func transferVersion(ctx context.Context, log logging.Logger, state WalkingState if eq.IsArtifactDetectable() { msg += " differs because some artifact digests are changed" } else { - // TODO: option to precalculate missing digests (as pre equivalent step). msg += " might differ, because not all artifact digests are known" } } else { @@ -165,18 +169,32 @@ func transferVersion(ctx context.Context, log logging.Logger, state WalkingState return errors.Wrapf(err, "%s: creating target version", state.History) } + var wg sync.WaitGroup + var mu sync.Mutex list := errors.ErrListf("component references for %s", nv) - log.Info(" transferring references") + for _, r := range d.References { - cv, shdlr, err := handler.TransferVersion(src.Repository(), src, &r, tgt) - if err != nil { - return errors.Wrapf(err, "%s: nested component %s[%s:%s]", state.History, r.GetName(), r.ComponentName, r.GetVersion()) - } - if cv != nil { - list.Add(transferVersion(common.AddPrinterGap(ctx, " "), log.WithValues("ref", r.Name), state, cv, tgt, shdlr)) - list.Addf(nil, cv.Close(), "closing reference %s", r.Name) - } + wg.Add(1) + go func() { + defer wg.Done() + cv, shdlr, err := handler.TransferVersion(src.Repository(), src, &r, tgt) + if err != nil { + mu.Lock() + list.Add(errors.Wrapf(err, "%s: nested component %s[%s:%s]", state.History, r.GetName(), r.ComponentName, r.GetVersion())) + mu.Unlock() + return + } + if cv != nil { + err1 := transferVersion(ctx, log.WithValues("ref", r.Name), state, cv, tgt, shdlr) + err2 := cv.Close() + mu.Lock() + list.Add(err1) + list.Addf(nil, err2, "closing reference %s", r.Name) + mu.Unlock() + } + }() } + wg.Wait() if doTransport { var n *compdesc.ComponentDescriptor @@ -190,22 +208,28 @@ func transferVersion(ctx context.Context, log logging.Logger, state WalkingState n = src.GetDescriptor().Copy() } - var unstr *runtime.UnstructuredTypedObject + var unstr *runtimeutil.UnstructuredTypedObject if !ocm.IsIntermediate(tgt.GetSpecification()) { - unstr, err = runtime.ToUnstructuredTypedObject(tgt.GetSpecification()) - if err != nil { - unstr = nil + // Capture the error specifically for this operation + specErr := error(nil) // Declare a local error variable + unstr, specErr = runtimeutil.ToUnstructuredTypedObject(tgt.GetSpecification()) + if specErr != nil { + // Log the error as a warning, but don't fail the transfer. + // The `log` variable from the function signature is suitable here. + log.Warn("Failed to convert target repository specification to unstructured object for RepositoryContext", + "error", specErr.Error(), // Use .Error() for string representation + "spec_type", tgt.GetSpecification().GetType(), // Log the type of spec + ) + // unstr remains nil, so it won't be appended. + } else { + // Only append if conversion was successful + n.RepositoryContexts = append(n.RepositoryContexts, unstr) } } - if unstr != nil { - n.RepositoryContexts = append(n.RepositoryContexts, unstr) - } - // just to be sure: both modes set to false would produce - // corrupted content in target. - // If no copy is done, merge must keep the access methods in target!!! if !doMerge || doCopy { - err = copyVersion(ctx, printer, log, state.History, src, t, n, handler) + numWorkers := calculateEffectiveTransferWorkers(ctx) + err = copyVersionWithWorkerPool(ctx, printer, log, state.History, src, t, n, handler, numWorkers) if err != nil { return err } @@ -221,17 +245,94 @@ func transferVersion(ctx context.Context, log logging.Logger, state WalkingState } func CopyVersion(printer common.Printer, log logging.Logger, hist common.History, src ocm.ComponentVersionAccess, t ocm.ComponentVersionAccess, handler TransferHandler) (rerr error) { - return copyVersion(context.Background(), common.AssurePrinter(printer), log, hist, src, t, src.GetDescriptor().Copy(), handler) + return CopyVersionWithContext(context.Background(), printer, log, hist, src, t, handler) } -func CopyVersionWithContext(cctx context.Context, log logging.Logger, hist common.History, src ocm.ComponentVersionAccess, t ocm.ComponentVersionAccess, handler TransferHandler) (rerr error) { - return copyVersion(cctx, common.GetPrinter(cctx), log, hist, src, t, src.GetDescriptor().Copy(), handler) +func CopyVersionWithContext(cctx context.Context, printer common.Printer, log logging.Logger, hist common.History, src ocm.ComponentVersionAccess, t ocm.ComponentVersionAccess, handler TransferHandler) (rerr error) { + numWorkers := calculateEffectiveTransferWorkers(cctx) + return copyVersionWithWorkerPool(cctx, printer, log, hist, src, t, src.GetDescriptor().Copy(), handler, numWorkers) } -// copyVersion (purely internal) expects an already prepared target comp desc for t given as prep. -func copyVersion(cctx context.Context, printer common.Printer, log logging.Logger, hist common.History, src ocm.ComponentVersionAccess, t ocm.ComponentVersionAccess, prep *compdesc.ComponentDescriptor, handler TransferHandler) (rerr error) { - var finalize finalizer.Finalizer +// calculateEffectiveTransferWorkers determines the number of workers to use. +// It prioritizes an explicit attribute setting, then falls back to CPU-based auto-detection. +func calculateEffectiveTransferWorkers(ctx context.Context) int { + // First, obtain the OCM data context from the provided Go context. + // `ocm.DefaultContext()` is a common way to get it if it's not directly `datacontext.Context` + // attributeWorkers := parallelattr.Get(ocmCtx) + // Get the workers value from the attribute. + // This will be: + // - User-defined value (if -X or .ocmconfig is used and > 0) + // - 0 (if user explicitly set -X maxworkers=0, OR if the attribute was not set at all) + // -=-=-=-=-=-=-=--condition to pick workers: -=-=-=-=-=-=-=-=-=- + // if nil : + // taking numworker as -1 , going sequential + // if maxworker = 0: + // taking worker as pwr cpu + // if makxworker = (any number) + // taking that number + ocmCtx := ocm.DefaultContext() + attributeWorkers := maxworkersattr.Get(ocmCtx) + + // If attributeWorkers is 0, it means either user explicitly set 0, or attribute was not set (default to 0). + // In both cases, this signals to use the CPU-based auto-detection. + if attributeWorkers == 0 { + return determineWorkersFromCPU() + } + + if attributeWorkers == 50 { + return -1 // If attribute is explicitly unset, use 1 worker. + } + + // Otherwise (if attributeWorkers is > 0), use the user-defined value. + return int(attributeWorkers) +} + +// determineWorkersFromCPU implements your CPU-based logic. + +func determineWorkersFromCPU() int { + numCPU := runtime.NumCPU() + switch { + case numCPU <= 2: + return 1 + case numCPU <= 4: + return 2 + case numCPU <= 8: + return 4 + default: + return numCPU / 2 + } +} +func notifyArtifactInfo(printer common.Printer, log logging.Logger, kind string, index int, meta compdesc.ArtifactMetaAccess, hint string, msgs ...interface{}) { + msg := "copying" + cmsg := "..." + if len(msgs) > 0 { + if m, ok := msgs[0].(string); ok { + msg = fmt.Sprintf(m, msgs[1:]...) + } else { + msg = fmt.Sprint(msgs...) + } + cmsg = " (" + msg + ")" + } + if printer != nil { + if hint != "" { + printer.Printf("...%s %d %s[%s](%s)%s\n", kind, index, meta.GetName(), meta.GetType(), hint, cmsg) + } else { + printer.Printf("...%s %d %s[%s]%s\n", kind, index, meta.GetName(), meta.GetType(), cmsg) + } + } + if hint != "" { + log.Debug("handle artifact", "kind", kind, "name", meta.GetName(), "type", meta.GetType(), "index", index, "hint", hint, "message", msg) + } else { + log.Debug("handle artifact", "kind", kind, "name", meta.GetName(), "type", meta.GetType(), "index", index, "message", msg) + } +} + +func copyVersionWithWorkerPool(ctx context.Context, printer common.Printer, log logging.Logger, hist common.History, + src ocm.ComponentVersionAccess, t ocm.ComponentVersionAccess, prep *compdesc.ComponentDescriptor, + handler TransferHandler, maxWorkers int) (rerr error) { + + var finalize finalizer.Finalizer defer errors.PropagateError(&rerr, finalize.Finalize) if handler == nil { @@ -241,136 +342,185 @@ func copyVersion(cctx context.Context, printer common.Printer, log logging.Logge srccd := src.GetDescriptor() cur := *t.GetDescriptor() *t.GetDescriptor() = *prep - log.Info(" transferring resources") - for i, r := range src.GetResources() { - var m ocmcpi.AccessMethod - if err := common.IsContextCanceled(cctx); err != nil { - printer.Printf("cancelled by caller\n") - return err - } - - nested := finalize.Nested() - - a, err := r.Access() - if err == nil { - m, err = a.AccessMethod(src) + if maxWorkers <= 1 { + // LINEWORKER SEQUENTIAL TRANSFER + for i, r := range src.GetResources() { + nested := finalize.Nested() + a, err := r.Access() + if err != nil { + return err + } + m, err := a.AccessMethod(src) nested.Close(m, fmt.Sprintf("%s: transferring resource %d: closing access method", hist, i)) - } - if err == nil { + if err != nil { + return err + } ok := a.IsLocal(src.GetContext()) - if !ok { - if !none.IsNone(a.GetKind()) { - ok, err = handler.TransferResource(src, a, r) - if err == nil && !ok { - log.Info("transport omitted", "resource", r.Meta().Name, "index", i, "access", a.GetType()) - } + if !ok && !none.IsNone(a.GetKind()) { + ok, err = handler.TransferResource(src, a, r) + if err != nil || !ok { + return err } } if ok { - var old compdesc.Resource - hint := ocmcpi.ArtifactNameHint(a, src) - old, err = cur.GetResourceByIdentity(r.Meta().GetIdentity(srccd.Resources)) - + old, err := cur.GetResourceByIdentity(r.Meta().GetIdentity(srccd.Resources)) changed := err != nil || old.Digest == nil || !old.Digest.Equal(r.Meta().Digest) valueNeeded := err == nil && needsTransport(src.GetContext(), r, &old) if changed || valueNeeded { - var msgs []interface{} - if !errors.IsErrNotFound(err) { - if err != nil { - return err - } - if !changed && valueNeeded { - msgs = []interface{}{"copy"} - } else { - msgs = []interface{}{"overwrite"} - } - } - notifyArtifactInfo(printer, log, "resource", i, r.Meta(), hint, msgs...) - err = handler.HandleTransferResource(r, m, hint, t) - } else { - if err == nil { // old resource found -> keep current access method - t.SetResource(r.Meta(), old.Access, ocm.ModifyElement(), ocm.SkipVerify(), ocm.DisableExtraIdentityDefaulting()) + notifyArtifactInfo(printer, log, "resource", i, r.Meta(), hint) + if err := handler.HandleTransferResource(r, m, hint, t); err != nil { + return err } + } else if err == nil { + t.SetResource(r.Meta(), old.Access, ocm.ModifyElement(), ocm.SkipVerify(), ocm.DisableExtraIdentityDefaulting()) notifyArtifactInfo(printer, log, "resource", i, r.Meta(), hint, "already present") } } - } - if err != nil { - if !errors.IsErrUnknownKind(err, errkind.KIND_ACCESSMETHOD) { - return errors.Wrapf(err, "%s: transferring resource %d", hist, i) + if err := nested.Finalize(); err != nil { + return err } - printer.Printf("WARN: %s: transferring resource %d: %s (enforce transport by reference)\n", hist, i, err) - } - err = nested.Finalize() - if err != nil { - return err - } - } - - log.Info(" transferring sources") - for i, r := range src.GetSources() { - var m ocmcpi.AccessMethod - - if err := common.IsContextCanceled(cctx); err != nil { - printer.Printf("cancelled by caller\n") - return err } - a, err := r.Access() - if err == nil { - m, err = a.AccessMethod(src) - } - if err == nil { + for i, r := range src.GetSources() { + a, err := r.Access() + if err != nil { + return err + } + m, err := a.AccessMethod(src) + if err != nil { + return err + } ok := a.IsLocal(src.GetContext()) - if !ok { - if !none.IsNone(a.GetKind()) { - ok, err = handler.TransferSource(src, a, r) - if err == nil && !ok { - log.Info("transport omitted", "source", r.Meta().Name, "index", i, "access", a.GetType()) - } + if !ok && !none.IsNone(a.GetKind()) { + ok, err = handler.TransferSource(src, a, r) + if err != nil || !ok { + return err } } if ok { - // sources do not have digests so far, so they have to copied, always. hint := ocmcpi.ArtifactNameHint(a, src) notifyArtifactInfo(printer, log, "source", i, r.Meta(), hint) - err = errors.Join(err, handler.HandleTransferSource(r, m, hint, t)) + if err := handler.HandleTransferSource(r, m, hint, t); err != nil { + return err + } } - err = errors.Join(err, m.Close()) - } - if err != nil { - if !errors.IsErrUnknownKind(err, errkind.KIND_ACCESSMETHOD) { - return errors.Wrapf(err, "%s: transferring source %d", hist, i) + if err := m.Close(); err != nil { + return err } - printer.Printf("WARN: %s: transferring source %d: %s (enforce transport by reference)\n", hist, i, err) } + return nil } - return nil -} -func notifyArtifactInfo(printer common.Printer, log logging.Logger, kind string, index int, meta compdesc.ArtifactMetaAccess, hint string, msgs ...interface{}) { - msg := "copying" - cmsg := "..." - if len(msgs) > 0 { - if m, ok := msgs[0].(string); ok { - msg = fmt.Sprintf(m, msgs[1:]...) - } else { - msg = fmt.Sprint(msgs...) + // PARALLEL WORKER POOL TRANSFER (maxWorkers > 1) + type transferTask struct { + task func() error + id string + } + taskBufferSize := maxWorkers * 2 + if taskBufferSize == 0 { + taskBufferSize = 1 + } + tasks := make(chan transferTask, taskBufferSize) + errChan := make(chan error, len(src.GetResources())+len(src.GetSources())) + + var wg sync.WaitGroup + for i := 0; i < maxWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for item := range tasks { + if err := item.task(); err != nil { + errChan <- err + } + } + }() + } + + for i, r := range src.GetResources() { + idx := i + res := r + tasks <- transferTask{ + id: fmt.Sprintf("resource-%d", idx), + task: func() error { + nested := finalize.Nested() + a, err := res.Access() + if err != nil { + return err + } + m, err := a.AccessMethod(src) + nested.Close(m, fmt.Sprintf("%s: transferring resource %d: closing access method", hist, idx)) + if err != nil { + return err + } + ok := a.IsLocal(src.GetContext()) + if !ok && !none.IsNone(a.GetKind()) { + ok, err = handler.TransferResource(src, a, res) + if err != nil || !ok { + return err + } + } + if ok { + hint := ocmcpi.ArtifactNameHint(a, src) + old, err := cur.GetResourceByIdentity(res.Meta().GetIdentity(srccd.Resources)) + changed := err != nil || old.Digest == nil || !old.Digest.Equal(res.Meta().Digest) + valueNeeded := err == nil && needsTransport(src.GetContext(), res, &old) + if changed || valueNeeded { + notifyArtifactInfo(printer, log, "resource", idx, res.Meta(), hint) + if err := handler.HandleTransferResource(res, m, hint, t); err != nil { + return err + } + } else if err == nil { + t.SetResource(res.Meta(), old.Access, ocm.ModifyElement(), ocm.SkipVerify(), ocm.DisableExtraIdentityDefaulting()) + notifyArtifactInfo(printer, log, "resource", idx, res.Meta(), hint, "already present") + } + } + return nested.Finalize() + }, } - cmsg = " (" + msg + ")" } - if printer != nil { - if hint != "" { - printer.Printf("...%s %d %s[%s](%s)%s\n", kind, index, meta.GetName(), meta.GetType(), hint, cmsg) - } else { - printer.Printf("...%s %d %s[%s]%s\n", kind, index, meta.GetName(), meta.GetType(), cmsg) + + for i, r := range src.GetSources() { + idx := i + srcRes := r + tasks <- transferTask{ + id: fmt.Sprintf("source-%d", idx), + task: func() error { + a, err := srcRes.Access() + if err != nil { + return err + } + m, err := a.AccessMethod(src) + if err != nil { + return err + } + ok := a.IsLocal(src.GetContext()) + if !ok && !none.IsNone(a.GetKind()) { + ok, err = handler.TransferSource(src, a, srcRes) + if err != nil || !ok { + return err + } + } + if ok { + hint := ocmcpi.ArtifactNameHint(a, src) + notifyArtifactInfo(printer, log, "source", idx, srcRes.Meta(), hint) + if err := handler.HandleTransferSource(srcRes, m, hint, t); err != nil { + return err + } + } + return m.Close() + }, } } - if hint != "" { - log.Debug("handle artifact", "kind", kind, "name", meta.GetName(), "type", meta.GetType(), "index", index, "hint", hint, "message", msg) - } else { - log.Debug("handle artifact", "kind", kind, "name", meta.GetName(), "type", meta.GetType(), "index", index, "message", msg) + + close(tasks) + wg.Wait() + close(errChan) + + errList := errors.ErrListf("transfer resources and sources") + for e := range errChan { + errList.Add(e) } + return errList.Result() }