1
1
import * as os from 'node:os'
2
2
import * as path from 'node:path'
3
- import type { CLINode } from '@/workflow/components/nodes/CLI_Node'
4
3
import type { LLMNode } from '@/workflow/components/nodes/LLM_Node'
5
4
import type { LoopStartNode } from '@/workflow/components/nodes/LoopStart_Node'
6
5
import type { SearchContextNode } from '@/workflow/components/nodes/SearchContext_Node'
@@ -138,18 +137,8 @@ export async function executeWorkflow(
138
137
switch ( node . type ) {
139
138
case NodeType . CLI : {
140
139
try {
141
- const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
142
- /*.map(
143
- output => sanitizeForShell(output)
144
- )*/
145
- const command = ( node as CLINode ) . data . content
146
- ? replaceIndexedInputs ( ( node as CLINode ) . data . content , inputs , context )
147
- : ''
148
140
result = await executeCLINode (
149
- {
150
- ...( node as CLINode ) ,
151
- data : { ...( node as CLINode ) . data , content : command } ,
152
- } ,
141
+ node ,
153
142
abortSignal ,
154
143
persistentShell ,
155
144
webview ,
@@ -176,50 +165,26 @@ export async function executeWorkflow(
176
165
break
177
166
}
178
167
case NodeType . LLM : {
179
- const inputs = combineParentOutputsByConnectionOrder ( node . id , context ) . map ( input =>
180
- sanitizeForPrompt ( input )
181
- )
182
- const prompt = node . data . content
183
- ? replaceIndexedInputs ( node . data . content , inputs , context )
184
- : ''
185
-
186
- const oldTemperature = await chatClient . getTemperature ( )
187
- await chatClient . setTemperature ( ( node as LLMNode ) . data . temperature )
188
- result = await executeLLMNode (
189
- { ...node , data : { ...node . data , content : prompt } } ,
190
- chatClient ,
191
- abortSignal
192
- )
193
- await chatClient . setTemperature ( oldTemperature )
168
+ try {
169
+ result = await executeLLMNode ( node , chatClient , abortSignal , context )
170
+ } catch ( error ) {
171
+ console . error ( 'Error in LLM Node:' , error )
172
+ throw error
173
+ }
194
174
break
195
175
}
196
176
case NodeType . PREVIEW : {
197
- const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
198
- result = await executePreviewNode ( inputs . join ( '\n' ) , node . id , webview , context )
177
+ result = await executePreviewNode ( node . id , webview , context )
199
178
break
200
179
}
201
180
202
181
case NodeType . INPUT : {
203
- const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
204
- const text = node . data . content
205
- ? replaceIndexedInputs ( node . data . content , inputs , context )
206
- : ''
207
- result = await executeInputNode ( text )
182
+ result = await executeInputNode ( node , context )
208
183
break
209
184
}
210
185
211
186
case NodeType . SEARCH_CONTEXT : {
212
- const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
213
- const text = node . data . content
214
- ? replaceIndexedInputs ( node . data . content , inputs , context )
215
- : ''
216
- const allowRemoteContext = ( node as SearchContextNode ) . data . local_remote
217
- result = await executeSearchContextNode (
218
- text ,
219
- contextRetriever ,
220
- abortSignal ,
221
- allowRemoteContext
222
- )
187
+ result = await executeSearchContextNode ( node , contextRetriever , abortSignal , context )
223
188
break
224
189
}
225
190
case NodeType . CODY_OUTPUT : {
@@ -276,8 +241,7 @@ export async function executeWorkflow(
276
241
break
277
242
}
278
243
case NodeType . LOOP_END : {
279
- const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
280
- result = await executePreviewNode ( inputs . join ( '\n' ) , node . id , webview , context )
244
+ result = await executePreviewNode ( node . id , webview , context )
281
245
break
282
246
}
283
247
@@ -459,13 +423,13 @@ export function replaceIndexedInputs(
459
423
*/
460
424
export function combineParentOutputsByConnectionOrder (
461
425
nodeId : string ,
462
- context : IndexedExecutionContext
426
+ context ? : IndexedExecutionContext
463
427
) : string [ ] {
464
- const parentEdges = context . edgeIndex . byTarget . get ( nodeId ) || [ ]
428
+ const parentEdges = context ? .edgeIndex . byTarget . get ( nodeId ) || [ ]
465
429
466
430
return parentEdges
467
431
. map ( edge => {
468
- let output = context . nodeOutputs . get ( edge . source )
432
+ let output = context ? .nodeOutputs . get ( edge . source )
469
433
if ( Array . isArray ( output ) ) {
470
434
output = output . join ( '\n' )
471
435
}
@@ -504,15 +468,14 @@ export async function executeCLINode(
504
468
if ( ! vscode . env . shell || ! vscode . workspace . isTrusted ) {
505
469
throw new Error ( 'Shell command is not supported in your current workspace.' )
506
470
}
507
- // Add validation for empty commands
508
- if ( ! node . data . content ?. trim ( ) ) {
471
+ const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
472
+ const command = node . data . content ? replaceIndexedInputs ( node . data . content , inputs , context ) : ''
473
+ if ( ! command . trim ( ) ) {
509
474
throw new Error ( 'CLI Node requires a non-empty command' )
510
475
}
511
476
512
477
const homeDir = os . homedir ( ) || process . env . HOME || process . env . USERPROFILE || ''
513
-
514
- let filteredCommand =
515
- ( node as CLINode ) . data . content ?. replaceAll ( / ( \s ~ \/ ) / g, ` ${ homeDir } ${ path . sep } ` ) || ''
478
+ let filteredCommand = command . replaceAll ( / ( \s ~ \/ ) / g, ` ${ homeDir } ${ path . sep } ` ) || ''
516
479
517
480
if ( node . data . needsUserApproval ) {
518
481
await webview . postMessage ( {
@@ -536,7 +499,7 @@ export async function executeCLINode(
536
499
537
500
try {
538
501
const { output, exitCode } = await persistentShell . execute ( filteredCommand , abortSignal )
539
- if ( exitCode !== '0' && ( node as CLINode ) . data . shouldAbort ) {
502
+ if ( exitCode !== '0' && node . data . shouldAbort ) {
540
503
throw new Error ( output )
541
504
}
542
505
context ?. cliMetadata ?. set ( node . id , { exitCode : exitCode } )
@@ -568,10 +531,18 @@ export async function executeCLINode(
568
531
async function executeLLMNode (
569
532
node : WorkflowNodes ,
570
533
chatClient : ChatClient ,
571
- abortSignal ?: AbortSignal
534
+ abortSignal ?: AbortSignal ,
535
+ context ?: IndexedExecutionContext
572
536
) : Promise < string > {
573
537
abortSignal ?. throwIfAborted ( )
574
- if ( ! node . data . content ) {
538
+ const oldTemperature = await chatClient . getTemperature ( )
539
+ await chatClient . setTemperature ( ( node as LLMNode ) . data . temperature )
540
+
541
+ const inputs = combineParentOutputsByConnectionOrder ( node . id , context ) . map ( input =>
542
+ sanitizeForPrompt ( input )
543
+ )
544
+ const prompt = node . data . content ? replaceIndexedInputs ( node . data . content , inputs , context ) : ''
545
+ if ( ! prompt || prompt . trim ( ) === '' ) {
575
546
throw new Error ( `No prompt specified for LLM node ${ node . id } with ${ node . data . title } ` )
576
547
}
577
548
@@ -589,7 +560,7 @@ async function executeLLMNode(
589
560
...preamble ,
590
561
{
591
562
speaker : 'human' ,
592
- text : PromptString . unsafe_fromUserQuery ( node . data . content ) ,
563
+ text : PromptString . unsafe_fromUserQuery ( prompt ) ,
593
564
} ,
594
565
]
595
566
@@ -628,14 +599,16 @@ async function executeLLMNode(
628
599
}
629
600
}
630
601
} catch ( error ) {
602
+ await chatClient . setTemperature ( oldTemperature )
631
603
reject ( error )
632
604
}
633
605
} )
634
606
. catch ( reject )
635
607
} )
636
-
608
+ await chatClient . setTemperature ( oldTemperature )
637
609
return await Promise . race ( [ streamPromise , timeout ] )
638
610
} catch ( error ) {
611
+ await chatClient . setTemperature ( oldTemperature )
639
612
if ( error instanceof Error ) {
640
613
if ( error . name === 'AbortError' ) {
641
614
throw new Error ( 'Workflow execution aborted' )
@@ -658,11 +631,11 @@ async function executeLLMNode(
658
631
* @returns The trimmed input string.
659
632
*/
660
633
async function executePreviewNode (
661
- input : string ,
662
634
nodeId : string ,
663
635
webview : vscode . Webview ,
664
636
context : IndexedExecutionContext
665
637
) : Promise < string > {
638
+ const input = combineParentOutputsByConnectionOrder ( nodeId , context ) . join ( '\n' )
666
639
const processedInput = replaceIndexedInputs ( input , [ ] , context )
667
640
const trimmedInput = processedInput . trim ( )
668
641
const tokenCount = await TokenCounterUtils . encode ( trimmedInput )
@@ -686,8 +659,10 @@ async function executePreviewNode(
686
659
* @param input - The input string to be processed.
687
660
* @returns The trimmed input string.
688
661
*/
689
- async function executeInputNode ( input : string ) : Promise < string > {
690
- return input . trim ( )
662
+ async function executeInputNode ( node : WorkflowNode , context : IndexedExecutionContext ) : Promise < string > {
663
+ const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
664
+ const text = node . data . content ? replaceIndexedInputs ( node . data . content , inputs , context ) : ''
665
+ return text . trim ( )
691
666
}
692
667
693
668
// #region 4 Search Context Node Execution */
@@ -700,12 +675,15 @@ async function executeInputNode(input: string): Promise<string> {
700
675
* @returns An array of strings, where each string represents a formatted context item (path + newline + content).
701
676
*/
702
677
async function executeSearchContextNode (
703
- input : string ,
678
+ node : WorkflowNode ,
704
679
contextRetriever : Pick < ContextRetriever , 'retrieveContext' > ,
705
680
abortSignal : AbortSignal ,
706
- allowRemoteContext : boolean
681
+ context : IndexedExecutionContext
707
682
) : Promise < string > {
708
683
abortSignal . throwIfAborted ( )
684
+ const inputs = combineParentOutputsByConnectionOrder ( node . id , context )
685
+ const text = node . data . content ? replaceIndexedInputs ( node . data . content , inputs , context ) : ''
686
+ const allowRemoteContext = ( node as SearchContextNode ) . data . local_remote
709
687
const corpusItems = await firstValueFrom ( getCorpusContextItemsForEditorState ( allowRemoteContext ) )
710
688
if ( corpusItems === pendingOperation || corpusItems . length === 0 ) {
711
689
return ''
@@ -715,15 +693,15 @@ async function executeSearchContextNode(
715
693
return ''
716
694
}
717
695
const span = tracer . startSpan ( 'chat.submit' )
718
- const context = await contextRetriever . retrieveContext (
696
+ const fetchedContext = await contextRetriever . retrieveContext (
719
697
toStructuredMentions ( corpusItems ) ,
720
- PromptString . unsafe_fromLLMResponse ( input ) ,
698
+ PromptString . unsafe_fromLLMResponse ( text ) ,
721
699
span ,
722
700
abortSignal ,
723
701
false
724
702
)
725
703
span . end ( )
726
- const result = context . map ( item => {
704
+ const result = fetchedContext . map ( item => {
727
705
// Format each context item as path + newline + content
728
706
return `${ item . uri . path } \n${ item . content || '' } `
729
707
} )
0 commit comments