Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/mongoproxy/plugins/authz/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func (p *AuthzPlugin) resourcesForCommand(r *plugins.Request, c command.Command)
}

case *command.GetMore:
cursorResources := r.CursorCache.GetCursor(cmd.CursorID).Map[contextKeyResources]
cursorResources := r.CursorCache.GetCursor(cmd.CursorID, r.GetClientInfo()).Map[contextKeyResources]
if cr, ok := cursorResources.(map[authzlib.AuthorizationMethod][]authzlib.Resource); ok {
return cr
}
Expand Down Expand Up @@ -592,7 +592,7 @@ func (p *AuthzPlugin) Process(ctx context.Context, r *plugins.Request, next plug
result, err := next(ctx, r)
if cursorIDRaw, ok := bsonutil.Lookup(result, "cursor", "id"); ok {
if cursorID, ok := cursorIDRaw.(int64); ok && cursorID > 0 {
r.CursorCache.GetCursor(cursorID).Map[contextKeyResources] = resourceMap
r.CursorCache.GetCursor(cursorID, r.GetClientInfo()).Map[contextKeyResources] = resourceMap
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/mongoproxy/plugins/authz/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestPluginGetMore(t *testing.T) {
p := plugins.BuildPipeline([]plugins.Plugin{d}, func(_ context.Context, r *plugins.Request) (bson.D, error) {
switch r.Command.(type) {
case *command.Find:
r.CursorCache.GetCursor(cursorID)
r.CursorCache.GetCursor(cursorID, r.GetClientInfo())
return bson.D{
{"ok", 1},
{"cursor", bson.D{{"id", cursorID}}},
Expand Down
31 changes: 28 additions & 3 deletions pkg/mongoproxy/plugins/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package plugins
import (
"context"
"net"
"strings"

"go.mongodb.org/mongo-driver/bson"

Expand All @@ -25,20 +26,22 @@ type Plugin interface {
Process(context.Context, *Request, PipelineFunc) (bson.D, error)
}

func NewCursorCacheEntry(id int64) *CursorCacheEntry {
func NewCursorCacheEntry(id int64, clientInfo string) *CursorCacheEntry {
return &CursorCacheEntry{
ID: id,
ClientInfo: clientInfo,
Map: map[interface{}]interface{}{},
}
}

type CursorCache interface {
GetCursor(cursorID int64) *CursorCacheEntry
CloseCursor(cursorID int64)
GetCursor(cursorID int64, clientInfo string) *CursorCacheEntry
CloseCursor(cursorID int64, clientInfo string)
}

type CursorCacheEntry struct {
ID int64
ClientInfo string
CursorConsumed int

// Map is storage that resets on cursor change
Expand All @@ -59,6 +62,10 @@ type Request struct {
Map map[string]interface{}
}

func (r *Request) GetClientInfo() string {
return r.CC.GetClientInfo()
}

func (r *Request) Close() {}

func NewClientConnection() *ClientConnection {
Expand All @@ -77,6 +84,19 @@ type ClientConnection struct {
// Map is storage that resets on cursor change
Map map[interface{}]interface{}
}
func (c *ClientConnection) GetUsername() string {
var usernames []string
for _, identity := range c.Identities {
usernames = append(usernames, identity.User())
}
var username string
if len(usernames) > 0 {
username = strings.Join(usernames, ",")
} else {
username = "unknown"
}
return username
}

func (c *ClientConnection) GetAddr() string {
if c.Addr == nil {
Expand All @@ -85,6 +105,11 @@ func (c *ClientConnection) GetAddr() string {
return c.Addr.String()
}

func (c *ClientConnection) GetClientInfo() string {
//todo @jiapeng use username
return c.GetAddr()
}

func (c *ClientConnection) Close() {}

type ClientIdentity interface {
Expand Down
8 changes: 4 additions & 4 deletions pkg/mongoproxy/plugins/mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func (p *MongoPlugin) Process(ctx context.Context, r *plugins.Request, next plug
if cursorID, ok := cursorIDRaw.(int64); ok && cursorID > 0 {
logrus.Tracef("Store cursor: %v %v", cursorID, cmdServer)
// TODO: TTL from cmd
r.CursorCache.GetCursor(cursorID).Map[contextKeyServer] = cmdServer
r.CursorCache.GetCursor(cursorID, r.GetClientInfo()).Map[contextKeyServer] = cmdServer
}
}
}
Expand Down Expand Up @@ -473,7 +473,7 @@ func (p *MongoPlugin) Process(ctx context.Context, r *plugins.Request, next plug
cmd.Database = ""

// TODO: move into runCommand?
v, ok := r.CursorCache.GetCursor(cmd.CursorID).Map[contextKeyServer]
v, ok := r.CursorCache.GetCursor(cmd.CursorID, r.GetClientInfo()).Map[contextKeyServer]
if !ok {
return mongoerror.CursorNotFound.ErrMessage("Cursor not found."), nil
}
Expand All @@ -482,7 +482,7 @@ func (p *MongoPlugin) Process(ctx context.Context, r *plugins.Request, next plug

if cursorIDRaw, ok := bsonutil.Lookup(result, "cursor", "id"); ok {
if cursorID, ok := cursorIDRaw.(int64); ok && cursorID == 0 {
r.CursorCache.CloseCursor(cmd.CursorID)
r.CursorCache.CloseCursor(cmd.CursorID, r.GetClientInfo())
}
}

Expand Down Expand Up @@ -522,7 +522,7 @@ func (p *MongoPlugin) Process(ctx context.Context, r *plugins.Request, next plug
if !ok {
return nil, fmt.Errorf("invalid cursorID")
}
v, ok := r.CursorCache.GetCursor(cursorID).Map[contextKeyServer]
v, ok := r.CursorCache.GetCursor(cursorID, r.GetClientInfo()).Map[contextKeyServer]
if !ok {
return mongoerror.CursorNotFound.ErrMessage("Cursor not found."), nil
}
Expand Down
40 changes: 32 additions & 8 deletions pkg/mongoproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"reflect"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -47,6 +48,14 @@ var (
Name: "mongoproxy_client_message_total",
Help: "The total number of messages from clients",
}, []string{"opcode"})
clientCursorGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "mongoproxy_client_cursors_open",
Help: "The current number of open cursors",
}, []string{"clientInfo"})
clientCursorCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "mongoproxy_client_cursors_total",
Help: "The total number of client cursors",
}, []string{"clientInfo"})

ErrServerClosed = errors.New("server closed")
SKIP_RECOVER = false
Expand Down Expand Up @@ -91,17 +100,31 @@ func NewProxy(l net.Listener, cfg *config.Config) (*Proxy, error) {

// Set up cursorCache
p.cursorCache.SetTTL(p.cfg.IdleCursorTimeout) // default TTL -- config
p.cursorCache.SetLoaderFunction(func(key string) (interface{}, time.Duration, error) {
p.cursorCache.SetLoaderFunction(func(clientKey string) (interface{}, time.Duration, error) {
keys := strings.SplitN(clientKey, "_", 2)
if keys == nil || len(keys) != 2 {
return nil, time.Duration(0), errors.New("Illegal cursor key " + clientKey)
}
key := keys[0]
clientInfo := keys[1]
cursorID, err := strconv.ParseInt(key, 10, 64)
if err != nil {
return nil, time.Duration(0), err
}
clientCursorCounter.WithLabelValues(clientInfo).Inc()
clientCursorGauge.WithLabelValues(clientInfo).Inc()

return plugins.NewCursorCacheEntry(cursorID), time.Duration(0), nil
return plugins.NewCursorCacheEntry(cursorID, clientInfo), time.Duration(0), nil
})
// expiration handler to send killCursor commands
p.cursorCache.SetExpirationReasonCallback(func(key string, reason ttlcache.EvictionReason, value interface{}) {
logrus.Tracef("expire cursor %s", key)
p.cursorCache.SetExpirationReasonCallback(func(clientKey string, reason ttlcache.EvictionReason, value interface{}) {
logrus.Tracef("expire cursor %s", clientKey)
keys := strings.SplitN(clientKey, "_", 2)
if keys == nil || len(keys) != 2 {
return
}
key := keys[0]
clientInfo := keys[1]
i, err := strconv.ParseInt(key, 10, 64)
if err != nil {
return
Expand All @@ -114,6 +137,7 @@ func NewProxy(l net.Listener, cfg *config.Config) (*Proxy, error) {
{"cursors", primitive.A{i}},
})
}
clientCursorGauge.WithLabelValues(clientInfo).Dec()
})

return p, nil
Expand All @@ -138,17 +162,17 @@ type Proxy struct {
internalCC *plugins.ClientConnection
}

func (p *Proxy) GetCursor(cursorID int64) *plugins.CursorCacheEntry {
v, err := p.cursorCache.Get(strconv.FormatInt(cursorID, 10))
func (p *Proxy) GetCursor(cursorID int64, clientInfo string) *plugins.CursorCacheEntry {
v, err := p.cursorCache.Get(strconv.FormatInt(cursorID, 10) + "_" + clientInfo)
if err == ttlcache.ErrNotFound {
panic("can't get cursor")
}

return v.(*plugins.CursorCacheEntry)
}

func (p *Proxy) CloseCursor(cursorID int64) {
p.cursorCache.Remove(strconv.FormatInt(cursorID, 10))
func (p *Proxy) CloseCursor(cursorID int64, clientInfo string) {
p.cursorCache.Remove(strconv.FormatInt(cursorID, 10) + "_" + clientInfo)
}

func (p *Proxy) Addr() string {
Expand Down
12 changes: 6 additions & 6 deletions pkg/mongoproxy/proxy_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (p *Proxy) handleOpQuery(ctx context.Context, cc *plugins.ClientConnection,
}

if cursorID, ok := bsonutil.Lookup(result, "cursor", "id"); ok {
p.GetCursor(cursorID.(int64)).CursorConsumed += len(v.(primitive.A))
p.GetCursor(cursorID.(int64), cc.GetClientInfo()).CursorConsumed += len(v.(primitive.A))
}

reply.Documents = []bson.D{{{"ok", 1}, {"result", v}}}
Expand All @@ -164,7 +164,7 @@ func (p *Proxy) handleOpQuery(ctx context.Context, cc *plugins.ClientConnection,
}
var cursorEntry *plugins.CursorCacheEntry
if cursorID, ok := bsonutil.Lookup(cursorData, "id"); ok {
cursorEntry = p.GetCursor(cursorID.(int64))
cursorEntry = p.GetCursor(cursorID.(int64), cc.GetClientInfo())
}
firstBatchRaw, ok := bsonutil.Lookup(cursorData, "firstBatch")
if ok {
Expand Down Expand Up @@ -283,7 +283,7 @@ func (p *Proxy) handleOpQuery(ctx context.Context, cc *plugins.ClientConnection,
}
var cursorEntry *plugins.CursorCacheEntry
if cursorID, ok := bsonutil.Lookup(cursorData, "id"); ok {
cursorEntry = p.GetCursor(cursorID.(int64))
cursorEntry = p.GetCursor(cursorID.(int64), cc.GetClientInfo())
reply.CursorID = cursorID.(int64)
}
firstBatchRaw, ok := bsonutil.Lookup(cursorData, "firstBatch")
Expand Down Expand Up @@ -318,7 +318,7 @@ func (p *Proxy) handleOpKillCursors(ctx context.Context, cc *plugins.ClientConne
if err == nil {
if bsonutil.Ok(result) {
for _, cursorID := range q.CursorIDs {
p.CloseCursor(cursorID)
p.CloseCursor(cursorID, request.GetClientInfo())
}
}
}
Expand All @@ -339,7 +339,7 @@ func (p *Proxy) handleOpGetMore(ctx context.Context, cc *plugins.ClientConnectio
Header: q.Header,
}

cursorEntry := p.GetCursor(q.CursorID)
cursorEntry := p.GetCursor(q.CursorID, cc.GetClientInfo())

result, err := p.HandleMongo(ctx, request, []primitive.E{
{Key: "getMore", Value: q.CursorID},
Expand All @@ -360,7 +360,7 @@ func (p *Proxy) handleOpGetMore(ctx context.Context, cc *plugins.ClientConnectio
if cursorID, ok := bsonutil.Lookup(cursorData, "id"); ok {
reply.CursorID = cursorID.(int64)
if cursorID.(int64) == 0 {
p.CloseCursor(cursorEntry.ID)
p.CloseCursor(cursorEntry.ID, request.GetClientInfo())
}
}
nextBatchRaw, ok := bsonutil.Lookup(cursorData, "nextBatch")
Expand Down