Skip to content

Commit e7a7055

Browse files
committed
fix: fix guided decoding state corruption in turbomind when tp>1
1 parent 8258be5 commit e7a7055

File tree

10 files changed

+31
-13
lines changed

10 files changed

+31
-13
lines changed

src/turbomind/engine/model_request.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
namespace turbomind {
1515

16-
ModelRequest::ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim):
16+
ModelRequest::ModelRequest(
17+
Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim, int tp_size):
1718
gateway_{gateway},
1819
data_type_{data_type},
1920
session_len_{session_len},
2021
vocab_size_{vocab_size},
21-
hidden_dim_{hidden_dim}
22+
hidden_dim_{hidden_dim},
23+
tp_size_{tp_size}
2224
{
2325
}
2426

@@ -128,7 +130,10 @@ auto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> Output
128130
r->sequence_length = outputs_->at("sequence_length");
129131

130132
if (grammar_) {
131-
r->matcher = std::make_shared<xgrammar::GrammarMatcher>(*grammar_);
133+
r->matchers.clear();
134+
for (int i = 0; i < tp_size_; ++i) {
135+
r->matchers.push_back(std::make_shared<xgrammar::GrammarMatcher>(*grammar_));
136+
}
132137
}
133138

134139
// Keep a weak reference for canceling the request

src/turbomind/engine/model_request.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class ModelRequest {
1515
public:
1616
virtual ~ModelRequest() = default;
1717

18-
ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim);
18+
ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim, int tp_size);
1919

2020
// Cancel running request
2121
void Cancel();
@@ -50,6 +50,7 @@ class ModelRequest {
5050
const int session_len_;
5151
const int hidden_dim_;
5252
const int vocab_size_;
53+
const int tp_size_;
5354

5455
uint64_t session_id_;
5556

src/turbomind/engine/request.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ struct Request {
154154
kInconsistency = 9, // Inconsistent request parameters, e.g. prefix caching is not allowed in interactive mode
155155
};
156156

157-
std::shared_ptr<xgrammar::GrammarMatcher> matcher;
157+
std::vector<std::shared_ptr<xgrammar::GrammarMatcher>> matchers; // GrammarMatchers for different threads (tp_size)
158158
};
159159

160160
inline void UpdateState(Request& r, int status, int seq_len)

src/turbomind/layers/BaseDynamicDecodeLayer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class BaseDynamicDecodeLayer {
3131
int vocab_size_padded;
3232
cudaStream_t stream;
3333
const cudaDeviceProp* device_prop;
34+
int tp_rank;
3435
};
3536

3637
virtual ~BaseDynamicDecodeLayer() = default;
@@ -42,6 +43,7 @@ class BaseDynamicDecodeLayer {
4243
vocab_size_padded_ = param.vocab_size_padded;
4344
stream_ = param.stream;
4445
device_prop_ = param.device_prop;
46+
tp_rank_ = param.tp_rank;
4547
};
4648

4749
virtual void Setup(const std::vector<const Request*>& rs, const TensorMap& args) = 0;
@@ -54,6 +56,7 @@ class BaseDynamicDecodeLayer {
5456
int vocab_size_padded_;
5557
cudaStream_t stream_;
5658
const cudaDeviceProp* device_prop_;
59+
int tp_rank_;
5760
};
5861

5962
} // namespace turbomind

src/turbomind/layers/DynamicDecodeLayer.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ DynamicDecodeLayer::DynamicDecodeLayer(DataType dtype,
3131
int vocab_size,
3232
int vocab_size_padded,
3333
cudaStream_t stream,
34-
const cudaDeviceProp* device_prop)
34+
const cudaDeviceProp* device_prop,
35+
int tp_rank):
36+
tp_rank_{tp_rank}
3537
{
3638
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
3739
TM_CHECK(dtype == kFloat32);
38-
BaseDynamicDecodeLayer::BaseParam param{max_batch_size, vocab_size, vocab_size_padded, stream, device_prop};
40+
BaseDynamicDecodeLayer::BaseParam param{
41+
max_batch_size, vocab_size, vocab_size_padded, stream, device_prop, tp_rank};
3942
layers_.emplace_back(new LogitsProcessorLayer<float>{param});
4043
layers_.emplace_back(new GuidedDecodeMaskLayer<float>{param});
4144
layers_.emplace_back(new SamplingLayer<float>{param});

src/turbomind/layers/DynamicDecodeLayer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class DynamicDecodeLayer {
3333
int vocab_size,
3434
int vocab_size_padded,
3535
cudaStream_t stream,
36-
const cudaDeviceProp* device_prop);
36+
const cudaDeviceProp* device_prop,
37+
int tp_rank);
3738

3839
~DynamicDecodeLayer();
3940

@@ -42,6 +43,7 @@ class DynamicDecodeLayer {
4243
void Forward(TensorMap& args);
4344

4445
private:
46+
int tp_rank_;
4547
std::vector<std::unique_ptr<BaseDynamicDecodeLayer>> layers_;
4648
};
4749

src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void GuidedDecodeMaskLayer<T>::Setup(const std::vector<const Request*>& rs, cons
3333
TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
3434
matchers_.clear();
3535
for (const auto& r : rs) {
36-
matchers_.push_back(r->matcher);
36+
matchers_.push_back(r->matchers[tp_rank_]);
3737
}
3838
}
3939

src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void GuidedDecodeUpdateLayer<T>::Setup(const std::vector<const Request*>& rs, co
2929
TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
3030
matchers_.clear();
3131
for (const auto& r : rs) {
32-
matchers_.push_back(r->matcher);
32+
matchers_.push_back(r->matchers[tp_rank_]);
3333
}
3434
}
3535

src/turbomind/models/llama/LlamaV2.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ LlamaV2::LlamaV2(DataType dtype,
9090

9191
// using float to avoid data overflow
9292
dynamic_decode_ = std::make_unique<DynamicDecodeLayer>(
93-
kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop);
93+
kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop, engine.mlp_tp_rank);
9494
}
9595

9696
void LlamaV2::updateEmbedding(char* decoder_input,

src/turbomind/triton_backend/llama/LlamaTritonModel.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,12 @@ std::unique_ptr<ModelRequest> LlamaTritonModel::createModelInstance(int device_i
454454
{
455455
FT_CHECK(engines_[device_id] != nullptr);
456456

457-
return std::make_unique<ModelRequest>(
458-
gateway_.get(), dtype_, engine_param_.session_len, model_param_.vocab_size, model_param_.hidden_units);
457+
return std::make_unique<ModelRequest>(gateway_.get(),
458+
dtype_,
459+
engine_param_.session_len,
460+
model_param_.vocab_size,
461+
model_param_.hidden_units,
462+
comm_size_);
459463
}
460464

461465
void LlamaTritonModel::createSharedWeights(int device_id, int rank)

0 commit comments

Comments
 (0)