@@ -914,55 +914,6 @@ def _get_eagle_module_inputs(
914
914
eagle_inputs ["attention_mask" ] = attn_mask
915
915
eagle_inputs ["position_ids" ] = position_ids
916
916
eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
917
-
918
- if self .config .sequence_parallel :
919
- gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
920
- else :
921
- gathered_hidden_states = hidden_states
922
- eagle_inputs ["hidden_states" ] = gathered_hidden_states
923
-
924
- for i in range (self .eagle_config .parallel_draft_step - 1 ):
925
- eagle_inputs ["input_ids" ] = torch .cat (
926
- (
927
- eagle_inputs ["input_ids" ],
928
- torch .full (
929
- padded_input_ids .shape ,
930
- getattr (self , f"mask_token_{ i } " ),
931
- device = padded_input_ids .device ,
932
- dtype = padded_input_ids .dtype ,
933
- ),
934
- ),
935
- dim = - 1 ,
936
- )
937
-
938
- eagle_inputs ["hidden_states" ] = torch .cat (
939
- (
940
- eagle_inputs ["hidden_states" ],
941
- torch .zeros (
942
- (1 + i , b , h ), dtype = hidden_states .dtype , device = hidden_states .device
943
- ),
944
- gathered_hidden_states [: - (1 + i )],
945
- ),
946
- dim = 0 ,
947
- )
948
-
949
- eagle_inputs ["position_ids" ] = torch .cat (
950
- (eagle_inputs ["position_ids" ], position_ids ), dim = - 1
951
- )
952
-
953
- if rotary_pos_emb is not None :
954
- eagle_inputs ["rotary_pos_emb" ] = torch .cat (
955
- (eagle_inputs ["rotary_pos_emb" ], rotary_pos_emb ), dim = 0
956
- )
957
-
958
- if self .config .sequence_parallel :
959
- eagle_inputs ["hidden_states" ] = scatter_to_sequence_parallel_region (
960
- eagle_inputs ["hidden_states" ]
961
- )
962
-
963
- eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
964
- attn_mask , self .eagle_config .parallel_draft_step
965
- )
966
917
elif features .shape [0 ] == hidden_states .shape [0 ]:
967
918
eagle_inputs ["input_ids" ] = torch .cat (
968
919
(padded_input_ids , padded_input_ids ),
0 commit comments