Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def test_check_filename_empty():
assert "empty name" in error


def test_check_filename_none():
"""Test check_filename with None"""
error = checker.check_filename(None)
assert "empty name" in error


# ============================================================================
# check_kind tests
# ============================================================================
Expand Down Expand Up @@ -110,6 +116,21 @@ def test_check_kind_symlink(tmp_path):
assert checker.check_kind(str(link), follow_symlink=False) == "symlink"


def test_check_kind_special_files(tmp_path):
"""Test check_kind with special file types"""
special_files = {
"char_device": 0o020666, # S_IFCHR
"block_device": 0o060666, # S_IFBLK
"socket": 0o140666, # S_IFSOCK
"fifo": 0o010666, # S_IFIFO
}

for kind, mode in special_files.items():
with patch("os.stat") as mock_stat:
mock_stat.return_value = MagicMock(st_mode=mode)
assert checker.check_kind("fake_path") == kind


# ============================================================================
# check_daemon tests
# ============================================================================
Expand Down Expand Up @@ -352,6 +373,34 @@ def test_validate_time_field_logic_range_out_of_bounds_end():
assert any("range end" in e and "out of bounds" in e for e in errors)


def test_validate_time_field_logic_duplicate_value():
"""Test validate_time_field_logic with duplicate values in a list"""
errors = checker.validate_time_field_logic("1,2,1", "minutes", 0, 59)
assert len(errors) > 0
assert any("duplicate value" in e for e in errors)


def test_validate_time_field_logic_empty_value():
"""Test validate_time_field_logic with an empty value in a list"""
errors = checker.validate_time_field_logic("1,,2", "minutes", 0, 59)
assert len(errors) > 0
assert any("empty value" in e for e in errors)


def test_validate_time_field_logic_invalid_step_format():
"""Test validate_time_field_logic with an invalid step format"""
errors = checker.validate_time_field_logic("*/a", "minutes", 0, 59)
assert len(errors) > 0
assert any("invalid step value" in e for e in errors)


def test_validate_time_field_logic_invalid_range():
"""Test validate_time_field_logic with an invalid range where start > end"""
errors = checker.validate_time_field_logic("5-1", "minutes", 0, 59)
assert len(errors) > 0
assert any("invalid range" in e for e in errors)


# ============================================================================
# check_user_exists tests
# ============================================================================
Expand Down Expand Up @@ -435,6 +484,56 @@ def test_check_command_missing():
assert "missing command" in errors[0]


def test_check_dangerous_commands():
"""Test check_dangerous_commands with dangerous commands"""
dangerous_commands = [
"rm -rf /",
"rm -rf /root",
"rm -rf / && ls",
"rm -rf / || ls",
"rm -rf / ; ls",
]
for command in dangerous_commands:
errors = checker.check_dangerous_commands(command)
assert len(errors) > 0
assert any("dangerous command" in e for e in errors)


def test_check_minutes_invalid_format():
"""Test check_minutes with an invalid format"""
errors = checker.check_minutes("a", is_system_crontab=False)
assert len(errors) > 0
assert any("invalid minute format" in e for e in errors)


def test_check_hours_invalid_format():
"""Test check_hours with an invalid format"""
errors = checker.check_hours("a")
assert len(errors) > 0
assert any("invalid hour format" in e for e in errors)


def test_check_day_of_month_invalid_format():
"""Test check_day_of_month with an invalid format"""
errors = checker.check_day_of_month("a")
assert len(errors) > 0
assert any("invalid day of month format" in e for e in errors)


def test_check_month_invalid_format():
"""Test check_month with an invalid format"""
errors = checker.check_month("a")
assert len(errors) > 0
assert any("invalid month format" in e for e in errors)


def test_check_day_of_week_invalid_format():
"""Test check_day_of_week with an invalid format"""
errors = checker.check_day_of_week("a")
assert len(errors) > 0
assert any("invalid day of week format" in e for e in errors)


# ============================================================================
# Legacy functions tests
# ============================================================================
Expand Down Expand Up @@ -463,6 +562,32 @@ def test_check_line_special_legacy():
# ============================================================================


def test_check_line_insufficient_fields_special():
"""Test check_line with insufficient fields for a special keyword"""
errors, warnings = checker.check_line("@reboot", 1, "test.txt")
assert len(errors) > 0
assert any("insufficient fields" in e for e in errors)


def test_check_line_insufficient_fields_regular():
"""Test check_line with insufficient fields for a regular crontab line"""
errors, warnings = checker.check_line("* * * *", 1, "test.txt")
assert len(errors) > 0
assert any("insufficient fields" in e for e in errors)


def test_check_line_system_extra_field_in_command():
"""Test check_line with an extra field in the command for a system crontab line"""
errors, warnings = checker.check_line(
"* * * * * root extra /usr/bin/backup.sh",
1,
"test.txt",
is_system_crontab=True,
)
assert len(errors) > 0
assert any("extra field" in e for e in errors)


