Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributor ingestion limits #3879

Merged
merged 7 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions pkg/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
connectapi "github.com/grafana/pyroscope/pkg/api/connect"
"github.com/grafana/pyroscope/pkg/clientpool"
"github.com/grafana/pyroscope/pkg/distributor/aggregator"
"github.com/grafana/pyroscope/pkg/distributor/ingest_limits"
distributormodel "github.com/grafana/pyroscope/pkg/distributor/model"
writepath "github.com/grafana/pyroscope/pkg/distributor/write_path"
phlaremodel "github.com/grafana/pyroscope/pkg/model"
Expand Down Expand Up @@ -99,6 +100,7 @@ type Distributor struct {
ingestionRateLimiter *limiter.RateLimiter
aggregator *aggregator.MultiTenantAggregator[*pprof.ProfileMerge]
asyncRequests sync.WaitGroup
ingestionLimitsSampler *ingest_limits.Sampler

subservices *services.Manager
subservicesWatcher *services.FailureWatcher
Expand All @@ -117,6 +119,7 @@ type Distributor struct {
type Limits interface {
IngestionRateBytes(tenantID string) float64
IngestionBurstSizeBytes(tenantID string) int
IngestionLimit(tenantID string) *ingest_limits.Config
IngestionTenantShardSize(tenantID string) int
MaxLabelNameLength(tenantID string) int
MaxLabelValueLength(tenantID string) int
Expand Down Expand Up @@ -187,7 +190,9 @@ func New(
return nil, err
}

subservices = append(subservices, distributorsLifecycler, distributorsRing, d.aggregator)
d.ingestionLimitsSampler = ingest_limits.NewSampler(distributorsRing)

subservices = append(subservices, distributorsLifecycler, distributorsRing, d.aggregator, d.ingestionLimitsSampler)

d.ingestionRateLimiter = limiter.NewRateLimiter(newGlobalRateStrategy(newIngestionRateStrategy(limits), d), 10*time.Second)
d.distributorsLifecycler = distributorsLifecycler
Expand Down Expand Up @@ -302,6 +307,12 @@ func (d *Distributor) PushParsed(ctx context.Context, req *distributormodel.Push
d.metrics.receivedCompressedBytes.WithLabelValues(string(profName), tenantID).Observe(float64(req.RawProfileSize))
}

d.calculateRequestSize(req)

if err := d.checkIngestLimit(tenantID, req); err != nil {
return nil, err
}

if err := d.rateLimit(tenantID, req); err != nil {
return nil, err
}
Expand All @@ -310,7 +321,7 @@ func (d *Distributor) PushParsed(ctx context.Context, req *distributormodel.Push

for _, series := range req.Series {
profName := phlaremodel.Labels(series.Labels).Get(ProfileName)
groups := usageGroups.GetUsageGroups(tenantID, phlaremodel.Labels(series.Labels))
groups := usageGroups.GetUsageGroups(tenantID, series.Labels)
profLanguage := d.GetProfileLanguage(series)

for _, raw := range series.Samples {
Expand Down Expand Up @@ -709,6 +720,17 @@ func (d *Distributor) limitMaxSessionsPerSeries(maxSessionsPerSeries int, labels
}

func (d *Distributor) rateLimit(tenantID string, req *distributormodel.PushRequest) error {
if !d.ingestionRateLimiter.AllowN(time.Now(), tenantID, int(req.TotalBytesUncompressed)) {
validation.DiscardedProfiles.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalProfiles))
validation.DiscardedBytes.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalBytesUncompressed))
return connect.NewError(connect.CodeResourceExhausted,
fmt.Errorf("push rate limit (%s) exceeded while adding %s", humanize.IBytes(uint64(d.limits.IngestionRateBytes(tenantID))), humanize.IBytes(uint64(req.TotalBytesUncompressed))),
)
}
return nil
}

func (d *Distributor) calculateRequestSize(req *distributormodel.PushRequest) {
for _, series := range req.Series {
// include the labels in the size calculation
for _, lbs := range series.Labels {
Expand All @@ -720,14 +742,26 @@ func (d *Distributor) rateLimit(tenantID string, req *distributormodel.PushReque
req.TotalBytesUncompressed += int64(raw.Profile.SizeVT())
}
}
// rate limit the request
if !d.ingestionRateLimiter.AllowN(time.Now(), tenantID, int(req.TotalBytesUncompressed)) {
validation.DiscardedProfiles.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalProfiles))
validation.DiscardedBytes.WithLabelValues(string(validation.RateLimited), tenantID).Add(float64(req.TotalBytesUncompressed))
}

func (d *Distributor) checkIngestLimit(tenantID string, req *distributormodel.PushRequest) error {
l := d.limits.IngestionLimit(tenantID)
if l == nil {
return nil
}

if l.LimitReached {
// we want to allow a very small portion of the traffic after reaching the limit
if d.ingestionLimitsSampler.AllowRequest(tenantID, l.Sampling) {
return nil
}
limitResetTime := time.Unix(l.LimitResetTime, 0).UTC().Format(time.RFC3339)
validation.DiscardedProfiles.WithLabelValues(string(validation.IngestLimitReached), tenantID).Add(float64(req.TotalProfiles))
validation.DiscardedBytes.WithLabelValues(string(validation.IngestLimitReached), tenantID).Add(float64(req.TotalBytesUncompressed))
return connect.NewError(connect.CodeResourceExhausted,
fmt.Errorf("push rate limit (%s) exceeded while adding %s", humanize.IBytes(uint64(d.limits.IngestionRateBytes(tenantID))), humanize.IBytes(uint64(req.TotalBytesUncompressed))),
)
fmt.Errorf("limit of %s/%s reached, next reset at %s", humanize.IBytes(uint64(l.PeriodLimitMb*1024*1024)), l.PeriodType, limitResetTime))
}

return nil
}

Expand Down
21 changes: 21 additions & 0 deletions pkg/distributor/distributor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/pyroscope/pkg/distributor/ingest_limits"
testhelper2 "github.com/grafana/pyroscope/pkg/pprof/testhelper"

profilev1 "github.com/grafana/pyroscope/api/gen/proto/go/google/v1"
Expand Down Expand Up @@ -301,6 +302,26 @@ func Test_Limits(t *testing.T) {
expectedCode: connect.CodeInvalidArgument,
expectedValidationReason: validation.LabelNameTooLong,
},
{
description: "ingest_limit_reached",
pushReq: &pushv1.PushRequest{},
overrides: validation.MockOverrides(func(defaults *validation.Limits, tenantLimits map[string]*validation.Limits) {
l := validation.MockDefaultLimits()
l.IngestionLimit = &ingest_limits.Config{
PeriodType: "hour",
PeriodLimitMb: 128,
LimitResetTime: 1737721086,
LimitReached: true,
Sampling: ingest_limits.SamplingConfig{
NumRequests: 0,
Period: time.Minute,
},
}
tenantLimits["user-1"] = l
}),
expectedCode: connect.CodeResourceExhausted,
expectedValidationReason: validation.IngestLimitReached,
},
}

