@@ -845,7 +845,7 @@ def _get_eagle_module_inputs(
845845 rotary_pos_emb = self .eagle_module .rotary_pos_emb (padded_input_ids .shape [- 1 ])
846846
847847 attn_mask = attention_mask .clone ().detach ()
848- attn_mask [:, :, :- 1 , :- 1 ] = attn_mask [:, :, 1 :, 1 :]
848+ attn_mask [:, :, :- 1 , :- 1 ] = attention_mask [:, :, 1 :, 1 :]
849849 attn_mask [:, :, - 1 , :] = True
850850 attn_mask [:, :, :, - 1 ] = True
851851
@@ -914,6 +914,55 @@ def _get_eagle_module_inputs(
914914 eagle_inputs ["attention_mask" ] = attn_mask
915915 eagle_inputs ["position_ids" ] = position_ids
916916 eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
917+
918+ if self .config .sequence_parallel :
919+ gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
920+ else :
921+ gathered_hidden_states = hidden_states
922+ eagle_inputs ["hidden_states" ] = gathered_hidden_states
923+
924+ for i in range (self .eagle_config .parallel_draft_step - 1 ):
925+ eagle_inputs ["input_ids" ] = torch .cat (
926+ (
927+ eagle_inputs ["input_ids" ],
928+ torch .full (
929+ padded_input_ids .shape ,
930+ getattr (self , f"mask_token_{ i } " ),
931+ device = padded_input_ids .device ,
932+ dtype = padded_input_ids .dtype ,
933+ ),
934+ ),
935+ dim = - 1 ,
936+ )
937+
938+ eagle_inputs ["hidden_states" ] = torch .cat (
939+ (
940+ eagle_inputs ["hidden_states" ],
941+ torch .zeros (
942+ (1 + i , b , h ), dtype = hidden_states .dtype , device = hidden_states .device
943+ ),
944+ gathered_hidden_states [: - (1 + i )],
945+ ),
946+ dim = 0 ,
947+ )
948+
949+ eagle_inputs ["position_ids" ] = torch .cat (
950+ (eagle_inputs ["position_ids" ], position_ids ), dim = - 1
951+ )
952+
953+ if rotary_pos_emb is not None :
954+ eagle_inputs ["rotary_pos_emb" ] = torch .cat (
955+ (eagle_inputs ["rotary_pos_emb" ], rotary_pos_emb ), dim = 0
956+ )
957+
958+ if self .config .sequence_parallel :
959+ eagle_inputs ["hidden_states" ] = scatter_to_sequence_parallel_region (
960+ eagle_inputs ["hidden_states" ]
961+ )
962+
963+ eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
964+ attn_mask , self .eagle_config .parallel_draft_step
965+ )
917966 elif features .shape [0 ] == hidden_states .shape [0 ]:
918967 eagle_inputs ["input_ids" ] = torch .cat (
919968 (padded_input_ids , padded_input_ids ),
0 commit comments