@patch("checkcrontab.checker.os.lstat")
@patch("checkcrontab.checker.os.path.exists")
@patch("checkcrontab.checker.os.path.lexists")
Expand Down Expand Up @@ -558,6 +683,31 @@ def test_check_line_multiline_continuation_with_backslash():
pass # Will test through integration tests


def test_check_user_invalid_format():
"""Test check_user with various invalid formats"""
invalid_usernames = ["", "#user", '"user"', "user@", "user name", "-user"]
for username in invalid_usernames:
errors, warnings = checker.check_user(username)
assert len(errors) > 0
assert any("invalid user format" in e for e in errors)


def test_check_special_invalid_keyword():
"""Test check_special with an invalid keyword"""
errors = checker.check_special(
"@invalid", ["@invalid", "command"], is_system_crontab=False
)
assert len(errors) > 0
assert any("invalid special keyword" in e for e in errors)


def test_check_line_with_env_var():
"""Test check_line with an environment variable"""
errors, warnings = checker.check_line("MAILTO=test@example.com", 1, "test.txt")
assert errors == []
assert warnings == []


# ============================================================================
# Tests for special file types (char device, block device, socket, fifo)
# ============================================================================
Expand Down
100 changes: 100 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3

import logging
import platform
import sys
from unittest.mock import MagicMock, patch

import pytest
from checkcrontab.logger import ColoredFormatter, setup_logging

# Mock the getwindowsversion function if it doesn't exist (i.e., on non-Windows platforms)
if not hasattr(sys, "getwindowsversion"):

class MockWindowsVersion:
def __init__(self, major, build):
self.major = major
self.build = build

sys.getwindowsversion = MagicMock(return_value=MockWindowsVersion(10, 10586))


class TestColoredFormatter:
@pytest.fixture
def log_record(self):
return logging.LogRecord(
"test", logging.INFO, "/test", 1, "test message", None, None
)

def test_formatter_no_color(self, log_record):
formatter = ColoredFormatter(fmt="%(message)s", use_colors=False)
formatted_message = formatter.format(log_record)
assert formatted_message == "test message"
assert "\033" not in formatted_message

def test_formatter_with_color(self, log_record):
formatter = ColoredFormatter(fmt="%(levelname)s: %(message)s", use_colors=True)
with patch("platform.system", return_value="Linux"):
formatted_message = formatter.format(log_record)
assert "test message" in formatted_message
assert "\033[0;32mINFO\033[0m" in formatted_message

@patch("platform.system", return_value="Windows")
def test_windows_color_compatibility_supported(self, mock_system, log_record):
with patch(
"sys.getwindowsversion",
return_value=MagicMock(major=10, build=10586),
):
formatter = ColoredFormatter(
fmt="%(levelname)s: %(message)s", use_colors=True
)
assert formatter._get_color_compatibility() is True
formatted_message = formatter.format(log_record)
assert "\033[0;32mINFO\033[0m" in formatted_message

@patch("platform.system", return_value="Windows")
def test_windows_color_compatibility_not_supported(self, mock_system, log_record):
with patch(
"sys.getwindowsversion",
return_value=MagicMock(major=10, build=10000),
):
formatter = ColoredFormatter(fmt="%(message)s", use_colors=True)
assert formatter._get_color_compatibility() is False
formatted_message = formatter.format(log_record)
assert "\033" not in formatted_message

@patch("platform.system", return_value="Windows")
def test_windows_color_compatibility_exception(self, mock_system, log_record):
with patch("sys.getwindowsversion", side_effect=Exception()):
formatter = ColoredFormatter(fmt="%(message)s", use_colors=True)
assert formatter._get_color_compatibility() is False
formatted_message = formatter.format(log_record)
assert "\033" not in formatted_message


class TestSetupLogging:
@patch("logging.basicConfig")
def test_setup_logging_debug(self, mock_basic_config):
setup_logging(debug=True)
mock_basic_config.assert_called_once()
assert mock_basic_config.call_args[1]["level"] == logging.DEBUG

@patch("logging.basicConfig")
def test_setup_logging_info(self, mock_basic_config):
setup_logging(debug=False)
mock_basic_config.assert_called_once()
assert mock_basic_config.call_args[1]["level"] == logging.INFO

@patch("logging.StreamHandler", MagicMock())
@patch("logging.basicConfig")
def test_setup_logging_stderr(self, mock_basic_config):
with patch("sys.stderr") as mock_stderr:
setup_logging(use_stderr=True)
logging.StreamHandler.assert_called_with(mock_stderr)

@patch("logging.StreamHandler", MagicMock())
@patch("logging.basicConfig")
def test_setup_logging_stdout(self, mock_basic_config):
with patch("sys.stdout") as mock_stdout:
setup_logging(use_stderr=False)
logging.StreamHandler.assert_called_with(mock_stdout)
Loading