@@ -463,11 +463,13 @@ def search(self,t):
463463 return ops .argmax (t ,- 1 )
464464 def call (self , inputs , ** kwargs ):
465465 hidden_state ,update_index ,out_ids ,flags = inputs [:]
466+
466467 y = self .search (hidden_state )
467468 t = ops .full_like (y ,self .end_token )
468469 y = ops .where (flags ,y ,t )
469470 start = [0 ,update_index ]
470471 flags = y != self .end_token
472+
471473 return ops .slice_update (out_ids ,start ,ops .cast (y ,out_ids .dtype )),flags
472474
473475class TopkSearch (GreedySearch ):
@@ -943,10 +945,10 @@ def build(self, input_shape):
943945 def call (self , inputs ):
944946 """如果custom_position_ids,那么第二个输入为自定义的位置id
945947 """
946- if self .custom_position_ids :
948+ flag = isinstance (inputs ,list )
949+ if self .custom_position_ids or flag :
947950 inputs , position_ids = inputs
948- if 'int' not in K .dtype (position_ids ):
949- position_ids = ops .cast (position_ids , 'int32' )
951+ position_ids = ops .cast (position_ids , 'int32' )
950952 else :
951953 input_shape = ops .shape (inputs )
952954 batch_size , seq_len = input_shape [0 ], input_shape [1 ]
@@ -960,8 +962,8 @@ def call(self, inputs):
960962 embeddings_y = ops .take (embeddings , position_ids % self .input_dim )
961963 embeddings = alpha * embeddings_x + (1 - alpha ) * embeddings_y
962964 else :
963- if self .custom_position_ids :
964- embeddings = ops .take (self .embeddings , position_ids )
965+ if self .custom_position_ids or flag :
966+ embeddings = ops .take (self .embeddings , position_ids , axis = 0 )
965967 else :
966968 embeddings = self .embeddings [None , :seq_len ]
967969
0 commit comments