Skip to content

Commit

Permalink
Log the empty leaf node path (as opposed to its dataset)
Browse files Browse the repository at this point in the history
This adds a test and fixes a bug with logging of the empty sdg leaf
nodes, as it was trying to log the actual empty dataset instead of the
leaf node path.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Sep 19, 2024
1 parent 81bca55 commit 59ba8a5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,8 @@ def generate_data(
logger.debug("Dataset: %s", ds)
new_generated_data = pipe.generate(ds, leaf_node_path)
if len(new_generated_data) == 0:
empty_sdg_leaf_nodes.append(new_generated_data)
logger.warning(
"Empty dataset for qna node: %s", leaf_node[0]["taxonomy_path"]
)
empty_sdg_leaf_nodes.append(leaf_node_path)
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
continue
generated_data = (
[new_generated_data]
Expand Down
61 changes: 61 additions & 0 deletions tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import glob
import json
import os
import re
import shutil
import tempfile
import unittest
Expand Down Expand Up @@ -267,6 +268,11 @@ def strip_q(q):
return output


def _empty_llmblock_generate(self, samples):
"""Return an empty set of generated samples."""
return []


@patch.object(LLMBlock, "_generate", _noop_llmblock_generate)
class TestGenerateCompositionalData(unittest.TestCase):
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -452,6 +458,61 @@ def __exit__(self, *args):
self.teardown()


@patch.object(LLMBlock, "_generate", _empty_llmblock_generate)
class TestGenerateEmptyDataset(unittest.TestCase):
@pytest.fixture(autouse=True)
def _init_taxonomy(self, taxonomy_dir):
self.test_taxonomy = taxonomy_dir

def setUp(self):
self.tmp_path = tempfile.TemporaryDirectory().name
test_valid_knowledge_skill_file = os.path.join(
TEST_DATA_DIR, "test_valid_knowledge_skill.yaml"
)
tracked_knowledge_file = os.path.join("knowledge ", "tracked", "qna.yaml")
untracked_knowledge_file = os.path.join("knowledge", "new", "qna.yaml")
test_valid_knowledge_skill = load_test_skills(test_valid_knowledge_skill_file)
self.test_taxonomy.add_tracked(
tracked_knowledge_file, test_valid_knowledge_skill
)
self.test_taxonomy.create_untracked(
untracked_knowledge_file, test_valid_knowledge_skill
)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
)
mocked_logger.warning.assert_called()
assert re.search(
"empty sdg output: knowledge_new", mocked_logger.warning.call_args.args[0]
)

def teardown(self) -> None:
"""Recursively remove the temporary repository and all of its
subdirectories and files.
"""
shutil.rmtree(self.tmp_path)
return

def __enter__(self):
return self

def __exit__(self, *args):
self.teardown()


def test_context_init_batch_size_optional():
"""Test that the _context_init function can handle a missing batch size by
delegating to the default in PipelineContext.
Expand Down

0 comments on commit 59ba8a5

Please sign in to comment.