Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 148 additions & 17 deletions internal/clients/msgraph_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ package clients
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand All @@ -14,11 +19,26 @@ const (
moduleName = "resource"
moduleVersion = "v0.1.0"
nextLinkKey = "@odata.nextLink"
listCacheTTL = 10 * time.Second
)

type listCacheEntry struct {
body interface{}
fetchedAt time.Time
}

type listInflight struct {
done chan struct{}
body interface{}
err error
}

type MSGraphClient struct {
host string
pl runtime.Pipeline
host string
pl runtime.Pipeline
cacheMu sync.Mutex
listCache map[string]listCacheEntry
inFlight map[string]*listInflight
}

func NewMSGraphClient(credential azcore.TokenCredential, opt *policy.ClientOptions) (*MSGraphClient, error) {
Expand All @@ -33,13 +53,14 @@ func NewMSGraphClient(credential azcore.TokenCredential, opt *policy.ClientOptio
Tracing: runtime.TracingOptions{},
}, opt)
return &MSGraphClient{
host: "https://graph.microsoft.com",
pl: pl,
host: "https://graph.microsoft.com",
pl: pl,
listCache: make(map[string]listCacheEntry),
inFlight: make(map[string]*listInflight),
}, nil
}

func (client *MSGraphClient) Read(ctx context.Context, url string, apiVersion string, options RequestOptions) (interface{}, error) {
// apply per-request retry options via context
if options.RetryOptions != nil {
ctx = policy.WithRetryOptions(ctx, *options.RetryOptions)
}
Expand Down Expand Up @@ -69,7 +90,6 @@ func (client *MSGraphClient) Read(ctx context.Context, url string, apiVersion st
return nil, err
}

// if response has nextLink, follow the link and return the final response
if responseBodyMap, ok := responseBody.(map[string]interface{}); ok {
if nextLink := responseBodyMap["@odata.nextLink"]; nextLink != nil {
return client.List(ctx, url, apiVersion, options)
Expand All @@ -79,8 +99,58 @@ func (client *MSGraphClient) Read(ctx context.Context, url string, apiVersion st
return responseBody, nil
}

func (client *MSGraphClient) cachedList(ctx context.Context, url, apiVersion string, options RequestOptions) (interface{}, error) {
key := listCacheKey(url, apiVersion, options.QueryParameters)

client.cacheMu.Lock()
if entry, ok := client.listCache[key]; ok && time.Since(entry.fetchedAt) < listCacheTTL {
client.cacheMu.Unlock()
return entry.body, nil
}
if f, ok := client.inFlight[key]; ok {
client.cacheMu.Unlock()
select {
case <-f.done:
case <-ctx.Done():
return nil, ctx.Err()
}
return f.body, f.err
}
f := &listInflight{done: make(chan struct{})}
client.inFlight[key] = f
client.cacheMu.Unlock()

body, err := client.List(ctx, url, apiVersion, options)

client.cacheMu.Lock()
delete(client.inFlight, key)
if err == nil {
client.listCache[key] = listCacheEntry{body: body, fetchedAt: time.Now()}
}
client.cacheMu.Unlock()

f.body = body
f.err = err
close(f.done)

return body, err
}

func listCacheKey(url, apiVersion string, params map[string]string) string {
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
key := fmt.Sprintf("%s|%s", apiVersion, url)
for _, k := range keys {
key += fmt.Sprintf("|%s=%s", k, params[k])
}
return key
}

func (client *MSGraphClient) ListRefIDs(ctx context.Context, url string, apiVersion string, options RequestOptions) ([]string, error) {
responseBody, err := client.List(ctx, url, apiVersion, options)
responseBody, err := client.cachedList(ctx, url, apiVersion, options)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -185,15 +255,13 @@ func (client *MSGraphClient) List(ctx context.Context, url string, apiVersion st
continue
}
}
// copy all fields except for nextLinkKey and value
for key, val := range pageMap {
if key != nextLinkKey && key != "value" {
out[key] = val
}
}
}

// if response doesn't follow the paging guideline, return the response as is
return page, nil
}

Expand All @@ -202,6 +270,71 @@ func (client *MSGraphClient) List(ctx context.Context, url string, apiVersion st
return out, nil
}

func (client *MSGraphClient) invalidateListCache(collectionUrl, apiVersion string) {
prefix := fmt.Sprintf("%s|%s", apiVersion, collectionUrl)
client.cacheMu.Lock()
defer client.cacheMu.Unlock()
for key := range client.listCache {
if strings.HasPrefix(key, prefix) {
delete(client.listCache, key)
}
}
}

func parentCollectionUrl(itemUrl string) string {
u := strings.TrimSuffix(itemUrl, "/$ref")
if idx := strings.LastIndex(u, "/"); idx >= 0 {
return u[:idx]
}
return u
}

func (client *MSGraphClient) ReadFromList(ctx context.Context, collectionUrl string, id string, apiVersion string, options RequestOptions) (interface{}, error) {
responseBody, err := client.cachedList(ctx, collectionUrl, apiVersion, options)
if err != nil {
return nil, err
}
return findItemInList(responseBody, id)
}

func (client *MSGraphClient) ReadFromListWithWait(ctx context.Context, collectionUrl, id, apiVersion string, options RequestOptions) (interface{}, error) {
for {
body, err := client.cachedList(ctx, collectionUrl, apiVersion, options)
if err != nil {
return nil, err
}
if item, err := findItemInList(body, id); err == nil {
return item, nil
}
select {
case <-ctx.Done():
return nil, fmt.Errorf("timed out waiting for resource %q to appear in collection %s", id, collectionUrl)
case <-time.After(listCacheTTL):
}
}
}

func findItemInList(body interface{}, id string) (interface{}, error) {
bodyMap, ok := body.(map[string]interface{})
if !ok {
return nil, &azcore.ResponseError{StatusCode: http.StatusNotFound}
}
values, ok := bodyMap["value"].([]interface{})
if !ok {
return nil, &azcore.ResponseError{StatusCode: http.StatusNotFound}
}
for _, item := range values {
itemMap, ok := item.(map[string]interface{})
if !ok {
continue
}
if itemId, ok := itemMap["id"].(string); ok && itemId == id {
return item, nil
}
}
return nil, &azcore.ResponseError{StatusCode: http.StatusNotFound}
}

func (client *MSGraphClient) Create(ctx context.Context, url string, apiVersion string, body interface{}, options RequestOptions) (interface{}, error) {
if options.RetryOptions != nil {
ctx = policy.WithRetryOptions(ctx, *options.RetryOptions)
Expand Down Expand Up @@ -230,12 +363,11 @@ func (client *MSGraphClient) Create(ctx context.Context, url string, apiVersion
return nil, runtime.NewResponseError(resp)
}

// TODO: Handle long-running operations if needed

var responseBody interface{}
if err := runtime.UnmarshalAsJSON(resp, &responseBody); err != nil {
return nil, err
}
client.invalidateListCache(url, apiVersion)
return responseBody, nil
}

Expand Down Expand Up @@ -267,12 +399,11 @@ func (client *MSGraphClient) Update(ctx context.Context, url string, apiVersion
return nil, runtime.NewResponseError(resp)
}

// TODO: Handle long-running operations if needed

var responseBody interface{}
if err := runtime.UnmarshalAsJSON(resp, &responseBody); err != nil {
return nil, err
}
client.invalidateListCache(parentCollectionUrl(url), apiVersion)
return responseBody, nil
}

Expand All @@ -298,16 +429,14 @@ func (client *MSGraphClient) Delete(ctx context.Context, url string, apiVersion
return err
}

// TODO: Handle long-running operations if needed

if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusNoContent) {
return runtime.NewResponseError(resp)
}
client.invalidateListCache(parentCollectionUrl(url), apiVersion)
return nil
}

