Skip to content
148 changes: 76 additions & 72 deletions scripts/data/filtering_and_updates/filter_dataset_by_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,99 @@
Motivated by: realizing the SFT mix has lots of "I am DeepSeek" snippets.

Run with:
python scripts/data/sft/filter_dataset_by_keywords.py --input-dataset allenai/tulu-3-sft-mixture --column messages
python scripts/data/filtering_and_updates/filter_dataset_by_keywords.py --input-dataset allenai/tulu-3-sft-mixture --column messages
"""

import os
os.environ['HF_DATASETS_DISABLE_CACHING'] = '1'

from datasets import disable_caching
disable_caching()



# Popular model providers
PROVIDERS = [
"OpenAI", "Open AI", "Claude", "Gemini", "Qwen", "DeepSeek", "Anthropic", "Meta AI", "Meta's",
"OpenAI", "Open AI", "Claude", "Gemini", "Qwen", "DeepSeek", "Anthropic", "Meta AI", "Meta's", "ChatGPT",
"Cohere", "Mistral AI", "Mistral's", "xAI", "Perplexity" # "Google AI", "Google's", "Microsoft", "HuggingFace", "Hugging Face"
]

# Regex patterns for filtering (case-insensitive for common words, case-sensitive for company names)
# Regex patterns for filtering (case-insensitive for common words, case-sensitive for company names)
PATTERNS = [
# Pattern: "I'm [model name], an AI assistant made by {provider}"
r"(?i)i'?m\s+(" + "|".join(PROVIDERS) + r"),?\s+an?\s+ai\s+(?:assistant|model)[^.!?]*?(?:made|developed|created|trained)\s+by\s+(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
# Kept full range, removed optional grouping that was too restrictive
r"(?i)\bI'?m\s+(" + "|".join(PROVIDERS) + r"),?\s+an?\s+AI\s+(?:assistant|model)[^.!?]{0,100}(?:made|developed|created|trained)\s+by\s+(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "[Model name] is an AI assistant developed by {provider}"
r"(?i)(" + "|".join(PROVIDERS) + r")\s+is\s+an?\s+ai\s+(?:assistant|model)[^.!?]*?(?:developed|created|made|trained)\s+by\s+(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
# Restored full pattern
r"(?i)\b(" + "|".join(PROVIDERS) + r")\s+is\s+an?\s+AI\s+(?:assistant|model)[^.!?]{0,100}(?:developed|created|made|trained)\s+by\s+(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "as a [AI model/assistant/chatbot] ... {provider}"
r"(?i)as\s+a\s+(?:language\s+model|ai\s+model|assistant|chatbot|model)[^.!?]*?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
# Kept greedy to match more
r"(?i)\bas\s+an?\s+(?:language\s+model|AI\s+model|assistant|chatbot|model)[^.!?]{0,100}\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "as an AI developed by {provider}"
r"(?i)as\s+an\s+ai\s+(?:developed|created|made|trained)\s+by\s+(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
# Kept full range
r"(?i)\bas\s+an\s+AI\s+(?:developed|created|made|trained)\s+by\s+(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "I am [model type] ... {provider}"
r"(?i)i\s+am\s+(?:a\s+)?(?:language\s+model|ai\s+model|assistant|chatbot|model)[^.!?]*?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
# Kept greedy for full matches
r"(?i)\bI\s+am\s+(?:a\s+)?(?:language\s+model|AI\s+model|assistant|chatbot|model)[^.!?]{0,100}\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "trained by ... {provider}" within one sentence
r"(?i)trained\s+by\s+[^.!?]*?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
# Pattern: "I am called [provider]"
r"(?i)\bI\s+am\s+called\s+\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "I'm [provider]" or "I am [provider]"
r"(?i)\b(?:I'?m|I\s+am)\s+\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "trained by ... {provider}" within one sentence
# Made middle section non-greedy but kept full ranges
r"(?i)\btrained\s+by\s+[^.!?]{0,100}?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "developed by ... {provider}" within one sentence
r"(?i)developed\s+by\s+[^.!?]*?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
r"(?i)\bdeveloped\s+by\s+[^.!?]{0,100}?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "created by ... {provider}" within one sentence
r"(?i)created\s+by\s+[^.!?]*?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
r"(?i)\bcreated\s+by\s+[^.!?]{0,100}?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "made by ... {provider}" within one sentence
r"(?i)made\s+by\s+[^.!?]*?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]*?[.!?]",
r"(?i)\bmade\s+by\s+[^.!?]{0,100}?\b(" + "|".join(PROVIDERS) + r")\b[^.!?]{0,100}[.!?]",

# Pattern: "against {provider}'s use-case policy" or similar policy references
r"(?i)against\s+(" + "|".join(PROVIDERS) + r")(?:'s|'s)?\s+(?:use-case\s+)?(?:policy|policies|guidelines|terms)[^.!?]*?[.!?]",
r"(?i)\bagainst\s+(" + "|".join(PROVIDERS) + r")(?:'s|'s)?\s+(?:use-case\s+)?(?:policy|policies|guidelines|terms)[^.!?]{0,100}[.!?]",

# Pattern: "{provider}'s policy" or "{provider}'s guidelines"
r"(?i)\b(" + "|".join(PROVIDERS) + r")(?:'s|'s)\s+(?:policy|policies|guidelines|terms|use-case)[^.!?]*?[.!?]",
]
r"(?i)\b(" + "|".join(PROVIDERS) + r")(?:'s|'s)\s+(?:policy|policies|guidelines|terms|use-case)[^.!?]{0,100}[.!?]",

# Pattern: Any sentence containing "DeepSeek-R1" or "DeepSeek R1" (case-insensitive)
# Less restrictive: bounded but allows more at the start
r"(?i)[^.!?]{0,500}?\bDeepSeek[\s-]?R1\b[^.!?]{0,100}[.!?]",

# Pattern: Anything with the word "Qwen" (case-insensitive)
# Less restrictive: bounded but allows more at the start
r"(?i)[^.!?]{0,500}?\bQwen\b[^.!?]{0,100}[.!?]",

# Pattern: Any sentence containing "Alibaba Qwen" (case-insensitive) or Alibaba Cloud
# Less restrictive: bounded but allows more at the start
r"(?i)[^.!?]{0,500}?\bAlibaba\s+Qwen\b[^.!?]{0,100}[.!?]",
r"(?i)[^.!?]{0,500}?\bAlibaba\s+Cloud\b[^.!?]{0,100}[.!?]",
]

def should_be_filtered_by_advanced_patterns(example, verbose=False, filter_user_turns=False):
def should_be_filtered_by_advanced_patterns(example, column="messages", verbose=False, filter_user_turns=False):
"""Filter by more sophisticated patterns like 'as a ... OpenAI' or 'trained by ... Google'"""

for message in example["messages"]:
for message in example[column]:
# Skip user messages unless explicitly enabled
if message["role"] == "user" and not filter_user_turns:
continue
if message["role"] != "assistant" and message["role"] != "user":
continue

content = message["content"] # Keep original case

# empty content check
if content is None:
return True
for pattern in PATTERNS:
if re.search(pattern, content):
if verbose:
Expand All @@ -86,9 +120,9 @@ def should_be_filtered_by_advanced_patterns(example, verbose=False, filter_user_
return False


def should_be_filtered_combined(example, verbose=False, filter_user_turns=False):
def should_be_filtered_combined(example, column="messages", verbose=False, filter_user_turns=False):
"""Combined filtering function"""
return should_be_filtered_by_advanced_patterns(example, verbose, filter_user_turns)
return should_be_filtered_by_advanced_patterns(example, column=column, verbose=verbose, filter_user_turns=filter_user_turns)

def load_dataset_from_parquet(dataset_name):
"""Load dataset directly from parquet files."""
Expand Down Expand Up @@ -123,7 +157,9 @@ def main():
parser.add_argument("--filter-user-turns", action="store_true",
help="Also filter based on user messages (default: only filter assistant messages)")
parser.add_argument("--output-entity", type=str, help="Output entity (org/user) for the filtered dataset. If not provided, uses the same entity as the input dataset.")

parser.add_argument("--column", type=str, default="messages",
help="Column name containing the messages (default: messages)")

args = parser.parse_args()

input_dataset = args.input_dataset
Expand Down Expand Up @@ -160,60 +196,28 @@ def main():
raise

print(f"Dataset loaded with {len(dataset)} examples")

# Keep track of filtered examples
filtered_examples = []

# Filter function
def filter_fn(example):
should_filter = should_be_filtered_combined(example, verbose=True, filter_user_turns=filter_user_turns)
if should_filter and len(filtered_examples) < 3:
# Find which pattern matched and extract the matching text
for message in example["messages"]:
# Apply same filtering logic for finding matched text
if message["role"] == "user" and not filter_user_turns:
continue
if message["role"] != "assistant" and message["role"] != "user":
continue

content = message["content"] # Keep original case

for pattern in PATTERNS:
match = re.search(pattern, content)
if match:
example["_matched_text"] = match.group(0)
example["_matched_role"] = message["role"]
break
if "_matched_text" in example:
break

filtered_examples.append(example)
return not should_filter


print("Filtering dataset...")
filtered_dataset = dataset.filter(filter_fn)
# First filter without debugging
filtered_dataset = dataset.filter(
lambda ex: not should_be_filtered_combined(ex, column=args.column, verbose=False, filter_user_turns=filter_user_turns),
num_proc=int(192/2)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Parallel Filtering Overloads CPU Resources

The parallel filtering step hardcodes num_proc to 96 processes. This assumes a system with at least 96 CPU cores, which can cause resource exhaustion, poor performance, or hanging processes on machines with fewer cores. This likely explains the jobs hanging at 90% mentioned in the PR discussion.

Fix in Cursor Fix in Web

)
print(f"Filtered size: {len(filtered_dataset)}")
print(f"Removed {len(dataset) - len(filtered_dataset)} examples")

# Show a few filtered examples
if filtered_examples:
print("\n--- Examples that were removed ---")
for i, example in enumerate(filtered_examples):
print("---------------------------------")
print(f"\nExample {i+1}:")
if "_matched_text" in example:
role = example.get("_matched_role", "unknown")
print(f" Matched text ({role}): '{example['_matched_text']}'")
messages = example.get("messages", [])
for msg in messages:
if msg.get("role") == "user":
content = msg.get("content", "")
print(f" User: {content}")
if msg.get("role") == "assistant":
content = msg.get("content", "")
print(f" Assistant: {content}")

# Then collect a few filtered examples in serial for inspection
if len(dataset) > len(filtered_dataset):
print("\nCollecting example filtered instances...")
examples_found = 0
print_within = min(10000, len(dataset))
for example in dataset.select(range(print_within)):
if should_be_filtered_combined(example, column=args.column, verbose=True, filter_user_turns=filter_user_turns):
# Show the example
examples_found += 1
if examples_found >= 10:
break
print("--- End of examples ---\n")


# Upload
full_name = f"{output_dataset}"
Expand All @@ -222,4 +226,4 @@ def filter_fn(example):
print("Done!")

if __name__ == "__main__":
main()
main()