@@ -2,13 +2,15 @@ package fourslash
22
33import (
44 "context"
5+ "errors"
56 "fmt"
67 "io"
78 "maps"
89 "runtime"
910 "slices"
1011 "strconv"
1112 "strings"
13+ "sync"
1214 "testing"
1315 "unicode/utf8"
1416
@@ -36,6 +38,7 @@ import (
3638 "github.com/microsoft/typescript-go/internal/vfs"
3739 "github.com/microsoft/typescript-go/internal/vfs/iovfs"
3840 "github.com/microsoft/typescript-go/internal/vfs/vfstest"
41+ "golang.org/x/sync/errgroup"
3942 "gotest.tools/v3/assert"
4043)
4144
@@ -62,6 +65,10 @@ type FourslashTest struct {
6265 selectionEnd * lsproto.Position
6366
6467 isStradaServer bool // Whether this is a fourslash server test in Strada. !!! Remove once we don't need to diff baselines.
68+
69+ // Async message handling
70+ pendingRequests map [lsproto.ID ]chan * lsproto.ResponseMessage
71+ pendingRequestsMu sync.Mutex
6572}
6673
6774type scriptInfo struct {
@@ -137,13 +144,14 @@ var parseCache = project.ParseCache{
137144 },
138145}
139146
140- func NewFourslash (t * testing.T , capabilities * lsproto.ClientCapabilities , content string ) * FourslashTest {
147+ func NewFourslash (t * testing.T , capabilities * lsproto.ClientCapabilities , content string ) ( * FourslashTest , func ()) {
141148 repo .SkipIfNoTypeScriptSubmodule (t )
142149 if ! bundled .Embedded {
143150 // Without embedding, we'd need to read all of the lib files out from disk into the MapFS.
144151 // Just skip this for now.
145152 t .Skip ("bundled files are not embedded" )
146153 }
154+
147155 fileName := getBaseFileNameFromTest (t ) + tspath .ExtensionTs
148156 testfs := make (map [string ]any )
149157 scriptInfos := make (map [string ]* scriptInfo )
@@ -211,16 +219,6 @@ func NewFourslash(t *testing.T, capabilities *lsproto.ClientCapabilities, conten
211219 ParseCache : & parseCache ,
212220 })
213221
214- go func () {
215- defer func () {
216- outputWriter .Close ()
217- }()
218- err := server .Run (context .TODO ())
219- if err != nil {
220- t .Error ("server error:" , err )
221- }
222- }()
223-
224222 converters := lsconv .NewConverters (lsproto .PositionEncodingKindUTF8 , func (fileName string ) * lsconv.LSPLineMap {
225223 scriptInfo , ok := scriptInfos [fileName ]
226224 if ! ok {
@@ -240,11 +238,26 @@ func NewFourslash(t *testing.T, capabilities *lsproto.ClientCapabilities, conten
240238 converters : converters ,
241239 baselines : make (map [baselineCommand ]* strings.Builder ),
242240 openFiles : make (map [string ]struct {}),
241+ pendingRequests : make (map [lsproto.ID ]chan * lsproto.ResponseMessage ),
243242 }
244243
244+ ctx , cancel := context .WithCancel (t .Context ())
245+ g , ctx := errgroup .WithContext (ctx )
246+
247+ // Start server goroutine
248+ g .Go (func () error {
249+ defer outputWriter .Close ()
250+ return server .Run (ctx )
251+ })
252+
253+ // Start async message router
254+ g .Go (func () error {
255+ return f .messageRouter (ctx )
256+ })
257+
245258 // !!! temporary; remove when we have `handleDidChangeConfiguration`/implicit project config support
246259 // !!! replace with a proper request *after initialize*
247- f .server .SetCompilerOptionsForInferredProjects (t . Context () , compilerOptions )
260+ f .server .SetCompilerOptionsForInferredProjects (ctx , compilerOptions )
248261 f .initialize (t , capabilities )
249262
250263 if testData .isStateBaseliningEnabled () {
@@ -258,11 +271,132 @@ func NewFourslash(t *testing.T, capabilities *lsproto.ClientCapabilities, conten
258271 }
259272
260273 _ , testPath , _ , _ := runtime .Caller (1 )
261- t .Cleanup (func () {
274+ return f , func () {
275+ t .Helper ()
276+ cancel ()
262277 inputWriter .Close ()
278+ if err := g .Wait (); err != nil && ! errors .Is (err , context .Canceled ) {
279+ t .Errorf ("goroutine error: %v" , err )
280+ }
263281 f .verifyBaselines (t , testPath )
264- })
265- return f
282+ }
283+ }
284+
285+ // messageRouter runs in a goroutine and routes incoming messages from the server.
286+ // It handles responses to client requests and server-initiated requests.
287+ func (f * FourslashTest ) messageRouter (ctx context.Context ) error {
288+ for {
289+ if ctx .Err () != nil {
290+ return nil
291+ }
292+
293+ msg , err := f .out .Read ()
294+ if err != nil {
295+ if errors .Is (err , io .EOF ) || ctx .Err () != nil {
296+ return nil
297+ }
298+ return fmt .Errorf ("failed to read message: %w" , err )
299+ }
300+
301+ // Validate message can be marshaled
302+ if err := json .MarshalWrite (io .Discard , msg ); err != nil {
303+ if ctx .Err () != nil {
304+ return nil
305+ }
306+
307+ return fmt .Errorf ("failed to encode message as JSON: %w" , err )
308+ }
309+
310+ switch msg .Kind {
311+ case lsproto .MessageKindResponse :
312+ f .handleResponse (ctx , msg .AsResponse ())
313+ case lsproto .MessageKindRequest :
314+ if err := f .handleServerRequest (ctx , msg .AsRequest ()); err != nil {
315+ return err
316+ }
317+ case lsproto .MessageKindNotification :
318+ // Server-initiated notifications (e.g., publishDiagnostics) are currently ignored
319+ // in fourslash tests
320+ }
321+ }
322+ }
323+
324+ // handleResponse routes a response message to the waiting request goroutine.
325+ func (f * FourslashTest ) handleResponse (ctx context.Context , resp * lsproto.ResponseMessage ) {
326+ if resp .ID == nil {
327+ return
328+ }
329+
330+ f .pendingRequestsMu .Lock ()
331+ respChan , ok := f .pendingRequests [* resp .ID ]
332+ if ok {
333+ delete (f .pendingRequests , * resp .ID )
334+ }
335+ f .pendingRequestsMu .Unlock ()
336+
337+ if ok {
338+ select {
339+ case respChan <- resp :
340+ // sent response
341+ case <- ctx .Done ():
342+ // context cancelled
343+ }
344+ }
345+ }
346+
347+ // handleServerRequest handles requests initiated by the server (e.g., workspace/configuration).
348+ func (f * FourslashTest ) handleServerRequest (ctx context.Context , req * lsproto.RequestMessage ) error {
349+ var response * lsproto.ResponseMessage
350+
351+ switch req .Method {
352+ case lsproto .MethodWorkspaceConfiguration :
353+ // Return current user preferences
354+ response = & lsproto.ResponseMessage {
355+ ID : req .ID ,
356+ JSONRPC : req .JSONRPC ,
357+ Result : []any {f .userPreferences },
358+ }
359+
360+ case lsproto .MethodClientRegisterCapability :
361+ // Accept all capability registrations
362+ response = & lsproto.ResponseMessage {
363+ ID : req .ID ,
364+ JSONRPC : req .JSONRPC ,
365+ Result : lsproto.Null {},
366+ }
367+
368+ case lsproto .MethodClientUnregisterCapability :
369+ // Accept all capability unregistrations
370+ response = & lsproto.ResponseMessage {
371+ ID : req .ID ,
372+ JSONRPC : req .JSONRPC ,
373+ Result : lsproto.Null {},
374+ }
375+
376+ default :
377+ // Unknown server request
378+ response = & lsproto.ResponseMessage {
379+ ID : req .ID ,
380+ JSONRPC : req .JSONRPC ,
381+ Error : & lsproto.ResponseError {
382+ Code : int32 (lsproto .ErrorCodeMethodNotFound ),
383+ Message : fmt .Sprintf ("Unknown method: %s" , req .Method ),
384+ },
385+ }
386+ }
387+
388+ // Send response back to server
389+ if ctx .Err () != nil {
390+ return nil
391+ }
392+
393+ if err := f .in .Write (response .Message ()); err != nil {
394+ if ctx .Err () != nil {
395+ return nil
396+ }
397+ return fmt .Errorf ("failed to write server request response: %w" , err )
398+ }
399+ return nil
266400}
267401
268402func getBaseFileNameFromTest (t * testing.T ) string {
@@ -300,16 +434,22 @@ func (f *FourslashTest) initialize(t *testing.T, capabilities *lsproto.ClientCap
300434 params := & lsproto.InitializeParams {
301435 Locale : ptrTo ("en-US" ),
302436 InitializationOptions : & lsproto.InitializationOptions {
303- // Hack: disable push diagnostics entirely, since the fourslash runner does not
304- // yet gracefully handle non-request messages.
305- DisablePushDiagnostics : ptrTo (true ),
306437 CodeLensShowLocationsCommandName : ptrTo (showCodeLensLocationsCommandName ),
307438 },
308439 }
309440 params .Capabilities = getCapabilitiesWithDefaults (capabilities )
310- // !!! check for errors?
311- sendRequestWorker (t , f , lsproto .InitializeInfo , params )
441+ resp , _ , ok := sendRequestWorker (t , f , lsproto .InitializeInfo , params )
442+ if ! ok {
443+ t .Fatalf ("Initialize request failed" )
444+ }
445+ if resp .AsResponse ().Error != nil {
446+ t .Fatalf ("Initialize request returned error: %s" , resp .AsResponse ().Error .String ())
447+ }
312448 sendNotificationWorker (t , f , lsproto .InitializedInfo , & lsproto.InitializedParams {})
449+
450+ // Wait for the initial configuration exchange to complete
451+ // The server will send workspace/configuration as part of handleInitialized
452+ <- f .server .InitComplete ()
313453}
314454
315455var (
@@ -412,45 +552,36 @@ func getCapabilitiesWithDefaults(capabilities *lsproto.ClientCapabilities) *lspr
412552
413553func sendRequestWorker [Params , Resp any ](t * testing.T , f * FourslashTest , info lsproto.RequestInfo [Params , Resp ], params Params ) (* lsproto.Message , Resp , bool ) {
414554 id := f .nextID ()
415- req := info .NewRequestMessage (
416- lsproto .NewID (lsproto.IntegerOrString {Integer : & id }),
417- params ,
418- )
419- f .writeMsg (t , req .Message ())
420- resp := f .readMsg (t )
421- if resp == nil {
422- return nil , * new (Resp ), false
423- }
555+ reqID := lsproto .NewID (lsproto.IntegerOrString {Integer : & id })
556+ req := info .NewRequestMessage (reqID , params )
424557
425- // currently, the only request that may be sent by the server during a client request is one `config` request
426- // !!! remove if `config` is handled in initialization and there are no other server-initiated requests
427- if resp .Kind == lsproto .MessageKindRequest {
428- req := resp .AsRequest ()
558+ // Create response channel and register it
559+ responseChan := make (chan * lsproto.ResponseMessage , 1 )
560+ f .pendingRequestsMu .Lock ()
561+ f .pendingRequests [* reqID ] = responseChan
562+ f .pendingRequestsMu .Unlock ()
429563
430- assert .Equal (t , req .Method , lsproto .MethodWorkspaceConfiguration , "Unexpected request received: %s" , req .Method )
431- res := lsproto.ResponseMessage {
432- ID : req .ID ,
433- JSONRPC : req .JSONRPC ,
434- Result : []any {f .userPreferences },
435- }
436- f .writeMsg (t , res .Message ())
437- req = f .readMsg (t ).AsRequest ()
564+ // Send the request
565+ f .writeMsg (t , req .Message ())
438566
439- assert .Equal (t , req .Method , lsproto .MethodClientRegisterCapability , "Unexpected request received: %s" , req .Method )
440- res = lsproto.ResponseMessage {
441- ID : req .ID ,
442- JSONRPC : req .JSONRPC ,
443- Result : lsproto.Null {},
567+ // Wait for response with context
568+ ctx := t .Context ()
569+ var resp * lsproto.ResponseMessage
570+ select {
571+ case <- ctx .Done ():
572+ f .pendingRequestsMu .Lock ()
573+ delete (f .pendingRequests , * reqID )
574+ f .pendingRequestsMu .Unlock ()
575+ t .Fatalf ("Request cancelled: %v" , ctx .Err ())
576+ return nil , * new (Resp ), false
577+ case resp = <- responseChan :
578+ if resp == nil {
579+ return nil , * new (Resp ), false
444580 }
445- f .writeMsg (t , res .Message ())
446- resp = f .readMsg (t )
447581 }
448582
449- if resp == nil {
450- return nil , * new (Resp ), false
451- }
452- result , ok := resp .AsResponse ().Result .(Resp )
453- return resp , result , ok
583+ result , ok := resp .Result .(Resp )
584+ return resp .Message (), result , ok
454585}
455586
456587func sendNotificationWorker [Params any ](t * testing.T , f * FourslashTest , info lsproto.NotificationInfo [Params ], params Params ) {
@@ -467,16 +598,6 @@ func (f *FourslashTest) writeMsg(t *testing.T, msg *lsproto.Message) {
467598 }
468599}
469600
470- func (f * FourslashTest ) readMsg (t * testing.T ) * lsproto.Message {
471- // !!! filter out response by id etc
472- msg , err := f .out .Read ()
473- if err != nil {
474- t .Fatalf ("failed to read message: %v" , err )
475- }
476- assert .NilError (t , json .MarshalWrite (io .Discard , msg ), "failed to encode message as JSON" )
477- return msg
478- }
479-
480601func sendRequest [Params , Resp any ](t * testing.T , f * FourslashTest , info lsproto.RequestInfo [Params , Resp ], params Params ) Resp {
481602 t .Helper ()
482603 prefix := f .getCurrentPositionPrefix ()
0 commit comments