diff --git a/internal/clients/msgraph_client.go b/internal/clients/msgraph_client.go index 31f9d7e..af64419 100644 --- a/internal/clients/msgraph_client.go +++ b/internal/clients/msgraph_client.go @@ -3,7 +3,12 @@ package clients import ( "context" "encoding/json" + "fmt" "net/http" + "sort" + "strings" + "sync" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -14,11 +19,26 @@ const ( moduleName = "resource" moduleVersion = "v0.1.0" nextLinkKey = "@odata.nextLink" + listCacheTTL = 10 * time.Second ) +type listCacheEntry struct { + body interface{} + fetchedAt time.Time +} + +type listInflight struct { + done chan struct{} + body interface{} + err error +} + type MSGraphClient struct { - host string - pl runtime.Pipeline + host string + pl runtime.Pipeline + cacheMu sync.Mutex + listCache map[string]listCacheEntry + inFlight map[string]*listInflight } func NewMSGraphClient(credential azcore.TokenCredential, opt *policy.ClientOptions) (*MSGraphClient, error) { @@ -33,13 +53,14 @@ func NewMSGraphClient(credential azcore.TokenCredential, opt *policy.ClientOptio Tracing: runtime.TracingOptions{}, }, opt) return &MSGraphClient{ - host: "https://graph.microsoft.com", - pl: pl, + host: "https://graph.microsoft.com", + pl: pl, + listCache: make(map[string]listCacheEntry), + inFlight: make(map[string]*listInflight), }, nil } func (client *MSGraphClient) Read(ctx context.Context, url string, apiVersion string, options RequestOptions) (interface{}, error) { - // apply per-request retry options via context if options.RetryOptions != nil { ctx = policy.WithRetryOptions(ctx, *options.RetryOptions) } @@ -69,7 +90,6 @@ func (client *MSGraphClient) Read(ctx context.Context, url string, apiVersion st return nil, err } - // if response has nextLink, follow the link and return the final response if responseBodyMap, ok := responseBody.(map[string]interface{}); ok { if nextLink := responseBodyMap["@odata.nextLink"]; nextLink != nil { return client.List(ctx, url, apiVersion, options) @@ -79,8 +99,58 @@ func (client *MSGraphClient) Read(ctx context.Context, url string, apiVersion st return responseBody, nil } +func (client *MSGraphClient) cachedList(ctx context.Context, url, apiVersion string, options RequestOptions) (interface{}, error) { + key := listCacheKey(url, apiVersion, options.QueryParameters) + + client.cacheMu.Lock() + if entry, ok := client.listCache[key]; ok && time.Since(entry.fetchedAt) < listCacheTTL { + client.cacheMu.Unlock() + return entry.body, nil + } + if f, ok := client.inFlight[key]; ok { + client.cacheMu.Unlock() + select { + case <-f.done: + case <-ctx.Done(): + return nil, ctx.Err() + } + return f.body, f.err + } + f := &listInflight{done: make(chan struct{})} + client.inFlight[key] = f + client.cacheMu.Unlock() + + body, err := client.List(ctx, url, apiVersion, options) + + client.cacheMu.Lock() + delete(client.inFlight, key) + if err == nil { + client.listCache[key] = listCacheEntry{body: body, fetchedAt: time.Now()} + } + client.cacheMu.Unlock() + + f.body = body + f.err = err + close(f.done) + + return body, err +} + +func listCacheKey(url, apiVersion string, params map[string]string) string { + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + sort.Strings(keys) + key := fmt.Sprintf("%s|%s", apiVersion, url) + for _, k := range keys { + key += fmt.Sprintf("|%s=%s", k, params[k]) + } + return key +} + func (client *MSGraphClient) ListRefIDs(ctx context.Context, url string, apiVersion string, options RequestOptions) ([]string, error) { - responseBody, err := client.List(ctx, url, apiVersion, options) + responseBody, err := client.cachedList(ctx, url, apiVersion, options) if err != nil { return nil, err } @@ -185,7 +255,6 @@ func (client *MSGraphClient) List(ctx context.Context, url string, apiVersion st continue } } - // copy all fields except for nextLinkKey and value for key, val := range pageMap { if key != nextLinkKey && key != "value" { out[key] = val @@ -193,7 +262,6 @@ func (client *MSGraphClient) List(ctx context.Context, url string, apiVersion st } } - // if response doesn't follow the paging guideline, return the response as is return page, nil } @@ -202,6 +270,71 @@ func (client *MSGraphClient) List(ctx context.Context, url string, apiVersion st return out, nil } +func (client *MSGraphClient) invalidateListCache(collectionUrl, apiVersion string) { + prefix := fmt.Sprintf("%s|%s", apiVersion, collectionUrl) + client.cacheMu.Lock() + defer client.cacheMu.Unlock() + for key := range client.listCache { + if strings.HasPrefix(key, prefix) { + delete(client.listCache, key) + } + } +} + +func parentCollectionUrl(itemUrl string) string { + u := strings.TrimSuffix(itemUrl, "/$ref") + if idx := strings.LastIndex(u, "/"); idx >= 0 { + return u[:idx] + } + return u +} + +func (client *MSGraphClient) ReadFromList(ctx context.Context, collectionUrl string, id string, apiVersion string, options RequestOptions) (interface{}, error) { + responseBody, err := client.cachedList(ctx, collectionUrl, apiVersion, options) + if err != nil { + return nil, err + } + return findItemInList(responseBody, id) +} + +func (client *MSGraphClient) ReadFromListWithWait(ctx context.Context, collectionUrl, id, apiVersion string, options RequestOptions) (interface{}, error) { + for { + body, err := client.cachedList(ctx, collectionUrl, apiVersion, options) + if err != nil { + return nil, err + } + if item, err := findItemInList(body, id); err == nil { + return item, nil + } + select { + case <-ctx.Done(): + return nil, fmt.Errorf("timed out waiting for resource %q to appear in collection %s", id, collectionUrl) + case <-time.After(listCacheTTL): + } + } +} + +func findItemInList(body interface{}, id string) (interface{}, error) { + bodyMap, ok := body.(map[string]interface{}) + if !ok { + return nil, &azcore.ResponseError{StatusCode: http.StatusNotFound} + } + values, ok := bodyMap["value"].([]interface{}) + if !ok { + return nil, &azcore.ResponseError{StatusCode: http.StatusNotFound} + } + for _, item := range values { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + if itemId, ok := itemMap["id"].(string); ok && itemId == id { + return item, nil + } + } + return nil, &azcore.ResponseError{StatusCode: http.StatusNotFound} +} + func (client *MSGraphClient) Create(ctx context.Context, url string, apiVersion string, body interface{}, options RequestOptions) (interface{}, error) { if options.RetryOptions != nil { ctx = policy.WithRetryOptions(ctx, *options.RetryOptions) @@ -230,12 +363,11 @@ func (client *MSGraphClient) Create(ctx context.Context, url string, apiVersion return nil, runtime.NewResponseError(resp) } - // TODO: Handle long-running operations if needed - var responseBody interface{} if err := runtime.UnmarshalAsJSON(resp, &responseBody); err != nil { return nil, err } + client.invalidateListCache(url, apiVersion) return responseBody, nil } @@ -267,12 +399,11 @@ func (client *MSGraphClient) Update(ctx context.Context, url string, apiVersion return nil, runtime.NewResponseError(resp) } - // TODO: Handle long-running operations if needed - var responseBody interface{} if err := runtime.UnmarshalAsJSON(resp, &responseBody); err != nil { return nil, err } + client.invalidateListCache(parentCollectionUrl(url), apiVersion) return responseBody, nil } @@ -298,16 +429,14 @@ func (client *MSGraphClient) Delete(ctx context.Context, url string, apiVersion return err } - // TODO: Handle long-running operations if needed - if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusNoContent) { return runtime.NewResponseError(resp) } + client.invalidateListCache(parentCollectionUrl(url), apiVersion) return nil } func (client *MSGraphClient) Action(ctx context.Context, method string, url string, apiVersion string, body interface{}, options RequestOptions) (interface{}, error) { - // apply per-request retry options via context if options.RetryOptions != nil { ctx = policy.WithRetryOptions(ctx, *options.RetryOptions) } @@ -327,7 +456,6 @@ func (client *MSGraphClient) Action(ctx context.Context, method string, url stri req.Raw().Header.Set(key, value) } - // Set request body if provided if body != nil { if err := runtime.MarshalAsJSON(req, body); err != nil { return nil, err @@ -345,6 +473,9 @@ func (client *MSGraphClient) Action(ctx context.Context, method string, url stri return nil, runtime.NewResponseError(resp) } + // Invalidate the list cache for the parent collection after any mutating action. + client.invalidateListCache(parentCollectionUrl(url), apiVersion) + // For methods that typically don't return a body (like DELETE), or if response is empty if resp.StatusCode == http.StatusNoContent || resp.ContentLength == 0 { return nil, nil diff --git a/internal/services/msgraph_resource.go b/internal/services/msgraph_resource.go index e3fc5b7..cd94b5e 100644 --- a/internal/services/msgraph_resource.go +++ b/internal/services/msgraph_resource.go @@ -276,25 +276,24 @@ func (r *MSGraphResource) Create(ctx context.Context, req resource.CreateRequest model.ResourceUrl = types.StringValue(fmt.Sprintf("%s/%s", model.Url.ValueString(), responseId)) } - // Wait for the resource to be available - if err = consistency.WaitForUpdate(ctx, ResourceExistenceFunc(r.client, model)); err != nil { - resp.Diagnostics.AddError("Error", fmt.Sprintf("waiting for creation of %s: %v", model.Url.ValueString(), err)) - return - } - if !isRelationship { + // Poll the collection endpoint until the item appears. Concurrent creates share a single + // GET /collection via the list cache instead of each making an individual GET request. options = clients.RequestOptions{ QueryParameters: clients.NewQueryParameters(AsMapOfLists(model.ReadQueryParameters)), - RetryOptions: clients.CombineRetryOptions( - clients.NewRetryOptionsForReadAfterCreate(), - clients.NewRetryOptions(model.Retry), - ), + RetryOptions: clients.NewRetryOptions(model.Retry), } - responseBody, err = r.client.Read(ctx, fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString()), model.ApiVersion.ValueString(), options) + responseBody, err = r.client.ReadFromListWithWait(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options) if err != nil { resp.Diagnostics.AddError("Failed to read data source", err.Error()) return } + } else { + // For $ref relationships, wait for the reference to appear using the existing mechanism. + if err = consistency.WaitForUpdate(ctx, ResourceExistenceFunc(r.client, model)); err != nil { + resp.Diagnostics.AddError("Error", fmt.Sprintf("waiting for creation of %s: %v", model.Url.ValueString(), err)) + return + } } model.Output = types.DynamicValue(buildOutputFromBody(responseBody, model.ResponseExportValues)) @@ -366,17 +365,13 @@ func (r *MSGraphResource) Update(ctx context.Context, req resource.UpdateRequest } } - // Wait for the resource to be available - if err := consistency.WaitForUpdate(ctx, ResourceExistenceFunc(r.client, model)); err != nil { - resp.Diagnostics.AddError("Error", fmt.Sprintf("waiting for creation of %s: %v", model.Url.ValueString(), err)) - return - } - + // The item existed before the PATCH/PUT and the write already invalidated the cache. + // A single ReadFromList fetches a fresh collection and finds the item — no polling loop. options = clients.RequestOptions{ QueryParameters: clients.NewQueryParameters(AsMapOfLists(model.ReadQueryParameters)), RetryOptions: clients.NewRetryOptions(model.Retry), } - responseBody, err := r.client.Read(ctx, fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString()), model.ApiVersion.ValueString(), options) + responseBody, err := r.client.ReadFromList(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options) if err != nil { resp.Diagnostics.AddError("Failed to read data source", err.Error()) return @@ -457,7 +452,7 @@ func (r *MSGraphResource) Read(ctx context.Context, req resource.ReadRequest, re } options := clients.NewRequestOptions(nil, AsMapOfLists(model.ReadQueryParameters)) - responseBody, err := r.client.Read(ctx, fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString()), model.ApiVersion.ValueString(), options) + responseBody, err := r.client.ReadFromList(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options) if err != nil { if utils.ResponseErrorWasNotFound(err) { tflog.Info(ctx, fmt.Sprintf("Error reading %q - removing from state", model.Id.ValueString())) @@ -540,13 +535,9 @@ func (r *MSGraphResource) Delete(ctx context.Context, req resource.DeleteRequest err := r.client.Delete(ctx, itemUrl, model.ApiVersion.ValueString(), options) if err != nil { resp.Diagnostics.AddError("Failed to delete resource", err.Error()) - return - } - - // Wait for deletion to complete - if err = consistency.WaitForDeletion(ctx, ResourceExistenceFunc(r.client, model)); err != nil { - resp.Diagnostics.AddError("Error waiting for deletion", err.Error()) } + // No WaitForDeletion: a successful DELETE response is sufficient. Polling until + // Graph's eventual consistency propagates the deletion only wastes time and API quota. } func ResourceExistenceFunc(client *clients.MSGraphClient, model *MSGraphResourceModel) consistency.ChangeFunc { @@ -590,8 +581,7 @@ func ResourceExistenceFunc(client *clients.MSGraphClient, model *MSGraphResource options := clients.RequestOptions{ QueryParameters: clients.NewQueryParameters(AsMapOfLists(model.ReadQueryParameters)), } - itemUrl := fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString()) - _, err := client.Read(ctx, itemUrl, model.ApiVersion.ValueString(), options) + _, err := client.ReadFromList(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options) if err != nil { if utils.ResponseErrorWasNotFound(err) { b := false @@ -644,14 +634,11 @@ func (r *MSGraphResource) ImportState(ctx context.Context, req resource.ImportSt urlValue = strings.TrimPrefix(parsedUrl.Path[0:lastIndex], "/") } - // Construct the resource_url based on the URL pattern var resourceUrl string if strings.HasSuffix(urlValue, "/$ref") { - // For $ref URLs, resource_url should be the collection URL without $ref + the ID baseUrl := strings.TrimSuffix(urlValue, "/$ref") resourceUrl = fmt.Sprintf("%s/%s", baseUrl, id) } else { - // For regular URLs, resource_url is url + ID resourceUrl = fmt.Sprintf("%s/%s", urlValue, id) } @@ -765,7 +752,6 @@ func (r *MSGraphResource) MoveState(ctx context.Context) []resource.StateMover { idValue = requestID[lastIndex+1:] } - // For $ref URLs, resource_url should be the collection URL without $ref + the ID baseUrl := strings.TrimSuffix(urlValue, "/$ref") resourceUrl := fmt.Sprintf("%s/%s", baseUrl, idValue)