for _, tc := range testCases {
Expand Down
25 changes: 25 additions & 0 deletions pkg/distributor/ingest_limits/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ingest_limits

import "time"

type Config struct {
// PeriodType provides the limit period / interval (e.g., "hour"). Used in error messages only.
PeriodType string `yaml:"period_type" json:"period_type"`
// PeriodLimitMb provides the limit that is being set in MB. Used in error messages only.
PeriodLimitMb int `yaml:"period_limit_mb" json:"period_limit_mb"`
// LimitResetTime provides the time (Unix seconds) when the limit will reset. Used in error messages only.
LimitResetTime int64 `yaml:"limit_reset_time" json:"limit_reset_time"`
// LimitReached instructs distributors to allow or reject profiles.
LimitReached bool `yaml:"limit_reached" json:"limit_reached"`
// Sampling controls the sampling parameters when the limit is reached.
Sampling SamplingConfig `yaml:"sampling" json:"sampling"`
}

// SamplingConfig describes the params of a simple probabilistic sampling mechanism.
//
// Distributors should allow up to NumRequests requests through and then apply a cooldown (Period) after which
// more requests can be let through.
type SamplingConfig struct {
NumRequests int `yaml:"num_requests" json:"num_requests"`
Period time.Duration `yaml:"period" json:"period"`
}
142 changes: 142 additions & 0 deletions pkg/distributor/ingest_limits/sampler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package ingest_limits

import (
"context"
"math/rand"
"sync"
"time"

"github.com/grafana/dskit/services"
)

type tenantTracker struct {
mu sync.Mutex
lastRequestTime time.Time
remainingRequests int
}

// Sampler provides a very simple time-based probabilistic sampling,
// intended to be used when a tenant limit has been reached.
//
// The sampler will allow a number of requests in a time interval.
// Once the interval is over, the number of allowed requests resets.
//
// We introduce a probability function for a request to be allowed defined as 1 / num_replicas,
// to account for the size of the cluster and because tracking is done in memory.
type Sampler struct {
*services.BasicService

mu sync.RWMutex
tenants map[string]*tenantTracker

// needed for adjusting the probability function with the number of replicas
instanceCountProvider InstanceCountProvider

// cleanup of the tenants map to prevent build-up
cleanupInterval time.Duration
maxAge time.Duration
closeOnce sync.Once
stop chan struct{}
done chan struct{}
}

type InstanceCountProvider interface {
InstancesCount() int
}

func NewSampler(instanceCount InstanceCountProvider) *Sampler {
s := &Sampler{
tenants: make(map[string]*tenantTracker),
instanceCountProvider: instanceCount,
cleanupInterval: 1 * time.Hour,
maxAge: 24 * time.Hour,
stop: make(chan struct{}),
done: make(chan struct{}),
}
s.BasicService = services.NewBasicService(
s.starting,
s.running,
s.stopping,
)

return s
}

func (s *Sampler) starting(_ context.Context) error { return nil }

func (s *Sampler) stopping(_ error) error {
s.closeOnce.Do(func() {
close(s.stop)
<-s.done
})
return nil
}

func (s *Sampler) running(ctx context.Context) error {
t := time.NewTicker(s.cleanupInterval)
defer func() {
t.Stop()
close(s.done)
}()
for {
select {
case <-t.C:
s.removeStaleTenants()
case <-s.stop:
return nil
case <-ctx.Done():
return nil
}
}
}

func (s *Sampler) AllowRequest(tenantID string, config SamplingConfig) bool {
s.mu.Lock()
tracker, exists := s.tenants[tenantID]
if !exists {
tracker = &tenantTracker{
lastRequestTime: time.Now(),
remainingRequests: config.NumRequests,
}
s.tenants[tenantID] = tracker
}
s.mu.Unlock()

return tracker.AllowRequest(s.instanceCountProvider.InstancesCount(), config.Period, config.NumRequests)
}

func (b *tenantTracker) AllowRequest(replicaCount int, windowDuration time.Duration, maxRequests int) bool {
b.mu.Lock()
defer b.mu.Unlock()

now := time.Now()

// reset tracking data if enough time has passed
if now.Sub(b.lastRequestTime) >= windowDuration {
b.lastRequestTime = now
b.remainingRequests = maxRequests
}

if b.remainingRequests > 0 {
// random chance of allowing request, adjusting for the number of replicas
shouldAllow := rand.Float64() < float64(maxRequests)/float64(replicaCount)

if shouldAllow {
b.remainingRequests--
return true
}
}

return false
}

func (s *Sampler) removeStaleTenants() {
s.mu.Lock()
cutoff := time.Now().Add(-s.maxAge)
for tenantID, tracker := range s.tenants {
if tracker.lastRequestTime.Before(cutoff) {
delete(s.tenants, tenantID)
}
}
s.mu.Unlock()
}
Loading
Loading