Skip to content

Commit

Permalink
Fix initialization of runtime_min_p (#2461)
Browse files Browse the repository at this point in the history
* fix minp

* better test params
  • Loading branch information
irexyc authored Sep 12, 2024
1 parent f98d152 commit 64fe4c5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/turbomind/layers/sampling_layers/SamplingLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void set_runtime_args(int batch_size,
for (int i = 0; i < batch_size; i++) {
int topk = top_ks_size > 1 ? top_ks[i] : top_k;
float topp = top_ps_size > 1 ? top_ps[i] : top_p;
float minp = min_ps_size > 1 ? top_ps[i] : min_p;
float minp = min_ps_size > 1 ? min_ps[i] : min_p;

if (topk == 0 && topp == 0.f) {
topk = 1;
Expand Down
10 changes: 5 additions & 5 deletions tests/csrc/unittests/test_sampling_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -506,22 +506,22 @@ TYPED_TEST_SUITE(TopPMinPFilterTest, FloatType);

TYPED_TEST(TopPMinPFilterTest, OnlyTopP)
{
float top_ps[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f};
float top_ps[] = {0.8f, 0.82f, 0.84f, 0.86f, 0.88f, 0.90f, 0.92f, 0.94f, 0.96f, 0.98f, 1.0f};
float min_ps[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);
};

TYPED_TEST(TopPMinPFilterTest, OnlyMinP)
{
float min_ps[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f};
float top_ps[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
float min_ps[] = {0.0f, 0.002f, 0.004f, 0.006f, 0.008f, 0.01f, 0.012f, 0.014f, 0.016f, 0.018f, 0.02f};
float top_ps[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};
this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);
};

TYPED_TEST(TopPMinPFilterTest, MixedTopPMinP)
{
float min_ps[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f};
float top_ps[] = {1.f, 0.9f, 0.8f, 0.7f, 0.6f, 0.5f, 0.4f, 0.3f, 0.2f, 0.1f, 0.f};
float min_ps[] = {0.0f, 0.002f, 0.004f, 0.006f, 0.008f, 0.01f, 0.012f, 0.014f, 0.016f, 0.018f, 0.02f};
float top_ps[] = {0.8f, 0.82f, 0.84f, 0.86f, 0.88f, 0.90f, 0.92f, 0.94f, 0.96f, 0.98f, 1.0f};
this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);
};

Expand Down

0 comments on commit 64fe4c5

Please sign in to comment.