func (client *MSGraphClient) Action(ctx context.Context, method string, url string, apiVersion string, body interface{}, options RequestOptions) (interface{}, error) {
// apply per-request retry options via context
if options.RetryOptions != nil {
ctx = policy.WithRetryOptions(ctx, *options.RetryOptions)
}
Expand All @@ -327,7 +456,6 @@ func (client *MSGraphClient) Action(ctx context.Context, method string, url stri
req.Raw().Header.Set(key, value)
}

// Set request body if provided
if body != nil {
if err := runtime.MarshalAsJSON(req, body); err != nil {
return nil, err
Expand All @@ -345,6 +473,9 @@ func (client *MSGraphClient) Action(ctx context.Context, method string, url stri
return nil, runtime.NewResponseError(resp)
}

// Invalidate the list cache for the parent collection after any mutating action.
client.invalidateListCache(parentCollectionUrl(url), apiVersion)

// For methods that typically don't return a body (like DELETE), or if response is empty
if resp.StatusCode == http.StatusNoContent || resp.ContentLength == 0 {
return nil, nil
Expand Down
48 changes: 17 additions & 31 deletions internal/services/msgraph_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,24 @@ func (r *MSGraphResource) Create(ctx context.Context, req resource.CreateRequest
model.ResourceUrl = types.StringValue(fmt.Sprintf("%s/%s", model.Url.ValueString(), responseId))
}

// Wait for the resource to be available
if err = consistency.WaitForUpdate(ctx, ResourceExistenceFunc(r.client, model)); err != nil {
resp.Diagnostics.AddError("Error", fmt.Sprintf("waiting for creation of %s: %v", model.Url.ValueString(), err))
return
}

if !isRelationship {
// Poll the collection endpoint until the item appears. Concurrent creates share a single
// GET /collection via the list cache instead of each making an individual GET request.
options = clients.RequestOptions{
QueryParameters: clients.NewQueryParameters(AsMapOfLists(model.ReadQueryParameters)),
RetryOptions: clients.CombineRetryOptions(
clients.NewRetryOptionsForReadAfterCreate(),
clients.NewRetryOptions(model.Retry),
),
RetryOptions: clients.NewRetryOptions(model.Retry),
}
responseBody, err = r.client.Read(ctx, fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString()), model.ApiVersion.ValueString(), options)
responseBody, err = r.client.ReadFromListWithWait(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options)
if err != nil {
resp.Diagnostics.AddError("Failed to read data source", err.Error())
return
}
} else {
// For $ref relationships, wait for the reference to appear using the existing mechanism.
if err = consistency.WaitForUpdate(ctx, ResourceExistenceFunc(r.client, model)); err != nil {
resp.Diagnostics.AddError("Error", fmt.Sprintf("waiting for creation of %s: %v", model.Url.ValueString(), err))
return
}
}

model.Output = types.DynamicValue(buildOutputFromBody(responseBody, model.ResponseExportValues))
Expand Down Expand Up @@ -366,17 +365,13 @@ func (r *MSGraphResource) Update(ctx context.Context, req resource.UpdateRequest
}
}

// Wait for the resource to be available
if err := consistency.WaitForUpdate(ctx, ResourceExistenceFunc(r.client, model)); err != nil {
resp.Diagnostics.AddError("Error", fmt.Sprintf("waiting for creation of %s: %v", model.Url.ValueString(), err))
return
}

// The item existed before the PATCH/PUT and the write already invalidated the cache.
// A single ReadFromList fetches a fresh collection and finds the item — no polling loop.
options = clients.RequestOptions{
QueryParameters: clients.NewQueryParameters(AsMapOfLists(model.ReadQueryParameters)),
RetryOptions: clients.NewRetryOptions(model.Retry),
}
responseBody, err := r.client.Read(ctx, fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString()), model.ApiVersion.ValueString(), options)
responseBody, err := r.client.ReadFromList(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options)
if err != nil {
resp.Diagnostics.AddError("Failed to read data source", err.Error())
return
Expand Down Expand Up @@ -457,7 +452,7 @@ func (r *MSGraphResource) Read(ctx context.Context, req resource.ReadRequest, re
}

options := clients.NewRequestOptions(nil, AsMapOfLists(model.ReadQueryParameters))
responseBody, err := r.client.Read(ctx, fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString()), model.ApiVersion.ValueString(), options)
responseBody, err := r.client.ReadFromList(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options)
if err != nil {
if utils.ResponseErrorWasNotFound(err) {
tflog.Info(ctx, fmt.Sprintf("Error reading %q - removing from state", model.Id.ValueString()))
Expand Down Expand Up @@ -540,13 +535,9 @@ func (r *MSGraphResource) Delete(ctx context.Context, req resource.DeleteRequest
err := r.client.Delete(ctx, itemUrl, model.ApiVersion.ValueString(), options)
if err != nil {
resp.Diagnostics.AddError("Failed to delete resource", err.Error())
return
}

// Wait for deletion to complete
if err = consistency.WaitForDeletion(ctx, ResourceExistenceFunc(r.client, model)); err != nil {
resp.Diagnostics.AddError("Error waiting for deletion", err.Error())
}
// No WaitForDeletion: a successful DELETE response is sufficient. Polling until
// Graph's eventual consistency propagates the deletion only wastes time and API quota.
}

func ResourceExistenceFunc(client *clients.MSGraphClient, model *MSGraphResourceModel) consistency.ChangeFunc {
Expand Down Expand Up @@ -590,8 +581,7 @@ func ResourceExistenceFunc(client *clients.MSGraphClient, model *MSGraphResource
options := clients.RequestOptions{
QueryParameters: clients.NewQueryParameters(AsMapOfLists(model.ReadQueryParameters)),
}
itemUrl := fmt.Sprintf("%s/%s", model.Url.ValueString(), model.Id.ValueString())
_, err := client.Read(ctx, itemUrl, model.ApiVersion.ValueString(), options)
_, err := client.ReadFromList(ctx, model.Url.ValueString(), model.Id.ValueString(), model.ApiVersion.ValueString(), options)
if err != nil {
if utils.ResponseErrorWasNotFound(err) {
b := false
Expand Down Expand Up @@ -644,14 +634,11 @@ func (r *MSGraphResource) ImportState(ctx context.Context, req resource.ImportSt
urlValue = strings.TrimPrefix(parsedUrl.Path[0:lastIndex], "/")
}

// Construct the resource_url based on the URL pattern
var resourceUrl string
if strings.HasSuffix(urlValue, "/$ref") {
// For $ref URLs, resource_url should be the collection URL without $ref + the ID
baseUrl := strings.TrimSuffix(urlValue, "/$ref")
resourceUrl = fmt.Sprintf("%s/%s", baseUrl, id)
} else {
// For regular URLs, resource_url is url + ID
resourceUrl = fmt.Sprintf("%s/%s", urlValue, id)
}

Expand Down Expand Up @@ -765,7 +752,6 @@ func (r *MSGraphResource) MoveState(ctx context.Context) []resource.StateMover {
idValue = requestID[lastIndex+1:]
}

// For $ref URLs, resource_url should be the collection URL without $ref + the ID
baseUrl := strings.TrimSuffix(urlValue, "/$ref")
resourceUrl := fmt.Sprintf("%s/%s", baseUrl, idValue)

Expand Down