diff --git a/cmd/cmd_run.go b/cmd/cmd_run.go index a0c36fd..60f506d 100644 --- a/cmd/cmd_run.go +++ b/cmd/cmd_run.go @@ -15,8 +15,13 @@ import ( "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/service" "github.com/spf13/cobra" "gopkg.in/ini.v1" + + "github.com/getlantern/lantern-box/adapter" + "github.com/getlantern/lantern-box/tracker/clientcontext" + "github.com/getlantern/lantern-box/tracker/datacap" ) func init() { @@ -25,6 +30,7 @@ func init() { runCmd.Flags().String("geo-city-url", "https://lanterngeo.lantern.io/GeoLite2-City.mmdb.tar.gz", "URL for downloading GeoLite2-City database") runCmd.Flags().String("city-database-name", "GeoLite2-City.mmdb", "Filename for storing GeoLite2-City database") runCmd.Flags().String("telemetry-endpoint", "telemetry.iantem.io:443", "Telemetry endpoint for OpenTelemetry exporter") + runCmd.Flags().String("datacap-url", "", "Datacap server URL") } var runCmd = &cobra.Command{ @@ -35,7 +41,11 @@ var runCmd = &cobra.Command{ if err != nil { return fmt.Errorf("get config flag: %w", err) } - return run(path) + datacapURL, err := cmd.Flags().GetString("datacap-url") + if err != nil { + return fmt.Errorf("get datacap-url flag: %w", err) + } + return run(path, datacapURL) }, } @@ -64,7 +74,7 @@ func readConfig(path string) (option.Options, error) { return options, nil } -func create(configPath string) (*box.Box, context.CancelFunc, error) { +func create(configPath string, datacapURL string) (*box.Box, context.CancelFunc, error) { options, err := readConfig(configPath) if err != nil { return nil, nil, fmt.Errorf("read config: %w", err) @@ -79,6 +89,29 @@ func create(configPath string) (*box.Box, context.CancelFunc, error) { return nil, nil, fmt.Errorf("create service: %w", err) } + if datacapURL != "" { + // Add datacap tracker + clientCtxMgr := clientcontext.NewManager(clientcontext.MatchBounds{ + Inbound: []string{""}, + Outbound: []string{""}, + }, log.NewNOPFactory().NewLogger("tracker")) + instance.Router().AppendTracker(clientCtxMgr) + service.MustRegister[adapter.ClientContextManager](ctx, clientCtxMgr) + + datacapTracker, err := datacap.NewDatacapTracker( + datacap.Options{ + URL: datacapURL, + }, + log.NewNOPFactory().NewLogger("datacap-tracker"), + ) + if err != nil { + return nil, nil, fmt.Errorf("create datacap tracker: %w", err) + } + clientCtxMgr.AppendTracker(datacapTracker) + } else { + log.Warn("Datacap URL not provided, datacap tracking disabled") + } + osSignals := make(chan os.Signal, 1) signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) defer func() { @@ -112,13 +145,13 @@ func closeMonitor(ctx context.Context) { log.Fatal("sing-box did not close!") } -func run(configPath string) error { +func run(configPath string, datacapURL string) error { log.Info("build info: version ", version, ", commit ", commit) osSignals := make(chan os.Signal, 1) signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) defer signal.Stop(osSignals) for { - instance, cancel, err := create(configPath) + instance, cancel, err := create(configPath, datacapURL) if err != nil { return err } diff --git a/tracker/datacap/client.go b/tracker/datacap/client.go new file mode 100644 index 0000000..b1fd830 --- /dev/null +++ b/tracker/datacap/client.go @@ -0,0 +1,150 @@ +package datacap + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" +) + +// Client handles communication with the datacap sidecar service. +type Client struct { + httpClient *http.Client + baseURL string +} + +// ClientConfig holds configuration for the datacap client. +type ClientConfig struct { + BaseURL string + Timeout time.Duration + InsecureSkipVerify bool +} + +// NewClient creates a new datacap client. +// The baseURL can be overridden by the DATACAP_URL environment variable. +// Supports both HTTP and HTTPS. For HTTPS, uses system's trusted certificates by default. +func NewClient(baseURL string, timeout time.Duration) *Client { + return NewClientWithConfig(ClientConfig{ + BaseURL: baseURL, + Timeout: timeout, + InsecureSkipVerify: false, + }) +} + +// NewClientWithConfig creates a new datacap client with advanced configuration. +func NewClientWithConfig(config ClientConfig) *Client { + // Check for environment variable override + if envURL := os.Getenv("DATACAP_URL"); envURL != "" { + config.BaseURL = envURL + } + + // Ensure HTTPS if not explicitly HTTP + if config.BaseURL != "" && !strings.HasPrefix(config.BaseURL, "http://") && !strings.HasPrefix(config.BaseURL, "https://") { + config.BaseURL = "https://" + config.BaseURL + } + + // Create HTTP client with TLS configuration + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: config.InsecureSkipVerify, + }, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + } + + return &Client{ + httpClient: &http.Client{ + Timeout: config.Timeout, + Transport: transport, + }, + baseURL: config.BaseURL, + } +} + +// DataCapStatus represents the response from the GET /data-cap/{deviceId} endpoint. +type DataCapStatus struct { + Throttle bool `json:"throttle"` + RemainingBytes int64 `json:"remainingBytes"` + CapLimit int64 `json:"capLimit"` + ExpiryTime int64 `json:"expiryTime"` +} + +// DataCapReport represents the request body for POST /data-cap/ endpoint. +type DataCapReport struct { + DeviceID string `json:"deviceId"` + CountryCode string `json:"countryCode"` + Platform string `json:"platform"` + BytesUsed int64 `json:"bytesUsed"` +} + +func (c *Client) GetDataCapStatus(ctx context.Context, deviceID string) (*DataCapStatus, error) { + url := fmt.Sprintf("%s/data-cap/%s", c.baseURL, deviceID) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to query datacap status: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("datacap status request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var status DataCapStatus + if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { + return nil, fmt.Errorf("failed to decode datacap status: %w", err) + } + + return &status, nil +} + +// ReportDataCapConsumption sends data consumption report to the sidecar. +// Endpoint: POST /data-cap/ +// This tracks usage and returns updated cap status. +func (c *Client) ReportDataCapConsumption(ctx context.Context, report *DataCapReport) (*DataCapStatus, error) { + url := fmt.Sprintf("%s/data-cap/", c.baseURL) + + jsonData, err := json.Marshal(report) + if err != nil { + return nil, fmt.Errorf("failed to marshal report: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to report datacap consumption: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("datacap report request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var status DataCapStatus + if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { + return nil, fmt.Errorf("failed to decode datacap status: %w", err) + } + + return &status, nil +} diff --git a/tracker/datacap/client_test.go b/tracker/datacap/client_test.go new file mode 100644 index 0000000..e6a25f0 --- /dev/null +++ b/tracker/datacap/client_test.go @@ -0,0 +1,231 @@ +package datacap + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestDataCapClient(t *testing.T) { + // Create a test server + statusCalled := false + reportCalled := false + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/data-cap/test-device-123" && r.Method == http.MethodGet { + statusCalled = true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":9663676416,"capLimit":10737418240,"expiryTime":1700179200}`)) + } else if r.URL.Path == "/data-cap/" && r.Method == http.MethodPost { + reportCalled = true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":9663676416,"capLimit":10737418240,"expiryTime":1700179200}`)) + } else { + t.Errorf("unexpected path: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + // Create client + client := NewClient(server.URL, 5*time.Second) + + // Test GetDataCapStatus + ctx := context.Background() + status, err := client.GetDataCapStatus(ctx, "test-device-123") + if err != nil { + t.Fatalf("GetDataCapStatus failed: %v", err) + } + + if !statusCalled { + t.Error("status endpoint was not called") + } + + if !status.Throttle { + // Expected: throttle should be false + } else { + t.Error("expected throttle=false, got true") + } + + if status.RemainingBytes != 9663676416 { + t.Errorf("expected remainingBytes=9663676416, got %d", status.RemainingBytes) + } + + if status.CapLimit != 10737418240 { + t.Errorf("expected capLimit=10737418240, got %d", status.CapLimit) + } + + if status.ExpiryTime != 1700179200 { + t.Errorf("expected expiryTime=1700179200, got %d", status.ExpiryTime) + } + + // Test ReportDataCapConsumption + report := &DataCapReport{ + DeviceID: "test-device-123", + CountryCode: "US", + Platform: "android", + BytesUsed: 1048576, + } + + status, err = client.ReportDataCapConsumption(ctx, report) + if err != nil { + t.Fatalf("ReportDataCapConsumption failed: %v", err) + } + + if !reportCalled { + t.Error("report endpoint was not called") + } + + if status == nil { + t.Fatal("expected status response, got nil") + } +} + +func TestDataCapClientInvalidURL(t *testing.T) { + client := NewClient("http://invalid-url-that-does-not-exist:9999", 1*time.Second) + + ctx := context.Background() + _, err := client.GetDataCapStatus(ctx, "test-device") + if err == nil { + t.Error("expected error for invalid URL, got nil") + } +} + +func TestDataCapClientTimeout(t *testing.T) { + // Create a server that sleeps longer than timeout + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create client with short timeout + client := NewClient(server.URL, 100*time.Millisecond) + + ctx := context.Background() + _, err := client.GetDataCapStatus(ctx, "test-device") + if err == nil { + t.Error("expected timeout error, got nil") + } +} + +func TestDataCapClientThrottleTrue(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + // Device is over cap, should be throttled + w.Write([]byte(`{"throttle":true,"remainingBytes":0,"capLimit":1073741824,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + ctx := context.Background() + + status, err := client.GetDataCapStatus(ctx, "test-device") + if err != nil { + t.Fatalf("GetDataCapStatus failed: %v", err) + } + + if !status.Throttle { + t.Error("expected throttle=true, got false") + } + + if status.RemainingBytes != 0 { + t.Errorf("expected remainingBytes=0, got %d", status.RemainingBytes) + } +} + +func TestDataCapClientAcceptHeader(t *testing.T) { + acceptHeaderReceived := false + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify Accept header is set + if r.Header.Get("Accept") == "application/json" { + acceptHeaderReceived = true + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":1000000,"capLimit":1073741824,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + ctx := context.Background() + + _, err := client.GetDataCapStatus(ctx, "test-device") + if err != nil { + t.Fatalf("GetDataCapStatus failed: %v", err) + } + + if !acceptHeaderReceived { + t.Error("Accept: application/json header was not sent") + } +} + +func TestDataCapReportWithPlatform(t *testing.T) { + platformReceived := false + var receivedPlatform string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + // Parse the request body + var report DataCapReport + if err := json.NewDecoder(r.Body).Decode(&report); err == nil { + if report.Platform != "" { + platformReceived = true + receivedPlatform = report.Platform + } + } + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":9663676416,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + ctx := context.Background() + + report := &DataCapReport{ + DeviceID: "test-device", + CountryCode: "US", + Platform: "ios", + BytesUsed: 1048576, + } + + _, err := client.ReportDataCapConsumption(ctx, report) + if err != nil { + t.Fatalf("ReportDataCapConsumption failed: %v", err) + } + + if !platformReceived { + t.Error("platform field was not received") + } + + if receivedPlatform != "ios" { + t.Errorf("expected platform=ios, got %s", receivedPlatform) + } +} + +func TestDataCapClientWithHTTPS(t *testing.T) { + client := NewClient("datacap-sidecar.example.com", 5*time.Second) + + if client.baseURL != "https://datacap-sidecar.example.com" { + t.Errorf("expected HTTPS URL, got %s", client.baseURL) + } +} + +func TestDataCapClientEnvironmentVariable(t *testing.T) { + t.Setenv("DATACAP_URL", "https://env-override.example.com") + + client := NewClient("https://default.example.com", 5*time.Second) + + if client.baseURL != "https://env-override.example.com" { + t.Errorf("expected environment variable URL, got %s", client.baseURL) + } +} diff --git a/tracker/datacap/conn.go b/tracker/datacap/conn.go new file mode 100644 index 0000000..b06249c --- /dev/null +++ b/tracker/datacap/conn.go @@ -0,0 +1,258 @@ +package datacap + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/log" + + "github.com/getlantern/lantern-box/tracker/clientcontext" +) + +// Throttle speed constants for datacap enforcement +const ( + + // Default upload speed (not throttled to allow user uploads even when capped) + defaultUploadSpeedBytesPerSec = 5 * 1024 * 1024 // 5 MB/s + + // Throttle speed tiers based on remaining percentage + highRemainingThresholdPct = 0.2 // 20% remaining + mediumRemainingThresholdPct = 0.1 // 10% remaining + highTierSpeedBytesPerSec = 5 * 1024 * 1024 // 5 MB/s + mediumTierSpeedBytesPerSec = 2 * 1024 * 1024 // 2 MB/s + lowTierSpeedBytesPerSec = 128 * 1024 // 128 KB/s +) + +// Conn wraps a net.Conn and tracks data consumption for datacap reporting. +type Conn struct { + net.Conn + ctx context.Context + cancel context.CancelFunc + client *Client + logger log.ContextLogger + + clientInfo clientcontext.ClientInfo + + // Atomic counters for thread-safe tracking + bytesSent atomic.Int64 + bytesReceived atomic.Int64 + + // Reporting control + reportTicker *time.Ticker + reportMutex sync.Mutex + closed atomic.Bool + wg sync.WaitGroup + // Throttling + throttler *Throttler + throttlingEnabled bool +} + +// ConnConfig holds configuration for creating a datacap-tracked connection. +type ConnConfig struct { + Conn net.Conn + Client *Client + Logger log.ContextLogger + ClientInfo clientcontext.ClientInfo + ReportInterval time.Duration + EnableThrottling bool + ThrottleSpeed int64 +} + +// NewConn creates a new datacap-tracked connection wrapper. +func NewConn(config ConnConfig) *Conn { + ctx, cancel := context.WithCancel(context.Background()) + + // Default report interval to 30 seconds if not specified + if config.ReportInterval == 0 { + config.ReportInterval = 30 * time.Second + } + + conn := &Conn{ + Conn: config.Conn, + ctx: ctx, + cancel: cancel, + client: config.Client, + logger: config.Logger, + clientInfo: config.ClientInfo, + reportTicker: time.NewTicker(config.ReportInterval), + throttler: NewThrottler(config.ThrottleSpeed), + throttlingEnabled: config.EnableThrottling, + } + + // Start periodic reporting goroutine + conn.wg.Add(1) + go conn.periodicReport() + + return conn +} + +// Read tracks bytes received and applies throttling if enabled. +func (c *Conn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if n > 0 { + c.bytesReceived.Add(int64(n)) + + // Apply throttling after read (token bucket wait) + if c.throttler.IsEnabled() { + if waitErr := c.throttler.WaitRead(c.ctx, n); waitErr != nil { + // Context cancelled, but we already read the data + // Return the data and the error + return n, waitErr + } + } + } + return +} + +// Write tracks bytes sent and applies throttling if enabled. +func (c *Conn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if n > 0 { + c.bytesSent.Add(int64(n)) + + // Apply throttling after write (token bucket wait) + if c.throttler.IsEnabled() { + if waitErr := c.throttler.WaitWrite(c.ctx, n); waitErr != nil { + // Context cancelled, but we already wrote the data + // Return the bytes written and the error + return n, waitErr + } + } + } + return +} + +// Close stops reporting and closes the underlying connection. +func (c *Conn) Close() error { + if c.closed.Swap(true) { + return nil // Already closed + } + + // Stop the reporting ticker + c.reportTicker.Stop() + + // Cancel context to signal goroutines to stop + c.cancel() + + // Wait for all goroutines to finish + c.wg.Wait() + + // Send final report + c.sendReport() + + return c.Conn.Close() +} + +// periodicReport runs in a goroutine and periodically reports data consumption. +func (c *Conn) periodicReport() { + defer c.wg.Done() + for { + select { + case <-c.ctx.Done(): + return + case <-c.reportTicker.C: + c.sendReport() + } + } +} + +// sendReport sends the current consumption data to the sidecar. +func (c *Conn) sendReport() { + c.reportMutex.Lock() + defer c.reportMutex.Unlock() + + // Skip if client is nil (datacap disabled) + if c.client == nil { + return + } + + sent := c.bytesSent.Load() + received := c.bytesReceived.Load() + totalConsumed := sent + received + + // Only report if there's data to report + if totalConsumed == 0 { + return + } + + report := &DataCapReport{ + DeviceID: c.clientInfo.DeviceID, + CountryCode: c.clientInfo.CountryCode, + Platform: c.clientInfo.Platform, + BytesUsed: totalConsumed, + } + + // Use the client's configured timeout for consistency + timeout := c.client.httpClient.Timeout + if timeout == 0 { + timeout = 10 * time.Second // Fallback if client has no timeout set + } + reportCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + status, err := c.client.ReportDataCapConsumption(reportCtx, report) + if err != nil { + // Just log the error, don't fail the connection + c.logger.Debug("failed to report datacap consumption (non-fatal): ", err) + } else { + c.logger.Debug("reported datacap consumption: ", totalConsumed, " bytes (sent: ", sent, ", received: ", received, ") for device ", c.clientInfo.DeviceID) + // Update internal state with response from sidecar + if status != nil { + c.updateThrottleState(status) + } + } +} + +// updateThrottleState updates the throttling configuration based on the current status. +func (c *Conn) updateThrottleState(status *DataCapStatus) { + if !c.throttlingEnabled || c.throttler == nil { + return + } + + if status.Throttle && status.RemainingBytes > 0 && status.CapLimit > 0 { + // Calculate remaining percentage + remainingPct := float64(status.RemainingBytes) / float64(status.CapLimit) + + // Adjust throttle speed based on remaining percentage tiers + var throttleSpeed int64 + if remainingPct > highRemainingThresholdPct { + throttleSpeed = highTierSpeedBytesPerSec + } else if remainingPct > mediumRemainingThresholdPct { + throttleSpeed = mediumTierSpeedBytesPerSec + } else { + throttleSpeed = lowTierSpeedBytesPerSec + } + + c.throttler.EnableWithRates(throttleSpeed, defaultUploadSpeedBytesPerSec) + c.logger.Debug("updated throttle speed to ", throttleSpeed, " bytes/s (remaining: ", remainingPct*100, "%)") + } else { + c.throttler.Disable() + c.logger.Debug("throttling disabled by sidecar") + } +} + +// GetStatus queries the sidecar for current data cap status. +func (c *Conn) GetStatus() (*DataCapStatus, error) { + // Skip if client is nil (datacap disabled) + if c.client == nil { + return nil, nil + } + + // Use the client's configured timeout for consistency + timeout := c.client.httpClient.Timeout + if timeout == 0 { + timeout = 5 * time.Second // Fallback if client has no timeout set + } + statusCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return c.client.GetDataCapStatus(statusCtx, c.clientInfo.DeviceID) +} + +// GetBytesConsumed returns the total bytes consumed by this connection. +func (c *Conn) GetBytesConsumed() int64 { + return c.bytesSent.Load() + c.bytesReceived.Load() +} diff --git a/tracker/datacap/integration_test.go b/tracker/datacap/integration_test.go new file mode 100644 index 0000000..227777b --- /dev/null +++ b/tracker/datacap/integration_test.go @@ -0,0 +1,1135 @@ +package datacap + +import ( + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getlantern/lantern-box/tracker/clientcontext" +) + +// mockConn implements net.Conn for testing +type mockConn struct { + readData []byte + readPos int + writeData []byte + closed bool + mu sync.Mutex +} + +func newMockConn(readData []byte) *mockConn { + return &mockConn{ + readData: readData, + } +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + return 0, io.EOF + } + + if m.readPos >= len(m.readData) { + return 0, io.EOF + } + + n = copy(b, m.readData[m.readPos:]) + m.readPos += n + return n, nil +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + return 0, io.ErrClosedPipe + } + + m.writeData = append(m.writeData, b...) + return len(b), nil +} + +func (m *mockConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func (m *mockConn) GetWrittenData() []byte { + m.mu.Lock() + defer m.mu.Unlock() + return m.writeData +} + +var noopLogger = log.NewNOPFactory().Logger() + +// TestDataCapEndToEndNoThrottling tests the complete datacap workflow without throttling +func TestDataCapEndToEndNoThrottling(t *testing.T) { + // Track reports received + var reportCount atomic.Int32 + var lastBytesUsed atomic.Int64 + + // Mock sidecar server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && r.URL.Path == "/data-cap/test-device" { + // Status check - no throttling + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + } else if r.Method == http.MethodPost && r.URL.Path == "/data-cap/" { + // Consumption report + reportCount.Add(1) + + var report DataCapReport + if err := json.NewDecoder(r.Body).Decode(&report); err == nil { + lastBytesUsed.Store(report.BytesUsed) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + } + })) + defer server.Close() + + // Create datacap client + client := NewClient(server.URL, 5*time.Second) + + // Create mock connection with test data + testData := make([]byte, 1024*100) // 100 KB + for i := range testData { + testData[i] = byte(i % 256) + } + mockConn := newMockConn(testData) + + // Create datacap-wrapped connection + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + CountryCode: "US", + Platform: "android", + }, + Logger: noopLogger, + ReportInterval: 100 * time.Millisecond, // Short interval for testing + EnableThrottling: false, + } + conn := NewConn(config) + + // Read data from connection + buffer := make([]byte, 1024) + totalRead := 0 + for { + n, err := conn.Read(buffer) + totalRead += n + if err == io.EOF { + break + } + require.NoError(t, err) + } + + // Verify we read all data + assert.Equal(t, len(testData), totalRead, "should read all test data") + + // Write data to connection + writeData := make([]byte, 1024*50) // 50 KB + n, err := conn.Write(writeData) + require.NoError(t, err) + assert.Equal(t, len(writeData), n, "should write all data") + + // Wait for at least one periodic report + time.Sleep(200 * time.Millisecond) + + // Close connection (triggers final report) + require.NoError(t, conn.Close()) + + // Wait a bit for final report to complete + time.Sleep(100 * time.Millisecond) + + // Verify reports were sent + assert.GreaterOrEqual(t, reportCount.Load(), int32(1), "should have sent at least one report") + + // Verify last report included all bytes + expectedBytes := int64(totalRead + len(writeData)) + reportedBytes := lastBytesUsed.Load() + assert.Equal(t, expectedBytes, reportedBytes, "last report should include all bytes used") + + // Verify bytes consumed tracking + consumed := conn.GetBytesConsumed() + assert.Equal(t, expectedBytes, consumed, "GetBytesConsumed should match total bytes used") +} + +// TestDataCapEndToEndWithThrottling tests datacap workflow with throttling enabled +func TestDataCapEndToEndWithThrottling(t *testing.T) { + var reportCount atomic.Int32 + + // Mock sidecar server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/data-cap/" { + reportCount.Add(1) + // Report response with throttling enabled + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":true,"remainingBytes":1073741824,"capLimit":10737418240,"expiryTime":1700179200}`)) + } + })) + defer server.Close() + + // Create datacap client + client := NewClient(server.URL, 5*time.Second) + + // Create mock connection + testData := make([]byte, 1024*10) // 10 KB + mockConn := newMockConn(testData) + + // Create datacap-wrapped connection with throttling + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + CountryCode: "US", + Platform: "android", + }, + Logger: noopLogger, + ReportInterval: 100 * time.Millisecond, + EnableThrottling: true, + ThrottleSpeed: 1024 * 10, // 10 KB/s (slow for testing) + } + conn := NewConn(config) + + // Measure time to read data with throttling + startTime := time.Now() + buffer := make([]byte, 1024) + totalRead := 0 + for { + n, err := conn.Read(buffer) + totalRead += n + if err == io.EOF { + break + } + require.NoError(t, err) + } + duration := time.Since(startTime) + + // With 10 KB data and 10 KB/s throttle, should take at least ~1 second + // (accounting for token bucket refill) + if duration < 500*time.Millisecond { + t.Logf("Warning: Read completed in %v, expected throttling to slow it down", duration) + } + + // Wait for periodic report which will update throttle state + time.Sleep(150 * time.Millisecond) + + // Verify at least one report was sent (which includes throttle status in response) + assert.GreaterOrEqual(t, reportCount.Load(), int32(1), "should have sent at least one report") + + // Close connection + conn.Close() + + // Verify reports were sent + assert.GreaterOrEqual(t, reportCount.Load(), int32(1), "should have sent at least one report") +} + +// TestDataCapThrottleSpeedAdjustment tests dynamic throttle speed adjustment +func TestDataCapThrottleSpeedAdjustment(t *testing.T) { + testCases := []struct { + name string + remainingBytes int64 + capLimit int64 + expectedThrottle bool + expectedSpeedTier string // "high", "medium", "low" + }{ + { + name: "High remaining (>20%)", + remainingBytes: 3000000000, // 3 GB + capLimit: 10000000000, // 10 GB + expectedThrottle: true, + expectedSpeedTier: "high", // 5 Mbps + }, + { + name: "Medium remaining (10-20%)", + remainingBytes: 1500000000, // 1.5 GB + capLimit: 10000000000, // 10 GB + expectedThrottle: true, + expectedSpeedTier: "medium", // 2 Mbps + }, + { + name: "Low remaining (<10%)", + remainingBytes: 500000000, // 500 MB + capLimit: 10000000000, // 10 GB + expectedThrottle: true, + expectedSpeedTier: "low", // 128 KB/s + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create mock connection + mockConn := newMockConn(nil) + + // Create datacap-wrapped connection + config := ConnConfig{ + Conn: mockConn, + Client: nil, // No client needed for this test + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + EnableThrottling: true, + } + conn := NewConn(config) + defer conn.Close() + + // Simulate status update + status := &DataCapStatus{ + Throttle: tc.expectedThrottle, + RemainingBytes: tc.remainingBytes, + CapLimit: tc.capLimit, + } + + conn.updateThrottleState(status) + + // Verify throttler is enabled + assert.True(t, conn.throttler.IsEnabled(), "throttling should be enabled") + + // Verify appropriate speed was set (we can't easily check exact speed, + // but we can verify the throttler is active) + t.Logf("Throttle state updated for %s: remaining=%.2f%%", + tc.expectedSpeedTier, + float64(tc.remainingBytes)/float64(tc.capLimit)*100) + }) + } +} + +// TestDataCapPeriodicReporting tests that reports are sent periodically +func TestDataCapPeriodicReporting(t *testing.T) { + var reportTimes []time.Time + var mu sync.Mutex + + // Mock sidecar server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/data-cap/" { + mu.Lock() + reportTimes = append(reportTimes, time.Now()) + mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + } + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + // Create connection with short report interval + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + CountryCode: "US", + Platform: "android", + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, // Very short for testing + } + conn := NewConn(config) + + // Do some I/O to generate data + buffer := make([]byte, 100) + conn.Read(buffer) + conn.Write(buffer) + + // Wait for multiple report intervals + time.Sleep(200 * time.Millisecond) + + conn.Close() + time.Sleep(50 * time.Millisecond) + + // Verify multiple reports were sent + mu.Lock() + count := len(reportTimes) + mu.Unlock() + + assert.GreaterOrEqual(t, count, 2, "should have sent at least 2 reports") + + // Verify reports were spaced approximately by the interval + if count >= 2 { + mu.Lock() + interval := reportTimes[1].Sub(reportTimes[0]) + mu.Unlock() + + if interval < 40*time.Millisecond || interval > 100*time.Millisecond { + t.Logf("Report interval was %v, expected ~50ms (some variance is normal)", interval) + } + } +} + +// TestDataCapFinalReportOnClose tests that a final report is sent when connection closes +func TestDataCapFinalReportOnClose(t *testing.T) { + var finalReport *DataCapReport + var mu sync.Mutex + + // Mock sidecar server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/data-cap/" { + mu.Lock() + defer mu.Unlock() + + var report DataCapReport + if err := json.NewDecoder(r.Body).Decode(&report); err == nil { + // Store the last report (final report) + finalReport = &report + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + } + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 5000) // 5 KB + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device-final", + CountryCode: "US", + Platform: "ios", + }, + Logger: noopLogger, + ReportInterval: time.Hour, // Long interval so only final report happens + } + conn := NewConn(config) + + // Read all data + buffer := make([]byte, 1024) + totalRead := 0 + for { + n, err := conn.Read(buffer) + totalRead += n + if err == io.EOF { + break + } + } + + // Close connection immediately (before periodic report) + conn.Close() + time.Sleep(100 * time.Millisecond) + + // Verify final report was sent with correct data + mu.Lock() + defer mu.Unlock() + + require.NotNil(t, finalReport, "final report should not be nil") + + assert.Equal(t, "test-device-final", finalReport.DeviceID, "device ID should match") + assert.Equal(t, "ios", finalReport.Platform, "platform should match") + assert.Equal(t, int64(totalRead), finalReport.BytesUsed, "bytes used should match total read") +} + +// TestDataCapSidecarUnreachable tests behavior when sidecar is unreachable +func TestDataCapSidecarUnreachable(t *testing.T) { + // Use invalid URL + client := NewClient("http://localhost:99999", 100*time.Millisecond) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + // Should still work even if sidecar is down + buffer := make([]byte, 512) + n, err := conn.Read(buffer) + if err == io.EOF { + err = nil + } + require.NoError(t, err, "read should work even if sidecar is down") + require.NotEqual(t, 0, n, "expected to read some data") + + // Wait for report attempt (will fail silently) + time.Sleep(100 * time.Millisecond) + + // Connection should still close properly + assert.NoError(t, conn.Close(), "close should succeed even if sidecar is down") + assert.Equal(t, int64(n), conn.GetBytesConsumed(), "bytes should be tracked even if sidecar is down") +} + +// TestDataCapSidecarReturnsError tests behavior when sidecar returns HTTP errors +func TestDataCapSidecarReturnsError(t *testing.T) { + errorCount := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + errorCount++ + // Return 500 error + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"internal server error"}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + // Read data + buffer := make([]byte, 512) + conn.Read(buffer) + + // Wait for report attempt + time.Sleep(100 * time.Millisecond) + + conn.Close() + + // Errors should be logged but not crash + assert.GreaterOrEqual(t, errorCount, 1, "should have received at least one error from sidecar") +} + +// TestDataCapNilClient tests behavior when client is nil (datacap disabled) +func TestDataCapNilClient(t *testing.T) { + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: nil, // Datacap disabled + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + // Should still work normally + buffer := make([]byte, 512) + n, err := conn.Read(buffer) + if err == io.EOF { + err = nil + } + require.NoError(t, err, "read should succeed with nil client") + + // Bytes should still be tracked + assert.Equal(t, int64(n), conn.GetBytesConsumed(), "bytes should be tracked even with nil client") + assert.NoError(t, conn.Close(), "close should succeed with nil client") +} + +// TestDataCapZeroBytes tests that zero-byte reports are not sent +func TestDataCapZeroBytes(t *testing.T) { + reportCount := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + reportCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + } + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 0) // Empty data + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + // Don't read or write any data + time.Sleep(150 * time.Millisecond) + + conn.Close() + time.Sleep(50 * time.Millisecond) + + // Should not send reports for zero bytes + assert.Equal(t, 0, reportCount, "should not send reports for zero bytes") +} + +// TestDataCapConcurrentReadWrite tests concurrent reads and writes +func TestDataCapConcurrentReadWrite(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 10000) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 200 * time.Millisecond, + } + conn := NewConn(config) + + var wg sync.WaitGroup + totalRead := atomic.Int64{} + totalWritten := atomic.Int64{} + + // Concurrent readers + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buffer := make([]byte, 100) + for j := 0; j < 10; j++ { + n, err := conn.Read(buffer) + if err != nil && err != io.EOF { + return + } + totalRead.Add(int64(n)) + if n == 0 { + break + } + } + }() + } + + // Concurrent writers + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buffer := make([]byte, 100) + for j := 0; j < 10; j++ { + n, err := conn.Write(buffer) + if err != nil { + return + } + totalWritten.Add(int64(n)) + } + }() + } + + wg.Wait() + time.Sleep(100 * time.Millisecond) + conn.Close() + + // Verify atomic counters handled concurrent access correctly + expected := totalRead.Load() + totalWritten.Load() + actual := conn.GetBytesConsumed() + assert.Equal(t, expected, actual, "GetBytesConsumed should match total read + written bytes") +} + +// TestDataCapMultipleClose tests that closing multiple times is safe +func TestDataCapMultipleClose(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: time.Hour, + } + conn := NewConn(config) + + // Read some data + buffer := make([]byte, 512) + conn.Read(buffer) + + // Close multiple times + // First close should succeed + assert.NoError(t, conn.Close(), "first close should succeed") + // Subsequent closes should be no-op (return nil) + assert.NoError(t, conn.Close(), "second close should be no-op") + assert.NoError(t, conn.Close(), "third close should be no-op") +} + +// TestDataCapThrottleDisableAfterEnable tests disabling throttle after it was enabled +func TestDataCapThrottleDisableAfterEnable(t *testing.T) { + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: nil, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + EnableThrottling: true, + } + conn := NewConn(config) + defer conn.Close() + + // Enable throttling + status1 := &DataCapStatus{ + Throttle: true, + RemainingBytes: 1000000000, + CapLimit: 10000000000, + } + conn.updateThrottleState(status1) + + assert.True(t, conn.throttler.IsEnabled(), "throttling should be enabled") + + // Disable throttling + status2 := &DataCapStatus{ + Throttle: false, + RemainingBytes: 5000000000, + CapLimit: 10000000000, + } + conn.updateThrottleState(status2) + + assert.False(t, conn.throttler.IsEnabled(), "throttling should be disabled") +} + +// TestDataCapEmptyDeviceID tests behavior with empty device ID +func TestDataCapEmptyDeviceID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var report DataCapReport + json.NewDecoder(r.Body).Decode(&report) + + // Verify empty device ID is sent + assert.Empty(t, report.DeviceID, "deviceId should be empty") + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "", // Empty device ID + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + buffer := make([]byte, 512) + conn.Read(buffer) + time.Sleep(100 * time.Millisecond) + conn.Close() + + // Should handle empty device ID gracefully +} + +// TestDataCapLargeDataTransfer tests with large data transfers +func TestDataCapLargeDataTransfer(t *testing.T) { + var lastReport *DataCapReport + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + mu.Lock() + var report DataCapReport + if err := json.NewDecoder(r.Body).Decode(&report); err == nil { + lastReport = &report + } + mu.Unlock() + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + // 10 MB of data + largeData := make([]byte, 10*1024*1024) + mockConn := newMockConn(largeData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 100 * time.Millisecond, + } + conn := NewConn(config) + + // Read all data + buffer := make([]byte, 64*1024) // 64 KB chunks + totalRead := 0 + for { + n, err := conn.Read(buffer) + totalRead += n + if err == io.EOF { + break + } + } + + time.Sleep(150 * time.Millisecond) + conn.Close() + time.Sleep(50 * time.Millisecond) + + // Verify large amounts are tracked correctly + mu.Lock() + defer mu.Unlock() + + require.NotNil(t, lastReport, "last report not sent") + assert.Equal(t, int64(totalRead), lastReport.BytesUsed, "bytes used should match total read") +} + +// TestDataCapRapidOpenClose tests rapid connection open/close cycles +func TestDataCapRapidOpenClose(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + // Open and close many connections rapidly + for i := 0; i < 50; i++ { + testData := make([]byte, 100) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: time.Hour, + } + conn := NewConn(config) + + // Quick read + buffer := make([]byte, 50) + conn.Read(buffer) + + // Immediate close + conn.Close() + } + + // Should not panic or cause issues + time.Sleep(100 * time.Millisecond) +} + +// TestDataCapStatusCheckAfterReport tests that status is updated after reporting +func TestDataCapStatusCheckAfterReport(t *testing.T) { + responseThrottle := false + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + + if r.Method == http.MethodPost { + // After first report, start throttling + responseThrottle = true + } + + throttleStr := "false" + if responseThrottle { + throttleStr = "true" + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":` + throttleStr + `,"remainingBytes":500000000,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + EnableThrottling: true, + } + conn := NewConn(config) + + // Read data + buffer := make([]byte, 512) + conn.Read(buffer) + + // Wait for report (which will get throttle=true response) + time.Sleep(100 * time.Millisecond) + + // Throttle should now be enabled based on report response + assert.True(t, conn.throttler.IsEnabled(), "throttling should be enabled after report response") + + conn.Close() +} + +// TestDataCapDifferentPlatforms tests different platform values +func TestDataCapDifferentPlatforms(t *testing.T) { + platforms := []string{"android", "ios", "windows", "macos", "linux", ""} + + for _, platform := range platforms { + t.Run("Platform_"+platform, func(t *testing.T) { + var receivedPlatform string + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + mu.Lock() + var report DataCapReport + if err := json.NewDecoder(r.Body).Decode(&report); err == nil { + receivedPlatform = report.Platform + } + mu.Unlock() + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + Platform: platform, + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + buffer := make([]byte, 512) + conn.Read(buffer) + time.Sleep(100 * time.Millisecond) + conn.Close() + + mu.Lock() + assert.Equal(t, platform, receivedPlatform, "platform should match") + mu.Unlock() + }) + } +} + +// TestDataCapCountryCodeVariations tests different country codes +func TestDataCapCountryCodeVariations(t *testing.T) { + countryCodes := []string{"US", "GB", "CN", "IN", "BR", ""} + + for _, code := range countryCodes { + t.Run("Country_"+code, func(t *testing.T) { + var receivedCode string + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + mu.Lock() + var report DataCapReport + if err := json.NewDecoder(r.Body).Decode(&report); err == nil { + receivedCode = report.CountryCode + } + mu.Unlock() + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + CountryCode: code, + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + buffer := make([]byte, 512) + conn.Read(buffer) + time.Sleep(100 * time.Millisecond) + conn.Close() + + mu.Lock() + assert.Equal(t, code, receivedCode, "country code should match") + mu.Unlock() + }) + } +} + +// TestDataCapReadWriteErrors tests handling of read/write errors +func TestDataCapReadWriteErrors(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: time.Hour, + } + conn := NewConn(config) + + // Read until EOF + buffer := make([]byte, 512) + for { + _, err := conn.Read(buffer) + if err == io.EOF { + break + } + require.NoError(t, err) + } + + // Try to read after EOF + n, err := conn.Read(buffer) + assert.ErrorIs(t, err, io.EOF, "expected EOF error on read after EOF") + assert.Equal(t, 0, n, "expected 0 bytes after EOF") + + conn.Close() +} + +// TestDataCapContextCancellation tests behavior when context is cancelled +func TestDataCapContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Slow response + time.Sleep(200 * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"throttle":false,"remainingBytes":10737418240,"capLimit":10737418240,"expiryTime":1700179200}`)) + })) + defer server.Close() + + client := NewClient(server.URL, 5*time.Second) + + testData := make([]byte, 1024) + mockConn := newMockConn(testData) + + config := ConnConfig{ + Conn: mockConn, + Client: client, + ClientInfo: clientcontext.ClientInfo{ + DeviceID: "test-device", + }, + Logger: noopLogger, + ReportInterval: 50 * time.Millisecond, + } + conn := NewConn(config) + + buffer := make([]byte, 512) + conn.Read(buffer) + + // Close immediately (cancels context) + conn.Close() + + // Should not panic or hang + time.Sleep(50 * time.Millisecond) +} diff --git a/tracker/datacap/packet_conn.go b/tracker/datacap/packet_conn.go new file mode 100644 index 0000000..45d1de3 --- /dev/null +++ b/tracker/datacap/packet_conn.go @@ -0,0 +1,249 @@ +package datacap + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/getlantern/lantern-box/tracker/clientcontext" +) + +// PacketConn wraps a sing-box network.PacketConn and tracks data consumption for datacap reporting. +type PacketConn struct { + N.PacketConn + ctx context.Context + cancel context.CancelFunc + client *Client + logger log.ContextLogger + + clientInfo clientcontext.ClientInfo + + // Atomic counters for thread-safe tracking + bytesSent atomic.Int64 + bytesReceived atomic.Int64 + + // Reporting control + reportTicker *time.Ticker + reportMutex sync.Mutex + closed atomic.Bool + wg sync.WaitGroup + + // Throttling + throttler *Throttler + throttlingEnabled bool +} + +// PacketConnConfig holds configuration for creating a datacap-tracked packet connection. +type PacketConnConfig struct { + Conn N.PacketConn + Client *Client + Logger log.ContextLogger + ClientInfo clientcontext.ClientInfo + ReportInterval time.Duration + EnableThrottling bool + ThrottleSpeed int64 +} + +// NewPacketConn creates a new datacap-tracked packet connection wrapper. +func NewPacketConn(config PacketConnConfig) *PacketConn { + ctx, cancel := context.WithCancel(context.Background()) + + // Default report interval to 30 seconds if not specified + if config.ReportInterval == 0 { + config.ReportInterval = 30 * time.Second + } + + conn := &PacketConn{ + PacketConn: config.Conn, + ctx: ctx, + cancel: cancel, + client: config.Client, + logger: config.Logger, + clientInfo: config.ClientInfo, + reportTicker: time.NewTicker(config.ReportInterval), + throttler: NewThrottler(config.ThrottleSpeed), + throttlingEnabled: config.EnableThrottling, + } + + // Start periodic reporting goroutine + conn.wg.Add(1) + go conn.periodicReport() + + return conn +} + +// ReadPacket tracks bytes received from packet reading and applies throttling. +func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + dest, err := c.PacketConn.ReadPacket(buffer) + if err != nil { + return dest, err + } + + if buffer.Len() > 0 { + c.bytesReceived.Add(int64(buffer.Len())) + + // Apply throttling after read (token bucket wait) + if c.throttler.IsEnabled() { + if waitErr := c.throttler.WaitRead(c.ctx, buffer.Len()); waitErr != nil { + // Context cancelled, but we already read the data + return dest, waitErr + } + } + } + return dest, nil +} + +// WritePacket tracks bytes sent from packet writing and applies throttling. +func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + packetSize := buffer.Len() + + if packetSize > 0 { + c.bytesSent.Add(int64(packetSize)) + + // Apply throttling before write (token bucket wait) + if c.throttler.IsEnabled() { + if waitErr := c.throttler.WaitWrite(c.ctx, packetSize); waitErr != nil { + return waitErr + } + } + } + + return c.PacketConn.WritePacket(buffer, destination) +} + +// Close stops reporting and closes the underlying packet connection. +func (c *PacketConn) Close() error { + if c.closed.Swap(true) { + return nil // Already closed + } + + // Stop the reporting ticker + c.reportTicker.Stop() + + // Cancel context to signal goroutines to stop + c.cancel() + + // Wait for all goroutines to finish + c.wg.Wait() + + // Send final report + c.sendReport() + + return c.PacketConn.Close() +} + +// periodicReport runs in a goroutine and periodically reports data consumption. +func (c *PacketConn) periodicReport() { + defer c.wg.Done() + for { + select { + case <-c.ctx.Done(): + return + case <-c.reportTicker.C: + c.sendReport() + } + } +} + +// sendReport sends the current consumption data to the sidecar. +func (c *PacketConn) sendReport() { + c.reportMutex.Lock() + defer c.reportMutex.Unlock() + + // Skip if client is nil (datacap disabled) + if c.client == nil { + return + } + + sent := c.bytesSent.Load() + received := c.bytesReceived.Load() + totalConsumed := sent + received + + // Only report if there's data to report + if totalConsumed == 0 { + return + } + + report := &DataCapReport{ + DeviceID: c.clientInfo.DeviceID, + CountryCode: c.clientInfo.CountryCode, + Platform: c.clientInfo.Platform, + BytesUsed: totalConsumed, + } + + // Use the client's configured timeout for consistency + timeout := c.client.httpClient.Timeout + if timeout == 0 { + timeout = 10 * time.Second // Fallback if client has no timeout set + } + reportCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + status, err := c.client.ReportDataCapConsumption(reportCtx, report) + if err != nil { + // Just log the error, don't fail the connection + c.logger.Debug("failed to report datacap consumption (non-fatal): ", err) + } else { + c.logger.Debug("reported datacap consumption: ", totalConsumed, " bytes (sent: ", sent, ", received: ", received, ") for device ", c.clientInfo.DeviceID) + // Update internal state with response from sidecar + if status != nil { + c.updateThrottleState(status) + } + } +} + +// updateThrottleState updates the throttling configuration based on the current status. +func (c *PacketConn) updateThrottleState(status *DataCapStatus) { + if !c.throttlingEnabled || c.throttler == nil { + return + } + + if status.Throttle && status.RemainingBytes > 0 && status.CapLimit > 0 { + // Calculate remaining percentage + remainingPct := float64(status.RemainingBytes) / float64(status.CapLimit) + + var throttleSpeed int64 + if remainingPct > highRemainingThresholdPct { + throttleSpeed = highTierSpeedBytesPerSec + } else if remainingPct > mediumRemainingThresholdPct { + throttleSpeed = mediumTierSpeedBytesPerSec + } else { + throttleSpeed = lowTierSpeedBytesPerSec + } + + c.throttler.EnableWithRates(throttleSpeed, defaultUploadSpeedBytesPerSec) + c.logger.Debug("updated throttle speed to ", throttleSpeed, " bytes/s (remaining: ", remainingPct*100, "%)") + } else { + c.throttler.Disable() + c.logger.Debug("throttling disabled by sidecar") + } +} + +// GetStatus queries the sidecar for current data cap status. +func (c *PacketConn) GetStatus() (*DataCapStatus, error) { + // Skip if client is nil (datacap disabled) + if c.client == nil { + return nil, nil + } + + // Use the client's configured timeout for consistency + timeout := c.client.httpClient.Timeout + if timeout == 0 { + timeout = 5 * time.Second // Fallback if client has no timeout set + } + statusCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return c.client.GetDataCapStatus(statusCtx, c.clientInfo.DeviceID) +} + +// GetBytesConsumed returns the total bytes consumed by this packet connection. +func (c *PacketConn) GetBytesConsumed() int64 { + return c.bytesSent.Load() + c.bytesReceived.Load() +} diff --git a/tracker/datacap/throttler.go b/tracker/datacap/throttler.go new file mode 100644 index 0000000..cac321e --- /dev/null +++ b/tracker/datacap/throttler.go @@ -0,0 +1,192 @@ +package datacap + +import ( + "context" + "sync" + "time" +) + +// Throttler implements bandwidth throttling using a token bucket algorithm. +// This matches the approach used in http-proxy-lantern for consistent behavior. +type Throttler struct { + mu sync.RWMutex + + enabled bool + readRate int64 // Bytes per second for reads + writeRate int64 // Bytes per second for writes + + // Token bucket for reads + readTokens float64 + readCapacity float64 + readLastRefill time.Time + + // Token bucket for writes + writeTokens float64 + writeCapacity float64 + writeLastRefill time.Time +} + +// NewThrottler creates a new throttler with the specified bytes per second limit. +// Both read and write use the same rate initially. +func NewThrottler(bytesPerSec int64) *Throttler { + return NewThrottlerWithRates(bytesPerSec, bytesPerSec) +} + +// NewThrottlerWithRates creates a throttler with separate read and write rates. +// This allows asymmetric throttling (e.g., throttle downloads but not uploads). +func NewThrottlerWithRates(readBytesPerSec, writeBytesPerSec int64) *Throttler { + now := time.Now() + t := &Throttler{ + enabled: readBytesPerSec > 0 || writeBytesPerSec > 0, + readRate: readBytesPerSec, + writeRate: writeBytesPerSec, + readTokens: float64(readBytesPerSec), // Start with full bucket + readCapacity: float64(readBytesPerSec), // 1 second worth of bytes + readLastRefill: now, + writeTokens: float64(writeBytesPerSec), + writeCapacity: float64(writeBytesPerSec), + writeLastRefill: now, + } + return t +} + +// Enable enables throttling with the specified rate for both read and write. +func (t *Throttler) Enable(bytesPerSec int64) { + t.EnableWithRates(bytesPerSec, bytesPerSec) +} + +// EnableWithRates enables throttling with separate read and write rates. +func (t *Throttler) EnableWithRates(readBytesPerSec, writeBytesPerSec int64) { + t.mu.Lock() + defer t.mu.Unlock() + + now := time.Now() + t.enabled = true + t.readRate = readBytesPerSec + t.writeRate = writeBytesPerSec + t.readTokens = float64(readBytesPerSec) + t.readCapacity = float64(readBytesPerSec) + t.readLastRefill = now + t.writeTokens = float64(writeBytesPerSec) + t.writeCapacity = float64(writeBytesPerSec) + t.writeLastRefill = now +} + +// Disable disables throttling. +func (t *Throttler) Disable() { + t.mu.Lock() + defer t.mu.Unlock() + t.enabled = false +} + +// IsEnabled returns whether throttling is enabled. +func (t *Throttler) IsEnabled() bool { + t.mu.RLock() + defer t.mu.RUnlock() + return t.enabled +} + +// WaitRead waits until n bytes can be read according to the rate limit. +// This uses the token bucket algorithm: tokens are added continuously at the +// configured rate, and operations consume tokens. If not enough tokens are +// available, the operation blocks until sufficient tokens accumulate. +func (t *Throttler) WaitRead(ctx context.Context, n int) error { + return t.wait(ctx, n, true) +} + +// WaitWrite waits until n bytes can be written according to the rate limit. +func (t *Throttler) WaitWrite(ctx context.Context, n int) error { + return t.wait(ctx, n, false) +} + +// wait implements the token bucket algorithm for rate limiting. +func (t *Throttler) wait(ctx context.Context, n int, isRead bool) error { + if n <= 0 { + return nil + } + + t.mu.Lock() + + if !t.enabled { + t.mu.Unlock() + return nil + } + + // Select which bucket to use + var tokens *float64 + var capacity *float64 + var lastRefill *time.Time + var rate int64 + + if isRead { + tokens = &t.readTokens + capacity = &t.readCapacity + lastRefill = &t.readLastRefill + rate = t.readRate + } else { + tokens = &t.writeTokens + capacity = &t.writeCapacity + lastRefill = &t.writeLastRefill + rate = t.writeRate + } + + // If rate is 0 or negative, no throttling + if rate <= 0 { + t.mu.Unlock() + return nil + } + + // Refill tokens based on time elapsed + now := time.Now() + elapsed := now.Sub(*lastRefill) + tokensToAdd := elapsed.Seconds() * float64(rate) + *tokens += tokensToAdd + if *tokens > *capacity { + *tokens = *capacity + } + *lastRefill = now + + // Check if we have enough tokens + required := float64(n) + if *tokens >= required { + // Consume tokens and proceed immediately + *tokens -= required + t.mu.Unlock() + return nil + } + + // Not enough tokens - calculate wait time + deficit := required - *tokens + waitTime := time.Duration(deficit / float64(rate) * float64(time.Second)) + + // Consume all available tokens + *tokens = 0 + t.mu.Unlock() + + // Wait for the required time + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitTime): + return nil + } +} + +// GetReadRate returns the current read throttle rate in bytes per second. +func (t *Throttler) GetReadRate() int64 { + t.mu.RLock() + defer t.mu.RUnlock() + return t.readRate +} + +// GetWriteRate returns the current write throttle rate in bytes per second. +func (t *Throttler) GetWriteRate() int64 { + t.mu.RLock() + defer t.mu.RUnlock() + return t.writeRate +} + +// GetBytesPerSecond returns the read throttle rate (for backward compatibility). +func (t *Throttler) GetBytesPerSecond() int64 { + return t.GetReadRate() +} diff --git a/tracker/datacap/throttler_test.go b/tracker/datacap/throttler_test.go new file mode 100644 index 0000000..c469052 --- /dev/null +++ b/tracker/datacap/throttler_test.go @@ -0,0 +1,243 @@ +package datacap + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestThrottler_ZeroOrNegativeRates(t *testing.T) { + ctx := context.Background() + + t.Run("zero rate", func(t *testing.T) { + throttler := NewThrottler(0) + err := throttler.WaitRead(ctx, 100) + assert.NoError(t, err) + }) + + t.Run("negative rate", func(t *testing.T) { + throttler := NewThrottler(-100) + err := throttler.WaitRead(ctx, 100) + assert.NoError(t, err) + }) +} + +func TestThrottler_ContextCancellation(t *testing.T) { + // create a throttler with a very slow rate + throttler := NewThrottler(10) // 10 bytes/sec + + // consume initial tokens + _ = throttler.WaitRead(context.Background(), 10) + + // try to consume more, which should block + ctx, cancel := context.WithCancel(context.Background()) + start := time.Now() + + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + err := throttler.WaitRead(ctx, 100) + + duration := time.Since(start) + + assert.Error(t, err) + assert.Equal(t, context.Canceled, err) + assert.Less(t, duration, 2*time.Second, "Should have cancelled quickly") +} + +func TestThrottler_ConcurrentAccess(t *testing.T) { + rate := int64(1024 * 1024) // 1MB/s + throttler := NewThrottler(rate) + ctx := context.Background() + + var wg sync.WaitGroup + workers := 10 + iterations := 100 + + // Start concurrent readers and writers + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + err := throttler.WaitRead(ctx, 10) + assert.NoError(t, err) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + err := throttler.WaitWrite(ctx, 10) + assert.NoError(t, err) + } + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success + case <-time.After(5 * time.Second): + t.Fatal("Test timed out - potential deadlock") + } +} + +func TestThrottler_TokenRefill(t *testing.T) { + // Rate: 100 bytes/sec + rate := int64(100) + throttler := NewThrottler(rate) + ctx := context.Background() + + // Initial state: bucket full (100 tokens) + // Consume 100 bytes - should be immediate + start := time.Now() + err := throttler.WaitRead(ctx, 100) + assert.NoError(t, err) + assert.WithinDuration(t, start, time.Now(), 10*time.Millisecond) + + // Bucket empty now. + // Consume 50 bytes. + // We need 50 tokens. At 100 bytes/sec, that takes 0.5 seconds. + start = time.Now() + err = throttler.WaitRead(ctx, 50) + assert.NoError(t, err) + + elapsed := time.Since(start) + // Allow small margin of error for scheduler + assert.GreaterOrEqual(t, elapsed.Milliseconds(), int64(450), "Should wait at least ~0.5s") +} + +func TestThrottler_SeparateRates(t *testing.T) { + readRate := int64(1000) + writeRate := int64(10) // Very slow write + + throttler := NewThrottlerWithRates(readRate, writeRate) + ctx := context.Background() + + // Read should be fast + start := time.Now() + err := throttler.WaitRead(ctx, 100) + assert.NoError(t, err) + assert.Less(t, time.Since(start), 100*time.Millisecond) + + // Drain write bucket + _ = throttler.WaitWrite(ctx, 10) + + // Write should be slow (needs 1s for 10 bytes) + start = time.Now() + err = throttler.WaitWrite(ctx, 10) + assert.NoError(t, err) + assert.GreaterOrEqual(t, time.Since(start).Milliseconds(), int64(900)) +} + +func TestThrottler_Disable(t *testing.T) { + throttler := NewThrottler(10) // Slow rate + ctx := context.Background() + + // Prove it's slow first + _ = throttler.WaitRead(ctx, 10) // Drain + start := time.Now() + _ = throttler.WaitRead(ctx, 10) // Wait 1s + assert.GreaterOrEqual(t, time.Since(start).Milliseconds(), int64(900)) + + // Disable it + throttler.Disable() + assert.False(t, throttler.IsEnabled()) + + // Should be instant now + start = time.Now() + err := throttler.WaitRead(ctx, 1000) + assert.NoError(t, err) + assert.Less(t, time.Since(start), 10*time.Millisecond) +} + +func TestThrottler_LargeRead(t *testing.T) { + // Rate: 100 bytes/sec + throttler := NewThrottler(100) + ctx := context.Background() + + // Consume initial full bucket (100 tokens) + _ = throttler.WaitRead(ctx, 100) + + // rate is 100B/s, so reading 300B requires 3s wait + start := time.Now() + err := throttler.WaitRead(ctx, 300) + assert.NoError(t, err) + + elapsed := time.Since(start) + // Expect ~3 seconds + assert.InDelta(t, float64(3000), float64(elapsed.Milliseconds()), 200, "Should wait approximately 3 seconds for large read") +} + +func TestThrottler_EnableDisableRapidly(t *testing.T) { + throttler := NewThrottler(10) + ctx := context.Background() + + for i := 0; i < 1000; i++ { + throttler.Disable() + assert.False(t, throttler.IsEnabled()) + err := throttler.WaitRead(ctx, 100) + assert.NoError(t, err) // Should be instant + + throttler.Enable(10) + assert.True(t, throttler.IsEnabled()) + // Don't wait here otherwise test takes forever, just verify state consistency + } +} + +func TestThrottler_ZeroWait(t *testing.T) { + throttler := NewThrottler(100) + err := throttler.WaitRead(context.Background(), 0) + assert.NoError(t, err) + // Should do nothing and return instantly +} + +func TestThrottler_WaitPrecision(t *testing.T) { + rate := int64(1000) // 1000 bytes/s = 1 byte/ms + throttler := NewThrottler(rate) + ctx := context.Background() + + // Drain bucket + _ = throttler.WaitRead(ctx, 1000) + + // Wait for 500 bytes (should take 500ms) + start := time.Now() + err := throttler.WaitRead(ctx, 500) + assert.NoError(t, err) + + elapsed := time.Since(start) + assert.GreaterOrEqual(t, elapsed.Milliseconds(), int64(450)) + assert.Less(t, elapsed.Milliseconds(), int64(600)) +} + +func TestThrottler_RefillCap(t *testing.T) { + // Rate 1000. Bucket capacity 1000. + throttler := NewThrottler(1000) + + // Wait 2 seconds. Tokens should not exceed capacity (1000). + time.Sleep(2 * time.Second) + + // consume 1500 bytes. if the bucket grew unbounded to 2000+ tokens (2s * 1000B/s), + // this would return instantly. since it's capped at capacity (1000), + // we should drain 1000 and wait for the remaining 500 (~0.5s). + + start := time.Now() + err := throttler.WaitRead(context.Background(), 1500) + assert.NoError(t, err) + + elapsed := time.Since(start) + // Should wait ~0.5s. If it waited 0, then refill cap logic is broken. + assert.Greater(t, elapsed.Milliseconds(), int64(400), "Should have capped token refill") +} diff --git a/tracker/datacap/tracker.go b/tracker/datacap/tracker.go new file mode 100644 index 0000000..3ef49a9 --- /dev/null +++ b/tracker/datacap/tracker.go @@ -0,0 +1,102 @@ +package datacap + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + + "github.com/getlantern/lantern-box/tracker/clientcontext" +) + +var _ (adapter.ConnectionTracker) = (*DatacapTracker)(nil) + +type DatacapTracker struct { + client *Client + logger log.ContextLogger + reportInterval time.Duration + enableThrottling bool + throttleSpeed int64 +} + +type Options struct { + URL string `json:"url,omitempty"` + ReportInterval string `json:"report_interval,omitempty"` + HTTPTimeout string `json:"http_timeout,omitempty"` + EnableThrottling bool `json:"enable_throttling,omitempty"` + ThrottleSpeed int64 `json:"throttle_speed,omitempty"` +} + +func NewDatacapTracker(options Options, logger log.ContextLogger) (*DatacapTracker, error) { + if options.URL == "" { + return nil, E.New("datacap url not defined") + } + // Parse intervals with defaults + reportInterval := 30 * time.Second + if options.ReportInterval != "" { + interval, err := time.ParseDuration(options.ReportInterval) + if err != nil { + return nil, E.New("invalid report_interval: ", err) + } + reportInterval = interval + } + + httpTimeout := 10 * time.Second + if options.HTTPTimeout != "" { + timeout, err := time.ParseDuration(options.HTTPTimeout) + if err != nil { + return nil, E.New("invalid http_timeout: ", err) + } + httpTimeout = timeout + } + return &DatacapTracker{ + client: NewClient(options.URL, httpTimeout), + reportInterval: reportInterval, + enableThrottling: options.EnableThrottling, + throttleSpeed: options.ThrottleSpeed, + logger: logger, + }, nil +} + +func (t *DatacapTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { + info, ok := clientcontext.ClientInfoFromContext(ctx) + if !ok { + // conn is not from a clientcontext-aware client (e.g., not radiance) + return conn + } + if info.IsPro { + return conn + } + return NewConn(ConnConfig{ + Conn: conn, + Client: t.client, + Logger: t.logger, + ClientInfo: info, + ReportInterval: t.reportInterval, + EnableThrottling: t.enableThrottling, + ThrottleSpeed: t.throttleSpeed, + }) +} +func (t *DatacapTracker) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn { + info, ok := clientcontext.ClientInfoFromContext(ctx) + if !ok { + // conn is not from a clientcontext-aware client (e.g., not radiance) + return conn + } + if info.IsPro { + return conn + } + return NewPacketConn(PacketConnConfig{ + Conn: conn, + Client: t.client, + Logger: t.logger, + ClientInfo: info, + ReportInterval: t.reportInterval, + EnableThrottling: t.enableThrottling, + ThrottleSpeed: t.throttleSpeed, + }) +} diff --git a/tracker/datacap/tracker_test.go b/tracker/datacap/tracker_test.go new file mode 100644 index 0000000..1f48952 --- /dev/null +++ b/tracker/datacap/tracker_test.go @@ -0,0 +1,110 @@ +package datacap + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/getlantern/lantern-box/tracker/clientcontext" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Scenario 1: NewDatacapTracker returns error if URL is empty +func TestNewDatacapTracker_MissingURL_ReturnsError(t *testing.T) { + _, err := NewDatacapTracker(Options{URL: ""}, log.NewNOPFactory().Logger()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "url not defined") +} + +// Scenario 2: Datacap URL is present & Client is Pro +func TestRoutedConnection_ProClient_SkipsTracking(t *testing.T) { + tracker, err := NewDatacapTracker(Options{URL: "http://example.com"}, log.NewNOPFactory().Logger()) + require.NoError(t, err) + + mockConn := newMockConn(nil) + ctx := service.ContextWithPtr(context.Background(), &clientcontext.ClientInfo{ + IsPro: true, + }) + + routedConn := tracker.RoutedConnection(ctx, mockConn, adapter.InboundContext{}, nil, nil) + // Should return original connection (skipped) + assert.Equal(t, mockConn, routedConn) +} + +// Scenario 3: Datacap URL present & Free Client & No Cap (Throttle: false) +func TestRoutedConnection_FreeUserNoCap_DisablesThrottling(t *testing.T) { + // Mock server returning Throttle: false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"throttle":false, "remainingBytes": 1000, "capLimit": 1000}`)) + })) + defer server.Close() + + tracker, err := NewDatacapTracker(Options{URL: server.URL, ReportInterval: "100ms"}, log.NewNOPFactory().Logger()) + require.NoError(t, err) + + mockConn := newMockConn(make([]byte, 1024)) + ctx := service.ContextWithPtr(context.Background(), &clientcontext.ClientInfo{ + IsPro: false, + DeviceID: "device-free-nocap", + Platform: "test", + CountryCode: "US", + }) + + routedConn := tracker.RoutedConnection(ctx, mockConn, adapter.InboundContext{}, nil, nil) + // Should return wrapped connection + assert.NotEqual(t, mockConn, routedConn) + + conn := routedConn.(*Conn) + // Read some data to trigger reporting + _, _ = conn.Read(make([]byte, 10)) + + // Wait for report to happen to update status + time.Sleep(200 * time.Millisecond) + + // Verify throttler is disabled + assert.False(t, conn.throttler.IsEnabled(), "Throttler should be disabled for uncapped user") + conn.Close() +} + +// Scenario 4: Datacap URL present & Free Client & With Cap (Throttle: true) +func TestRoutedConnection_FreeUserWithCap_EnablesThrottling(t *testing.T) { + // Mock server returning Throttle: true + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"throttle":true, "remainingBytes": 100, "capLimit": 1000}`)) + })) + defer server.Close() + + tracker, err := NewDatacapTracker(Options{URL: server.URL, ReportInterval: "100ms", EnableThrottling: true}, log.NewNOPFactory().Logger()) + require.NoError(t, err) + + mockConn := newMockConn(make([]byte, 1024)) + ctx := service.ContextWithPtr(context.Background(), &clientcontext.ClientInfo{ + IsPro: false, + DeviceID: "device-free-capped", + Platform: "test", + CountryCode: "US", + }) + + routedConn := tracker.RoutedConnection(ctx, mockConn, adapter.InboundContext{}, nil, nil) + // Should return wrapped connection + assert.NotEqual(t, mockConn, routedConn) + + conn := routedConn.(*Conn) + // Read some data to trigger reporting + _, _ = conn.Read(make([]byte, 10)) + + // Wait for report to happen to update status + time.Sleep(200 * time.Millisecond) + + // Verify throttler is enabled + assert.True(t, conn.throttler.IsEnabled(), "Throttler should be enabled for capped user") + conn.Close() +}