diff --git a/.envrc b/.envrc new file mode 100644 index 000000000..118c1901d --- /dev/null +++ b/.envrc @@ -0,0 +1,2 @@ +use asdf +layout python diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 232f1e0d6..8867ecf58 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -4,7 +4,7 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 services: postgres: @@ -35,7 +35,7 @@ jobs: steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -58,12 +58,12 @@ jobs: python orm migrate --connection mysql make test lint: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 name: Lint steps: - uses: actions/checkout@v1 - name: Set up Python 3.6 - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: 3.6 - name: Install Flake8 diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index dc4afe8ef..bc1301967 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -6,7 +6,7 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 services: postgres: @@ -37,7 +37,7 @@ jobs: steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/radon-analysis.yml b/.github/workflows/radon-analysis.yml new file mode 100644 index 000000000..f81c58e4d --- /dev/null +++ b/.github/workflows/radon-analysis.yml @@ -0,0 +1,66 @@ +name: Radon Analysis + +on: + pull_request: + paths: + - '**/*.py' # Only trigger for Python files + +jobs: + radon_analysis: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 0 # Ensure full history is fetched for git diff to work properly + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install Radon + run: pip install radon + + - name: Get list of changed files + id: get_changed_files + run: | + CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} | grep '.py$' || true) + echo "Changed files: $CHANGED_FILES" + echo "CHANGED_FILES<> $GITHUB_ENV + echo "$CHANGED_FILES" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + - name: Debug changed files + run: | + echo "Changed files from env: ${{ env.CHANGED_FILES }}" + shell: bash + + - name: Run Radon analysis + id: radon_analysis + run: | + if [ -z "${{ env.CHANGED_FILES }}" ]; then + echo "No Python files changed." + echo "RESULTS=None" >> $GITHUB_ENV + else + RESULTS=$(echo "${{ env.CHANGED_FILES }}" | xargs radon cc -s -n A || true) + echo "RESULTS<> $GITHUB_ENV + echo "$RESULTS" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + fi + + - name: Comment on PR + if: env.RESULTS != 'None' + uses: marocchino/sticky-pull-request-comment@v2 + with: + header: Radon Complexity Analysis + message: | + **Radon Analysis Results:** + ``` + ${{ env.RESULTS }} + ``` + + - name: Log if no Python files found + if: env.RESULTS == 'None' + run: echo "No Python files changed in this PR." diff --git a/.gitignore b/.gitignore index f593b05cf..3dee2e72d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ venv +.direnv .python-version .vscode .pytest_* @@ -15,4 +16,6 @@ htmlcov/* coverage.xml .coverage *.log -build \ No newline at end of file +build +/orm.sqlite3 +/.bootstrapped-pip diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 000000000..8b869bd71 --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +python 3.8.10 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 654a27f10..2a118b384 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,7 +12,7 @@ If you are interested in the project then it would be a great idea to read the " ## Issues -Everything really should start with opening an issue or finding an issue. If you feel you have an idea for how the project can be improved, no matter how small, you should open an issue so we can have an open dicussion with the maintainers of the project. +Everything really should start with opening an issue or finding an issue. If you feel you have an idea for how the project can be improved, no matter how small, you should open an issue so we can have an open discussion with the maintainers of the project. We can discuss in that issue the solution to the problem or feature you have. If we do not feel it fits within the project then we will close the issue. Feel free to open a new issue if new information comes up. diff --git a/cc.py b/cc.py index 44b2a5e48..5193c8e11 100644 --- a/cc.py +++ b/cc.py @@ -16,8 +16,9 @@ # print(builder.where("id", 1).or_where(lambda q: q.where('id', 2).or_where('id', 3)).get()) class User(Model): - __connection__ = "sqlite" + __connection__ = "mysql" __table__ = "users" + __dates__ = ["verified_at"] @has_many("id", "user_id") def articles(self): @@ -28,7 +29,9 @@ class Article(Model): # user = User.create({"name": "phill", "email": "phill"}) # print(inspect.isclass(User)) -print(User.find(1).with_("articles").first().serialize()) +user = User.first() +user.update({"verified_at": None, "updated_at": None}) +print(user.first().serialize()) # print(user.serialize()) # print(User.first()) \ No newline at end of file diff --git a/config/test-database.py b/config/test-database.py new file mode 100644 index 000000000..a05058cbd --- /dev/null +++ b/config/test-database.py @@ -0,0 +1,35 @@ +from src.masoniteorm.connections import ConnectionResolver + +DATABASES = { + "default": "mysql", + "mysql": { + "host": "127.0.0.1", + "driver": "mysql", + "database": "masonite", + "user": "root", + "password": "", + "port": 3306, + "log_queries": False, + "options": { + # + } + }, + "postgres": { + "host": "127.0.0.1", + "driver": "postgres", + "database": "masonite", + "user": "root", + "password": "", + "port": 5432, + "log_queries": False, + "options": { + # + } + }, + "sqlite": { + "driver": "sqlite", + "database": "masonite.sqlite3", + } +} + +DB = ConnectionResolver().set_connection_details(DATABASES) diff --git a/makefile b/makefile index 683335f7c..137524336 100644 --- a/makefile +++ b/makefile @@ -1,22 +1,25 @@ -init: +init: .env .bootstrapped-pip + +.bootstrapped-pip: requirements.txt requirements.dev + pip install -r requirements.txt -r requirements.dev + touch .bootstrapped-pip + +.env: cp .env-example .env - pip install -r requirements.txt - pip install . - # Create MySQL Database - # Create Postgres Database -test: + +# Create MySQL Database +# Create Postgres Database +test: init python -m pytest tests ci: make test +check: format sort lint lint: python -m flake8 src/masoniteorm/ --ignore=E501,F401,E203,E128,E402,E731,F821,E712,W503,F811 -format: - black src/masoniteorm - black tests/ - make lint -sort: - isort tests - isort src/masoniteorm +format: init + black src/masoniteorm tests/ +sort: init + isort src/masoniteorm tests/ coverage: python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/ python -m coveralls diff --git a/orm b/orm index 588410707..73c273fdc 100644 --- a/orm +++ b/orm @@ -10,6 +10,7 @@ from src.masoniteorm.commands import ( MigrateCommand, MigrateRollbackCommand, MigrateRefreshCommand, + MigrateFreshCommand, MakeMigrationCommand, MakeObserverCommand, MakeModelCommand, @@ -25,6 +26,7 @@ application = Application("ORM Version:", 0.1) application.add(MigrateCommand()) application.add(MigrateRollbackCommand()) application.add(MigrateRefreshCommand()) +application.add(MigrateFreshCommand()) application.add(MakeMigrationCommand()) application.add(MakeModelCommand()) application.add(MakeModelDocstringCommand()) diff --git a/orm.sqlite3 b/orm.sqlite3 index a0b6a9b11..f62e36cb5 100644 Binary files a/orm.sqlite3 and b/orm.sqlite3 differ diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..fb1fb3add --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +env = + D:DB_CONFIG_PATH=config/test-database \ No newline at end of file diff --git a/requirements.dev b/requirements.dev new file mode 100644 index 000000000..15b093441 --- /dev/null +++ b/requirements.dev @@ -0,0 +1,8 @@ +flake8==3.7.9 +black +faker +pytest +pytest-cov +pytest-env +pymysql +isort \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 398264052..aba4fd5b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,6 @@ -flake8==3.7.9 -black==19.3b0 -faker -pytest -pytest-cov -pymysql -isort inflection==0.3.1 psycopg2-binary -python-dotenv==0.14.0 pyodbc -pendulum>=2.1,<2.2 -cleo>=0.8.0,<0.9 \ No newline at end of file +pendulum>=2.1,<3.1 +cleo>=0.8.0,<0.9 +python-dotenv==0.14.0 \ No newline at end of file diff --git a/setup.py b/setup.py index bb3caf5d1..f62fc8e81 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ # Versions should comply with PEP440. For a discussion on single-sourcing # the version across setup.py and the project code, see # https://packaging.python.org/en/latest/single_source_version.html - version="2.18.1", + version="2.23.1", package_dir={"": "src"}, description="The Official Masonite ORM", long_description=long_description, @@ -30,7 +30,7 @@ # https://packaging.python.org/en/latest/requirements.html install_requires=[ "inflection>=0.3,<0.6", - "pendulum>=2.1,<2.2", + "pendulum>=2.1,<3.1", "faker>=4.1.0,<14.0", "cleo>=0.8.0,<0.9", ], diff --git a/src/masoniteorm/collection/Collection.py b/src/masoniteorm/collection/Collection.py index efadd444e..e238db8a5 100644 --- a/src/masoniteorm/collection/Collection.py +++ b/src/masoniteorm/collection/Collection.py @@ -40,6 +40,7 @@ def first(self, callback=None): if callback: filtered = self.filter(callback) response = None + if filtered: response = filtered[0] return response @@ -91,15 +92,15 @@ def avg(self, key=None): return result def max(self, key=None): - """Returns the average of the items. + """Returns the max of the items. - If a key is given it will return the average of all the values of the key. + If a key is given it will return the max of all the values of the key. Keyword Arguments: - key {string} -- The key to use to find the average of all the values of that key. (default: {None}) + key {string} -- The key to use to find the max of all the values of that key. (default: {None}) Returns: - int -- Returns the average. + int -- Returns the max. """ result = 0 items = self._get_value(key) or self._items @@ -110,6 +111,26 @@ def max(self, key=None): pass return result + def min(self, key=None): + """Returns the min of the items. + + If a key is given it will return the min of all the values of the key. + + Keyword Arguments: + key {string} -- The key to use to find the min of all the values of that key. (default: {None}) + + Returns: + int -- Returns the min. + """ + result = 0 + items = self._get_value(key) or self._items + + try: + return min(items) + except (TypeError, ValueError): + pass + return result + def chunk(self, size: int): """Chunks the items. @@ -280,7 +301,7 @@ def push(self, value): self._items.append(value) def put(self, key, value): - self[key] = value + self._items[key] = value return self def random(self, count=None): @@ -323,7 +344,7 @@ def _serialize(item): def add_relation(self, result=None): for model in self._items: - model.add_relations(result or {}) + model.add_relation(result or {}) return self @@ -351,7 +372,6 @@ def to_json(self, **kwargs): return json.dumps(self.serialize(), **kwargs) def group_by(self, key): - from itertools import groupby self.sort(key) @@ -410,7 +430,6 @@ def where(self, key, *args): return self.__class__(attributes) def where_in(self, key, args: list) -> "Collection": - attributes = [] for item in self._items: @@ -514,8 +533,13 @@ def __eq__(self, other): def __getitem__(self, item): if isinstance(item, slice): return self.__class__(self._items[item]) + if isinstance(item, dict): + return self._items.get(item, None) - return self._items[item] + try: + return self._items[item] + except KeyError: + return None def __setitem__(self, key, value): self._items[key] = value diff --git a/src/masoniteorm/commands/Entry.py b/src/masoniteorm/commands/Entry.py index c704f5f56..c46089048 100644 --- a/src/masoniteorm/commands/Entry.py +++ b/src/masoniteorm/commands/Entry.py @@ -10,6 +10,7 @@ MigrateCommand, MigrateRollbackCommand, MigrateRefreshCommand, + MigrateFreshCommand, MakeMigrationCommand, MakeModelCommand, MakeModelDocstringCommand, @@ -26,6 +27,7 @@ application.add(MigrateCommand()) application.add(MigrateRollbackCommand()) application.add(MigrateRefreshCommand()) +application.add(MigrateFreshCommand()) application.add(MakeMigrationCommand()) application.add(MakeModelCommand()) application.add(MakeModelDocstringCommand()) diff --git a/src/masoniteorm/commands/MigrateFreshCommand.py b/src/masoniteorm/commands/MigrateFreshCommand.py new file mode 100644 index 000000000..fb90f52ab --- /dev/null +++ b/src/masoniteorm/commands/MigrateFreshCommand.py @@ -0,0 +1,42 @@ +from ..migrations import Migration + +from .Command import Command + + +class MigrateFreshCommand(Command): + """ + Drops all tables and migrates them again. + + migrate:fresh + {--c|connection=default : The connection you want to run migrations on} + {--d|directory=databases/migrations : The location of the migration directory} + {--f|ignore-fk=? : The connection you want to run migrations on} + {--s|seed=? : Seed database after fresh. The seeder to be ran can be provided in argument} + {--schema=? : Sets the schema to be migrated} + {--D|seed-directory=databases/seeds : The location of the seed directory if seed option is used.} + """ + + def handle(self): + migration = Migration( + command_class=self, + connection=self.option("connection"), + migration_directory=self.option("directory"), + config_path=self.option("config"), + schema=self.option("schema"), + ) + + migration.fresh(ignore_fk=self.option("ignore-fk")) + + self.line("") + + if self.option("seed") == "null": + self.call( + "seed:run", + f"None --directory {self.option('seed-directory')} --connection {self.option('connection')}", + ) + + elif self.option("seed"): + self.call( + "seed:run", + f"{self.option('seed')} --directory {self.option('seed-directory')} --connection {self.option('connection')}", + ) diff --git a/src/masoniteorm/commands/MigrateRefreshCommand.py b/src/masoniteorm/commands/MigrateRefreshCommand.py index b56e5d377..c65735643 100644 --- a/src/masoniteorm/commands/MigrateRefreshCommand.py +++ b/src/masoniteorm/commands/MigrateRefreshCommand.py @@ -17,7 +17,6 @@ class MigrateRefreshCommand(Command): """ def handle(self): - migration = Migration( command_class=self, connection=self.option("connection"), @@ -33,11 +32,11 @@ def handle(self): if self.option("seed") == "null": self.call( "seed:run", - f"None --directory {self.option('seed-directory')} --connection {self.option('connection', 'default')}", + f"None --directory {self.option('seed-directory')} --connection {self.option('connection')}", ) elif self.option("seed"): self.call( "seed:run", - f"{self.option('seed')} --directory {self.option('seed-directory')} --connection {self.option('connection', 'default')}", + f"{self.option('seed')} --directory {self.option('seed-directory')} --connection {self.option('connection')}", ) diff --git a/src/masoniteorm/commands/MigrateStatusCommand.py b/src/masoniteorm/commands/MigrateStatusCommand.py index 8cdbd7068..7ddbfdd32 100644 --- a/src/masoniteorm/commands/MigrateStatusCommand.py +++ b/src/masoniteorm/commands/MigrateStatusCommand.py @@ -22,17 +22,28 @@ def handle(self): ) migration.create_table_if_not_exists() table = self.table() - table.set_header_row(["Ran?", "Migration"]) + table.set_header_row(["Ran?", "Migration", "Batch"]) migrations = [] - for migration_file in migration.get_ran_migrations(): + for migration_data in migration.get_ran_migrations(): + migration_file = migration_data["migration_file"] + batch = migration_data["batch"] + migrations.append( - ["Y", f"{migration_file}"] + [ + "Y", + f"{migration_file}", + f"{batch}", + ] ) for migration_file in migration.get_unran_migrations(): migrations.append( - ["N", f"{migration_file}"] + [ + "N", + f"{migration_file}", + "-", + ] ) table.set_rows(migrations) diff --git a/src/masoniteorm/commands/SeedRunCommand.py b/src/masoniteorm/commands/SeedRunCommand.py index 3ae6b585f..89e90359d 100644 --- a/src/masoniteorm/commands/SeedRunCommand.py +++ b/src/masoniteorm/commands/SeedRunCommand.py @@ -27,7 +27,6 @@ def handle(self): seeder_seeded = "Database Seeder" else: - table = self.argument("table") seeder_file = ( f"{underscore(table)}_table_seeder.{camelize(table)}TableSeeder" diff --git a/src/masoniteorm/commands/__init__.py b/src/masoniteorm/commands/__init__.py index 380b9a442..454ef577d 100644 --- a/src/masoniteorm/commands/__init__.py +++ b/src/masoniteorm/commands/__init__.py @@ -6,6 +6,7 @@ from .MigrateCommand import MigrateCommand from .MigrateRollbackCommand import MigrateRollbackCommand from .MigrateRefreshCommand import MigrateRefreshCommand +from .MigrateFreshCommand import MigrateFreshCommand from .MigrateResetCommand import MigrateResetCommand from .MakeModelCommand import MakeModelCommand from .MakeModelDocstringCommand import MakeModelDocstringCommand diff --git a/src/masoniteorm/config.py b/src/masoniteorm/config.py index 6e37ef0cf..52ce4eb37 100644 --- a/src/masoniteorm/config.py +++ b/src/masoniteorm/config.py @@ -2,7 +2,8 @@ import pydoc import urllib.parse as urlparse -from .exceptions import ConfigurationNotFound, InvalidUrlConfiguration +from .exceptions import ConfigurationNotFound +from .exceptions import InvalidUrlConfiguration def load_config(config_path=None): @@ -12,8 +13,11 @@ def load_config(config_path=None): 2. else try to load from default config_path: config/database """ selected_config_path = ( - config_path or os.getenv("DB_CONFIG_PATH", None) or "config/database" + os.getenv("DB_CONFIG_PATH", None) or config_path or "config/database" ) + + os.environ["DB_CONFIG_PATH"] = selected_config_path + # format path as python module if needed selected_config_path = ( selected_config_path.replace("/", ".").replace("\\", ".").rstrip(".py") diff --git a/src/masoniteorm/connections/BaseConnection.py b/src/masoniteorm/connections/BaseConnection.py index 8419f70a4..3a420fd20 100644 --- a/src/masoniteorm/connections/BaseConnection.py +++ b/src/masoniteorm/connections/BaseConnection.py @@ -4,7 +4,6 @@ class BaseConnection: - _connection = None _cursor = None _dry = False @@ -78,5 +77,15 @@ def select_many(self, query, bindings, amount): result = self.format_cursor_results(self._cursor.fetchmany(amount)) + def enable_disable_foreign_keys(self): + foreign_keys = self.full_details.get("foreign_keys") + platform = self.get_default_platform()() + + if foreign_keys: + self._connection.execute(platform.enable_foreign_key_constraints()) + elif foreign_keys is not None: + self._connection.execute(platform.disable_foreign_key_constraints()) + def get_row_count(self): return self._cursor.rowcount + diff --git a/src/masoniteorm/connections/ConnectionFactory.py b/src/masoniteorm/connections/ConnectionFactory.py index 9b924a3db..b36292dd4 100644 --- a/src/masoniteorm/connections/ConnectionFactory.py +++ b/src/masoniteorm/connections/ConnectionFactory.py @@ -4,9 +4,10 @@ class ConnectionFactory: """Class for controlling the registration and creation of connection types.""" - _connections = { - # - } + _connections = {} + + def __init__(self, config_path=None): + self.config_path = config_path @classmethod def register(cls, key, connection): @@ -35,7 +36,7 @@ def make(self, key): masoniteorm.connection.BaseConnection -- Returns an instance of a BaseConnection class. """ - DB = load_config().DB + DB = load_config(config_path=self.config_path).DB connections = DB.get_connection_details() diff --git a/src/masoniteorm/connections/ConnectionResolver.py b/src/masoniteorm/connections/ConnectionResolver.py index d518c36ad..f408c09c3 100644 --- a/src/masoniteorm/connections/ConnectionResolver.py +++ b/src/masoniteorm/connections/ConnectionResolver.py @@ -2,12 +2,11 @@ class ConnectionResolver: - _connection_details = {} _connections = {} _morph_map = {} - def __init__(self): + def __init__(self, config_path=None): from ..connections import ( SQLiteConnection, PostgresConnection, @@ -15,10 +14,10 @@ def __init__(self): MSSQLConnection, ) + self.config_path = config_path from ..connections import ConnectionFactory - self.connection_factory = ConnectionFactory() - + self.connection_factory = ConnectionFactory(config_path=config_path) self.register(SQLiteConnection) self.register(PostgresConnection) self.register(MySQLConnection) diff --git a/src/masoniteorm/connections/MSSQLConnection.py b/src/masoniteorm/connections/MSSQLConnection.py index a55ac119a..6bc72cbc5 100644 --- a/src/masoniteorm/connections/MSSQLConnection.py +++ b/src/masoniteorm/connections/MSSQLConnection.py @@ -26,7 +26,6 @@ def __init__( full_details=None, name=None, ): - self.host = host if port: self.port = int(port) @@ -71,6 +70,8 @@ def make_connection(self): autocommit=True, ) + self.enable_disable_foreign_keys() + self.open = 1 return self @@ -151,7 +152,7 @@ def query(self, query, bindings=(), results="*"): return {} columnNames = [column[0] for column in cursor.description] result = cursor.fetchone() - return dict(zip(columnNames, result)) + return dict(zip(columnNames, result)) if result is not None else {} else: if not cursor.description: return {} diff --git a/src/masoniteorm/connections/MySQLConnection.py b/src/masoniteorm/connections/MySQLConnection.py index 7bc0dd8ee..6b284ab58 100644 --- a/src/masoniteorm/connections/MySQLConnection.py +++ b/src/masoniteorm/connections/MySQLConnection.py @@ -31,10 +31,16 @@ def __init__( if str(port).isdigit(): self.port = int(self.port) self.database = database + self.user = user self.password = password self.prefix = prefix self.full_details = full_details or {} + self.connection_pool_size = ( + full_details.get( + "connection_pooling_max_size", 100 + ) + ) self.options = options or {} self._cursor = None self.open = 0 @@ -48,39 +54,80 @@ def make_connection(self): if self._dry: return + if self.has_global_connection(): + return self.get_global_connection() + + # Check if there is an available connection in the pool + self._connection = self.create_connection() + self.enable_disable_foreign_keys() + + return self + + def close_connection(self): + if ( + self.full_details.get("connection_pooling_enabled") + and len(CONNECTION_POOL) < self.connection_pool_size + ): + CONNECTION_POOL.append(self._connection) + self.open = 0 + self._connection = None + + def create_connection(self, autocommit=True): + try: import pymysql except ModuleNotFoundError: raise DriverNotFound( - "You must have the 'pymysql' package installed to make a connection to MySQL. Please install it using 'pip install pymysql'" + "You must have the 'pymysql' package " + "installed to make a connection to MySQL. " + "Please install it using 'pip install pymysql'" ) + import pendulum + import pymysql.converters - try: - import pendulum - import pymysql.converters + pymysql.converters.conversions[pendulum.DateTime] = ( + pymysql.converters.escape_datetime + ) - pymysql.converters.conversions[ - pendulum.DateTime - ] = pymysql.converters.escape_datetime - except ImportError: - pass + # Initialize the connection pool if the option is set + initialize_size = self.full_details.get("connection_pooling_min_size") + if initialize_size and len(CONNECTION_POOL) < initialize_size: + for _ in range(initialize_size - len(CONNECTION_POOL)): + connection = pymysql.connect( + cursorclass=pymysql.cursors.DictCursor, + autocommit=autocommit, + host=self.host, + user=self.user, + password=self.password, + port=self.port, + database=self.database, + **self.options + ) + CONNECTION_POOL.append(connection) + + if ( + self.full_details.get("connection_pooling_enabled") + and CONNECTION_POOL + and len(CONNECTION_POOL) > 0 + ): + connection = CONNECTION_POOL.pop() + else: + connection = pymysql.connect( + cursorclass=pymysql.cursors.DictCursor, + autocommit=autocommit, + host=self.host, + user=self.user, + password=self.password, + port=self.port, + database=self.database, + **self.options + ) - if self.has_global_connection(): - return self.get_global_connection() + connection.close = self.close_connection - self._connection = pymysql.connect( - cursorclass=pymysql.cursors.DictCursor, - autocommit=True, - host=self.host, - user=self.user, - password=self.password, - port=self.port, - db=self.database, - **self.options - ) self.open = 1 - return self + return connection def reconnect(self): self._connection.connect() @@ -105,6 +152,9 @@ def commit(self): """Transaction""" self._connection.commit() self.transaction_level -= 1 + if self.get_transaction_level() <= 0: + self.open = 0 + self._connection.close() def dry(self): """Transaction""" @@ -121,6 +171,9 @@ def rollback(self): """Transaction""" self._connection.rollback() self.transaction_level -= 1 + if self.get_transaction_level() <= 0: + self.open = 0 + self._connection.close() def get_transaction_level(self): """Transaction""" @@ -130,15 +183,19 @@ def get_cursor(self): return self._cursor def query(self, query, bindings=(), results="*"): - """Make the actual query that will reach the database and come back with a result. + """Make the actual query that + will reach the database and come back with a result. Arguments: - query {string} -- A string query. This could be a qmarked string or a regular query. + query {string} -- A string query. + This could be a qmarked string or a regular query. bindings {tuple} -- A tuple of bindings Keyword Arguments: - results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll' - else it will return 'fetchOne' and return a single record. (default: {"*"}) + results {str|1} -- If the results is equal to an + asterisks it will call 'fetchAll' + else it will return 'fetchOne' and + return a single record. (default: {"*"}) Returns: dict|None -- Returns a dictionary of results or None @@ -147,7 +204,10 @@ def query(self, query, bindings=(), results="*"): if self._dry: return {} - if not self._connection.open: + if not self.open: + if self._connection is None: + self._connection = self.create_connection() + self._connection.connect() self._cursor = self._connection.cursor() @@ -169,6 +229,7 @@ def query(self, query, bindings=(), results="*"): except Exception as e: raise QueryException(str(e)) from e finally: + self._cursor.close() if self.get_transaction_level() <= 0: self.open = 0 self._connection.close() diff --git a/src/masoniteorm/connections/PostgresConnection.py b/src/masoniteorm/connections/PostgresConnection.py index 5132ce3da..5d2146f5e 100644 --- a/src/masoniteorm/connections/PostgresConnection.py +++ b/src/masoniteorm/connections/PostgresConnection.py @@ -26,7 +26,6 @@ def __init__( full_details=None, name=None, ): - self.host = host if port: self.port = int(port) @@ -35,8 +34,10 @@ def __init__( self.database = database self.user = user self.password = password + self.prefix = prefix self.full_details = full_details or {} + self.connection_pool_size = full_details.get("connection_pooling_max_size", 100) self.options = options or {} self._cursor = None self.transaction_level = 0 @@ -57,23 +58,71 @@ def make_connection(self): if self.has_global_connection(): return self.get_global_connection() - schema = self.schema or self.full_details.get("schema") - - self._connection = psycopg2.connect( - database=self.database, - user=self.user, - password=self.password, - host=self.host, - port=self.port, - options=f"-c search_path={schema}" if schema else "", - ) + self._connection = self.create_connection() self._connection.autocommit = True + self.enable_disable_foreign_keys() + self.open = 1 return self + def create_connection(self): + import psycopg2 + + # Initialize the connection pool if the option is set + initialize_size = self.full_details.get("connection_pooling_min_size") + if ( + self.full_details.get("connection_pooling_enabled") + and initialize_size + and len(CONNECTION_POOL) < initialize_size + ): + for _ in range(initialize_size - len(CONNECTION_POOL)): + connection = psycopg2.connect( + database=self.database, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + sslmode=self.options.get("sslmode"), + sslcert=self.options.get("sslcert"), + sslkey=self.options.get("sslkey"), + sslrootcert=self.options.get("sslrootcert"), + options=( + f"-c search_path={self.schema or self.full_details.get('schema')}" + if self.schema or self.full_details.get("schema") + else "" + ), + ) + CONNECTION_POOL.append(connection) + + if ( + self.full_details.get("connection_pooling_enabled") + and CONNECTION_POOL + and len(CONNECTION_POOL) > 0 + ): + connection = CONNECTION_POOL.pop() + else: + connection = psycopg2.connect( + database=self.database, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + sslmode=self.options.get("sslmode"), + sslcert=self.options.get("sslcert"), + sslkey=self.options.get("sslkey"), + sslrootcert=self.options.get("sslrootcert"), + options=( + f"-c search_path={self.schema or self.full_details.get('schema')}" + if self.schema or self.full_details.get("schema") + else "" + ), + ) + + return connection + def get_database_name(self): return self.database @@ -92,6 +141,17 @@ def get_default_post_processor(cls): def reconnect(self): pass + def close_connection(self): + if ( + self.full_details.get("connection_pooling_enabled") + and len(CONNECTION_POOL) < self.connection_pool_size + ): + CONNECTION_POOL.append(self._connection) + else: + self._connection.close() + + self._connection = None + def commit(self): """Transaction""" if self.get_transaction_level() == 1: @@ -139,7 +199,7 @@ def query(self, query, bindings=(), results="*"): dict|None -- Returns a dictionary of results or None """ try: - if self._connection.closed: + if not self._connection or self._connection.closed: self.make_connection() self.set_cursor() @@ -163,4 +223,5 @@ def query(self, query, bindings=(), results="*"): finally: if self.get_transaction_level() <= 0: self.open = 0 - self._connection.close() + self.close_connection() + # self._connection.close() diff --git a/src/masoniteorm/connections/SQLiteConnection.py b/src/masoniteorm/connections/SQLiteConnection.py index 903291a3f..2bb9501cb 100644 --- a/src/masoniteorm/connections/SQLiteConnection.py +++ b/src/masoniteorm/connections/SQLiteConnection.py @@ -63,6 +63,9 @@ def make_connection(self): self._connection.create_function("REGEXP", 2, regexp) self._connection.row_factory = sqlite3.Row + + self.enable_disable_foreign_keys() + self.open = 1 return self diff --git a/src/masoniteorm/exceptions.py b/src/masoniteorm/exceptions.py index e2aed15fb..9d4594db4 100644 --- a/src/masoniteorm/exceptions.py +++ b/src/masoniteorm/exceptions.py @@ -32,3 +32,7 @@ class InvalidUrlConfiguration(Exception): class MultipleRecordsFound(Exception): pass + + +class InvalidArgument(Exception): + pass diff --git a/src/masoniteorm/factories/Factory.py b/src/masoniteorm/factories/Factory.py index d3f7741a3..0aef9ba19 100644 --- a/src/masoniteorm/factories/Factory.py +++ b/src/masoniteorm/factories/Factory.py @@ -3,7 +3,6 @@ class Factory: - _factories = {} _after_creates = {} _faker = None diff --git a/src/masoniteorm/migrations/Migration.py b/src/masoniteorm/migrations/Migration.py index d7dd237b0..01f653529 100644 --- a/src/masoniteorm/migrations/Migration.py +++ b/src/masoniteorm/migrations/Migration.py @@ -60,7 +60,9 @@ def get_unran_migrations(self): all_migrations = [ f.replace(".py", "") for f in listdir(directory_path) - if isfile(join(directory_path, f)) and f != "__init__.py" + if isfile(join(directory_path, f)) + and f != "__init__.py" + and not f.startswith(".") ] all_migrations.sort() unran_migrations = [] @@ -107,26 +109,34 @@ def get_ran_migrations(self): all_migrations = [ f.replace(".py", "") for f in listdir(directory_path) - if isfile(join(directory_path, f)) and f != "__init__.py" + if isfile(join(directory_path, f)) + and f != "__init__.py" + and not f.startswith(".") ] all_migrations.sort() ran = [] database_migrations = self.migration_model.all() for migration in all_migrations: - if migration in database_migrations.pluck("migration"): - ran.append(migration) + matched_migration = database_migrations.where( + "migration", migration + ).first() + if matched_migration: + ran.append( + { + "migration_file": matched_migration.migration, + "batch": matched_migration.batch, + } + ) return ran def migrate(self, migration="all", output=False): - default_migrations = self.get_unran_migrations() migrations = default_migrations if migration == "all" else [migration] batch = self.get_last_batch_number() + 1 for migration in migrations: - try: migration_class = self.locate(migration) @@ -173,7 +183,6 @@ def migrate(self, migration="all", output=False): ) def rollback(self, migration="all", output=False): - default_migrations = self.get_rollback_migrations() migrations = default_migrations if migration == "all" else [migration] @@ -279,3 +288,30 @@ def reset(self, migration="all"): def refresh(self, migration="all"): self.reset(migration) self.migrate(migration) + + def drop_all_tables(self, ignore_fk=False): + if self.command_class: + self.command_class.line("Dropping all tables") + + if ignore_fk: + self.schema.disable_foreign_key_constraints() + + for table in self.schema.get_all_tables(): + self.schema.drop(table) + + if ignore_fk: + self.schema.enable_foreign_key_constraints() + + if self.command_class: + self.command_class.line("All tables dropped") + + def fresh(self, ignore_fk=False, migration="all"): + self.drop_all_tables(ignore_fk=ignore_fk) + self.create_table_if_not_exists() + + if not self.get_unran_migrations(): + if self.command_class: + self.command_class.line("Nothing to migrate") + return + + self.migrate(migration) diff --git a/src/masoniteorm/models/MigrationModel.py b/src/masoniteorm/models/MigrationModel.py index d7677db4d..d915c5590 100644 --- a/src/masoniteorm/models/MigrationModel.py +++ b/src/masoniteorm/models/MigrationModel.py @@ -2,7 +2,6 @@ class MigrationModel(Model): - __table__ = "migrations" __fillable__ = ["migration", "batch"] __timestamps__ = None diff --git a/src/masoniteorm/models/Model.py b/src/masoniteorm/models/Model.py index 4656af10b..3c591494d 100644 --- a/src/masoniteorm/models/Model.py +++ b/src/masoniteorm/models/Model.py @@ -1,19 +1,21 @@ +import inspect import json -from datetime import datetime, date as datetimedate, time as datetimetime import logging +from datetime import date as datetimedate +from datetime import datetime +from datetime import time as datetimetime from decimal import Decimal - -from inflection import tableize, underscore -import inspect +from typing import Any, Dict import pendulum +from inflection import tableize, underscore -from ..query import QueryBuilder from ..collection import Collection -from ..observers import ObservesEvents -from ..scopes import TimeStampsMixin from ..config import load_config from ..exceptions import ModelNotFound +from ..observers import ObservesEvents +from ..query import QueryBuilder +from ..scopes import TimeStampsMixin """This is a magic class that will help using models like User.first() instead of having to instatiate a class like User().first() @@ -58,8 +60,11 @@ class JsonCast: """Casts a value to JSON""" def get(self, value): - if not isinstance(value, str): - return json.dumps(value) + if isinstance(value, str): + try: + return json.loads(value) + except ValueError: + return None return value @@ -69,7 +74,7 @@ def set(self, value): json.loads(value) return value - return json.dumps(value) + return json.dumps(value, default=str) class IntCast: @@ -130,7 +135,7 @@ class Model(TimeStampsMixin, ObservesEvents, metaclass=ModelMeta): """ __fillable__ = ["*"] - __guarded__ = ["*"] + __guarded__ = [] __dry__ = False __table__ = None __connection__ = "default" @@ -157,105 +162,113 @@ class Model(TimeStampsMixin, ObservesEvents, metaclass=ModelMeta): date_created_at = "created_at" date_updated_at = "updated_at" + builder: QueryBuilder + """Pass through will pass any method calls to the model directly through to the query builder. Anytime one of these methods are called on the model it will actually be called on the query builder class. """ - __passthrough__ = [ - "add_select", - "aggregate", - "all", - "avg", - "between", - "bulk_create", - "chunk", - "count", - "decrement", - "delete", - "distinct", - "doesnt_exist", - "doesnt_have", - "exists", - "find_or_404", - "find_or_fail", - "first_or_fail", - "first", - "first_where", - "first_or_create", - "force_update", - "from_", - "from_raw", - "get", - "get_table_schema", - "group_by_raw", - "group_by", - "has", - "having", - "having_raw", - "increment", - "in_random_order", - "join_on", - "join", - "joins", - "last", - "left_join", - "limit", - "lock_for_update", - "make_lock", - "max", - "min", - "new_from_builder", - "new", - "not_between", - "offset", - "on", - "or_where", - "or_where_null", - "order_by_raw", - "order_by", - "paginate", - "right_join", - "select_raw", - "select", - "set_global_scope", - "set_schema", - "shared_lock", - "simple_paginate", - "skip", - "statement", - "sum", - "table_raw", - "take", - "to_qmark", - "to_sql", - "truncate", - "update", - "when", - "where_between", - "where_column", - "where_date", - "or_where_doesnt_have", - "or_has", - "or_where_has", - "or_doesnt_have", - "or_where_not_exists", - "or_where_date", - "where_exists", - "where_from_builder", - "where_has", - "where_in", - "where_like", - "where_not_between", - "where_not_in", - "where_not_like", - "where_not_null", - "where_null", - "where_raw", - "without_global_scopes", - "where", - "where_doesnt_have", - "with_", - "with_count", - ] + __passthrough__ = set( + ( + "add_select", + "aggregate", + "all", + "avg", + "between", + "bulk_create", + "chunk", + "count", + "decrement", + "delete", + "distinct", + "doesnt_exist", + "doesnt_have", + "exists", + "find_or", + "find_or_404", + "find_or_fail", + "first_or_fail", + "first", + "first_where", + "first_or_create", + "force_update", + "from_", + "from_raw", + "get", + "get_table_schema", + "group_by_raw", + "group_by", + "has", + "having", + "having_raw", + "increment", + "in_random_order", + "join_on", + "join", + "joins", + "last", + "left_join", + "limit", + "lock_for_update", + "make_lock", + "max", + "min", + "new_from_builder", + "new", + "not_between", + "offset", + "on", + "or_where", + "or_where_null", + "order_by_raw", + "order_by", + "paginate", + "right_join", + "select_raw", + "select", + "set_global_scope", + "set_schema", + "shared_lock", + "simple_paginate", + "skip", + "statement", + "sum", + "table_raw", + "take", + "to_qmark", + "to_sql", + "truncate", + "update", + "when", + "where_between", + "where_column", + "where_date", + "or_where_doesnt_have", + "or_has", + "or_where_has", + "or_doesnt_have", + "or_where_not_exists", + "or_where_date", + "where_exists", + "where_from_builder", + "where_has", + "where_in", + "where_like", + "where_not_between", + "where_not_in", + "where_not_like", + "where_not_null", + "where_null", + "where_raw", + "without_global_scopes", + "where", + "where_doesnt_have", + "with_", + "with_count", + "latest", + "oldest", + "value", + ) + ) __cast_map__ = {} @@ -337,11 +350,14 @@ def get_builder(self): table=self.get_table_name(), # connection_details=self.get_connection_details(), model=self, - scopes=self._scopes, + scopes=self._scopes.get(self.__class__), dry=self.__dry__, ) - return self.builder.select(*self.__selects__) + return self.builder.select(*self.get_selects()) + + def get_selects(self): + return self.__selects__ @classmethod def get_columns(cls): @@ -359,6 +375,15 @@ def boot(self): if class_name.endswith("Mixin"): getattr(self, "boot_" + class_name)(self.get_builder()) + elif ( + base_class != Model + and issubclass(base_class, Model) + and "__fillable__" in base_class.__dict__ + and "__guarded__" in base_class.__dict__ + ): + raise AttributeError( + f"{type(self).__name__} must specify either __fillable__ or __guarded__ properties, but not both." + ) self._booted = True self.observe_events(self, "booted") @@ -366,7 +391,7 @@ def boot(self): self.append_passthrough(list(self.get_builder()._macros.keys())) def append_passthrough(self, passthrough): - self.__passthrough__ += passthrough + self.__passthrough__.update(passthrough) return self @classmethod @@ -404,12 +429,12 @@ def find(cls, record_id, query=False): builder = cls().where(cls.get_primary_key(), record_id) if query: - return builder.to_sql() + return builder else: if isinstance(record_id, (list, tuple)): return builder.get() - return builder.first() + return builder.first() @classmethod def find_or_fail(cls, record_id, query=False): @@ -519,45 +544,33 @@ def new_collection(cls, data): return Collection(data) @classmethod - def create(cls, dictionary=None, query=False, cast=True, **kwargs): + + + def create(cls, dictionary=None, query=False, cast=True, **kwargs)>>>> 3.x """Creates new records based off of a dictionary as well as data set on the model such as fillable values. Args: dictionary (dict, optional): [description]. Defaults to {}. query (bool, optional): [description]. Defaults to False. + cast (bool, optional): [description]. Whether or not to cast passed values. Returns: self: A hydrated version of a model """ - - if not dictionary: - dictionary = kwargs - - if cls.__fillable__ != ["*"]: - d = {} - for x in cls.__fillable__: - if x in dictionary: - if cast == True: - d.update({x: cls._set_casted_value(x, dictionary[x])}) - else: - d.update({x: dictionary[x]}) - dictionary = d - - if cls.__guarded__ != ["*"]: - for x in cls.__guarded__: - if x in dictionary: - dictionary.pop(x) - if query: return cls.builder.create( - dictionary, query=True, id_key=cls.__primary_key__ - ).to_sql() + dictionary, query=True, cast=cast, **kwargs + ) - return cls.builder.create(dictionary, id_key=cls.__primary_key__) + return cls.builder.create(dictionary, cast=cast, **kwargs) @classmethod - def _set_casted_value(cls, attribute, value): + def cast_value(cls, attribute: str, value: Any): + """ + Given an attribute name and a value, casts the value using the model's registered caster. + If no registered caster exists, returns the unmodified value. + """ cast_method = cls.__casts__.get(attribute) cast_map = cls.get_cast_map(cls) @@ -571,6 +584,15 @@ def _set_casted_value(cls, attribute, value): return cast_method(value) return value + @classmethod + def cast_values(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Runs provided dictionary through all model casters and returns the result. + + Does not mutate the passed dictionary. + """ + return {x: cls.cast_value(x, dictionary[x]) for x in dictionary} + def fresh(self): return ( self.get_builder() @@ -637,7 +659,6 @@ def serialize(self, exclude=None, include=None): remove_keys = [] for key, value in serialized_dictionary.items(): - if key in self.__hidden__: remove_keys.append(key) if hasattr(value, "serialize"): @@ -660,7 +681,7 @@ def to_json(self): Returns: string """ - return json.dumps(self.serialize()) + return json.dumps(self.serialize(), default=str) @classmethod def first_or_create(cls, wheres, creates: dict = None): @@ -688,7 +709,7 @@ def update_or_create(cls, wheres, updates): total.update(updates) total.update(wheres) if not record: - return self.create(total, id_key=cls.get_primary_key()) + return self.create(total, id_key=cls.get_primary_key()).fresh() return self.where(wheres).update(total) @@ -788,6 +809,23 @@ def method(*args, **kwargs): return None + def only(self, attributes: list) -> dict: + if isinstance(attributes, str): + attributes = [attributes] + results: dict[str, Any] = {} + for attribute in attributes: + if " as " in attribute: + attribute, alias = attribute.split(" as ") + alias = alias.strip() + attribute = attribute.strip() + else: + alias = attribute.strip() + attribute = attribute.strip() + + results[alias] = self.get_raw_attribute(attribute) + + return results + def __setattr__(self, attribute, value): if hasattr(self, "set_" + attribute + "_attribute"): method = getattr(self, "set_" + attribute + "_attribute") @@ -840,12 +878,19 @@ def save(self, query=False): if not query: if self.is_loaded(): + + result = builder.update( + self.__dirty_attributes__, ignore_mass_assignment=True + ) + builder.update(self.__dirty_attributes__) + else: result = self.create( self.__dirty_attributes__, query=query, id_key=self.get_primary_key(), + ignore_mass_assignment=True, ) self.observe_events(self, "saved") self.__dirty_attributes__ = {} @@ -854,7 +899,9 @@ def save(self, query=False): return result if self.is_loaded(): - result = builder.update(self.__dirty_attributes__, dry=query).to_sql() + result = builder.update( + self.__dirty_attributes__, dry=query, ignore_mass_assignment=True + ) else: result = self.create(self.__dirty_attributes__, query=query) @@ -980,7 +1027,8 @@ def get_new_date(self, _datetime=None): def get_new_datetime_string(self, _datetime=None): """ - Get the attributes that should be converted to dates. + Given an optional datetime value, constructs and returns a new datetime string. + If no datetime is specified, returns the current time. :rtype: list """ @@ -1004,7 +1052,6 @@ def set_appends(self, appends): return self def save_many(self, relation, relating_records): - if isinstance(relating_records, Model): raise ValueError( "Saving many records requires an iterable like a collection or a list of models and not a Model object. To attach a model, use the 'attach' method." @@ -1020,7 +1067,6 @@ def save_many(self, relation, relating_records): related.attach_related(self, related_record) def detach_many(self, relation, relating_records): - if isinstance(relating_records, Model): raise ValueError( "Detaching many records requires an iterable like a collection or a list of models and not a Model object. To detach a model, use the 'detach' method." @@ -1063,6 +1109,44 @@ def detach(self, relation, related_record): return related.detach(self, related_record) + def save_quietly(self): + """This method calls the save method on a model without firing the saved & saving observer events. Saved/Saving + are toggled back on once save_quietly has been ran. + + Instead of calling: + + User().save(...) + + you can use this: + + User.save_quietly(...) + """ + self.without_events() + saved = self.save() + self.with_events() + return saved + + def delete_quietly(self): + """This method calls the delete method on a model without firing the delete & deleting observer events. + Instead of calling: + + User().delete(...) + + you can use this: + + User.delete_quietly(...) + + Returns: + self + """ + delete = ( + self.without_events() + .where(self.get_primary_key(), self.get_primary_key_value()) + .delete() + ) + self.with_events() + return delete + def attach_related(self, relation, related_record): related = getattr(self.__class__, relation) @@ -1072,3 +1156,35 @@ def attach_related(self, relation, related_record): related_record.save() return related.attach_related(self, related_record) + + @classmethod + def filter_fillable(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters provided dictionary to only include fields specified in the model's __fillable__ property + + Passed dictionary is not mutated. + """ + if cls.__fillable__ != ["*"]: + dictionary = {x: dictionary[x] for x in cls.__fillable__ if x in dictionary} + return dictionary + + @classmethod + def filter_mass_assignment(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters the provided dictionary in preparation for a mass-assignment operation + + Wrapper around filter_fillable() & filter_guarded(). Passed dictionary is not mutated. + """ + return cls.filter_guarded(cls.filter_fillable(dictionary)) + + @classmethod + def filter_guarded(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters provided dictionary to exclude fields specified in the model's __guarded__ property + + Passed dictionary is not mutated. + """ + if cls.__guarded__ == ["*"]: + # If all fields are guarded, all data should be filtered + return {} + return {f: dictionary[f] for f in dictionary if f not in cls.__guarded__} diff --git a/src/masoniteorm/models/Model.pyi b/src/masoniteorm/models/Model.pyi index d80c2d3bd..1cb203636 100644 --- a/src/masoniteorm/models/Model.pyi +++ b/src/masoniteorm/models/Model.pyi @@ -1,12 +1,12 @@ -from typing import Any +from typing import Any, Dict + from typing_extensions import Self from ..query.QueryBuilder import QueryBuilder class Model: def add_select(alias: str, callable: Any): - """Specifies a select subquery. - """ + """Specifies a select subquery.""" pass def aggregate(aggregate: str, column: str, alias: str): """Helper function to aggregate. @@ -53,6 +53,19 @@ class Model: pass def bulk_create(creates: dict, query: bool = False): pass + def cast_value(attribute: str, value: Any): + """ + Given an attribute name and a value, casts the value using the model's registered caster. + If no registered caster exists, returns the unmodified value. + """ + pass + def cast_values(dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Runs provided dictionary through all model casters and returns the result. + + Does not mutate the passed dictionary. + """ + pass def chunk(chunk_amount: str | int): pass def count(column: str = None): @@ -65,6 +78,24 @@ class Model: self """ pass + def create( + dictionary: Dict[str, Any] = None, + query: bool = False, + cast: bool = False, + **kwargs + ): + """Creates new records based off of a dictionary as well as data set on the model + such as fillable values. + + Args: + dictionary (dict, optional): [description]. Defaults to {}. + query (bool, optional): [description]. Defaults to False. + cast (bool, optional): [description]. Whether or not to cast passed values. + + Returns: + self: A hydrated version of a model + """ + pass def decrement(column: str, value: int = 1): """Decrements a column's value. @@ -90,8 +121,7 @@ class Model: """ pass def distinct(boolean: bool = True): - """Species that the select query should be a SELECT DISTINCT query. - """ + """Species that the select query should be a SELECT DISTINCT query.""" pass def doesnt_exist() -> bool: """Determines if any rows exist for the current query. @@ -114,6 +144,27 @@ class Model: Bool - True or False """ pass + def filter_fillable(dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters provided dictionary to only include fields specified in the model's __fillable__ property + + Passed dictionary is not mutated. + """ + pass + def filter_mass_assignment(dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters the provided dictionary in preparation for a mass-assignment operation + + Wrapper around filter_fillable() & filter_guarded(). Passed dictionary is not mutated. + """ + pass + def filter_guarded(dictionary: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters provided dictionary to exclude fields specified in the model's __guarded__ property + + Passed dictionary is not mutated. + """ + pass def find_or_404(record_id: str | int): """Finds a row by the primary key ID (Requires a model) or raise an 404 exception. @@ -458,8 +509,7 @@ class Model: def simple_paginate(per_page: int, page: int = 1): pass def skip(*args, **kwargs): - """Alias for limit method. - """ + """Alias for limit method.""" pass def statement(query: str, bindings: list = ()): pass @@ -502,16 +552,16 @@ class Model: pass def truncate(foreign_keys: bool = False): pass - def update(updates: dict, dry: bool = False, force: bool = False): + def update( + updates: dict, dry: bool = False, force: bool = False, cast: bool = False + ): """Specifies columns and values to be updated. Arguments: updates {dictionary} -- A dictionary of columns and values to update. - dry {bool} -- Whether a query should actually run - force {bool} -- Force the update even if there are no changes - - Keyword Arguments: - dry {bool} -- Whether the query should be executed. (default: {False}) + dry {bool, optional} -- Whether a query should actually run + force {bool, optional} -- Force the update even if there are no changes + cast {bool, optional} -- Run all values through model's casters Returns: self @@ -520,8 +570,7 @@ class Model: def when(conditional: bool, callback: callable): pass def where_between(*args, **kwargs): - """Alias for between - """ + """Alias for between""" pass def where_column(column1: str, column2: str): """Specifies where two columns equal eachother. @@ -619,8 +668,7 @@ class Model: """ pass def where_not_between(*args: Any, **kwargs: Any): - """Alias for not_between - """ + """Alias for not_between""" pass def where_not_in(column: str, wheres: list = []): """Specifies where a column does not contain a list of a values. diff --git a/src/masoniteorm/observers/ObservesEvents.py b/src/masoniteorm/observers/ObservesEvents.py index c9d21274d..3e2cb0406 100644 --- a/src/masoniteorm/observers/ObservesEvents.py +++ b/src/masoniteorm/observers/ObservesEvents.py @@ -16,14 +16,12 @@ def observe(cls, observer): @classmethod def without_events(cls): - """Sets __has_events__ attribute on model to false. - """ + """Sets __has_events__ attribute on model to false.""" cls.__has_events__ = False return cls @classmethod def with_events(cls): - """Sets __has_events__ attribute on model to True. - """ - cls.__has_events__ = False + """Sets __has_events__ attribute on model to True.""" + cls.__has_events__ = True return cls diff --git a/src/masoniteorm/query/QueryBuilder.py b/src/masoniteorm/query/QueryBuilder.py index 836a0814b..18c5143f4 100644 --- a/src/masoniteorm/query/QueryBuilder.py +++ b/src/masoniteorm/query/QueryBuilder.py @@ -1,38 +1,36 @@ import inspect from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, List, Optional, Callable -from ..config import load_config from ..collection.Collection import Collection +from ..config import load_config +from ..exceptions import ( + HTTP404, + ConnectionNotRegistered, + ModelNotFound, + MultipleRecordsFound, + InvalidArgument, +) from ..expressions.expressions import ( - JoinClause, - SubGroupExpression, - SubSelectExpression, - SelectExpression, + AggregateExpression, BetweenExpression, + FromTable, GroupByExpression, - AggregateExpression, - QueryExpression, + HavingExpression, + JoinClause, OrderByExpression, + QueryExpression, + SelectExpression, + SubGroupExpression, + SubSelectExpression, UpdateQueryExpression, - HavingExpression, - FromTable, ) - -from ..scopes import BaseScope -from ..schema import Schema from ..observers import ObservesEvents -from ..exceptions import ( - ModelNotFound, - HTTP404, - ConnectionNotRegistered, - ModelNotFound, - MultipleRecordsFound, -) from ..pagination import LengthAwarePaginator, SimplePaginator -from .EagerRelation import EagerRelations -from datetime import datetime, date as datetimedate, time as datetimetime -import pendulum from ..schema import Schema +from ..scopes import BaseScope +from .EagerRelation import EagerRelations class QueryBuilder(ObservesEvents): @@ -50,6 +48,7 @@ def __init__( scopes=None, schema=None, dry=False, + config_path=None, ): """QueryBuilder initializer @@ -60,6 +59,7 @@ def __init__( connection {masoniteorm.connection.Connection} -- A connection class (default: {None}) table {str} -- the name of the table (default: {""}) """ + self.config_path = config_path self.grammar = grammar self.table(table) self.dry = dry @@ -106,7 +106,7 @@ def __init__( self.set_action("select") if not self._connection_details: - DB = load_config().DB + DB = load_config(config_path=self.config_path).DB self._connection_details = DB.get_connection_details() self.on(connection) @@ -385,7 +385,7 @@ def method(*args, **kwargs): ) def on(self, connection): - DB = load_config().DB + DB = load_config(self.config_path).DB if connection == "default": self.connection = self._connection_details.get("default") @@ -460,15 +460,24 @@ def select_raw(self, query): def get_processor(self): return self.connection_class.get_default_post_processor()() - def bulk_create(self, creates, query=False): - model = None + def bulk_create( + self, creates: List[Dict[str, Any]], query: bool = False, cast: bool = False + ): self.set_action("bulk_create") - - self._creates = creates + model = None if self._model: model = self._model + self._creates = [] + for unsorted_create in creates: + if model: + unsorted_create = model.filter_mass_assignment(unsorted_create) + if cast: + unsorted_create = model.cast_values(unsorted_create) + # sort the dicts by key so the values inserted align with the correct column + self._creates.append(dict(sorted(unsorted_create.items()))) + if query: return self @@ -487,7 +496,15 @@ def bulk_create(self, creates, query=False): return processed_results - def create(self, creates=None, query=False, id_key="id", **kwargs): + def create( + self, + creates: Optional[Dict[str, Any]] = None, + query: bool = False, + id_key: str = "id", + cast: bool = False, + ignore_mass_assignment: bool = False, + **kwargs, + ): """Specifies a dictionary that should be used to create new values. Arguments: @@ -496,18 +513,20 @@ def create(self, creates=None, query=False, id_key="id", **kwargs): Returns: self """ - self._creates = {} - - if not creates: - creates = kwargs - + self.set_action("insert") model = None + self._creates = creates if creates else kwargs if self._model: model = self._model - - self.set_action("insert") - self._creates.update(creates) + # Update values with related record's + self._creates.update(self._creates_related) + # Filter __fillable/__guarded__ fields + if not ignore_mass_assignment: + self._creates = model.filter_mass_assignment(self._creates) + # Cast values if necessary + if cast: + self._creates = model.cast_values(self._creates) if query: return self @@ -522,18 +541,6 @@ def create(self, creates=None, query=False, id_key="id", **kwargs): if not self.dry: connection = self.new_connection() - if model: - d = {} - for x in self._creates: - if x in self._creates: - if kwargs.get("cast") == True: - d.update( - {x: self._model._set_casted_value(x, self._creates[x])} - ) - else: - d.update({x: self._creates[x]}) - d.update(self._creates_related) - self._creates = d query_result = connection.query(self.to_qmark(), self._bindings, results=1) if model: @@ -615,7 +622,6 @@ def where(self, column, *args): ) elif isinstance(column, dict): for key, value in column.items(): - self._wheres += ((QueryExpression(key, "=", value, "value")),) elif isinstance(value, QueryBuilder): self._wheres += ( @@ -896,7 +902,6 @@ def or_where_null(self, column): def chunk(self, chunk_amount): chunk_connection = self.new_connection() for result in chunk_connection.select_many(self.to_sql(), (), chunk_amount): - yield self.prepare_result(result) def where_not_null(self, column: str): @@ -1382,14 +1387,22 @@ def skip(self, *args, **kwargs): """Alias for limit method""" return self.offset(*args, **kwargs) - def update(self, updates: dict, dry=False, force=False): + def update( + self, + updates: Dict[str, Any], + dry: bool = False, + force: bool = False, + cast: bool = False, + ignore_mass_assignment: bool = False, + ): """Specifies columns and values to be updated. Arguments: updates {dictionary} -- A dictionary of columns and values to update. - - Keyword Arguments: - dry {bool} -- Whether the query should be executed. (default: {False}) + dry {bool, optional}: Do everything except execute the query against the DB + force {bool, optional}: Force an update statement to be executed even if nothing was changed + cast {bool, optional}: Run all values through model's casters + ignore_mass_assignment {bool, optional}: Whether the update should ignore mass assignment on the model Returns: self @@ -1400,13 +1413,15 @@ def update(self, updates: dict, dry=False, force=False): if self._model: model = self._model + # Filter __fillable/__guarded__ fields + if not ignore_mass_assignment: + updates = model.filter_mass_assignment(updates) if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) additional.update({model.get_primary_key(): model.get_primary_key_value()}) self.observe_events(model, "updating") - # update only attributes with changes if model and not model.__force_update__ and not force: changes = {} @@ -1591,32 +1606,6 @@ def count(self, column=None): else: return self - def cast_value(self, value): - - if isinstance(value, datetime): - return str(pendulum.instance(value)) - elif isinstance(value, datetimedate): - return str(pendulum.datetime(value.year, value.month, value.day)) - elif isinstance(value, datetimetime): - return str(pendulum.parse(f"{value.hour}:{value.minute}:{value.second}")) - - return value - - def cast_dates(self, result): - if isinstance(result, dict): - new_dict = {} - for key, value in result.items(): - new_dict.update({key: self.cast_value(value)}) - - return new_dict - elif isinstance(result, list): - new_list = [] - for res in result: - new_list.append(self.cast_dates(res)) - return new_list - - return result - def max(self, column): """Aggregates a columns values. @@ -1714,14 +1703,13 @@ def first(self, fields=None, query=False): if not fields: fields = [] - if fields: - self.select(fields) + self.select(fields).limit(1) if query: - return self.limit(1) + return self result = self.new_connection().query( - self.limit(1).to_qmark(), self._bindings, results=1 + self.to_qmark(), self._bindings, results=1 ) return self.prepare_result(result) @@ -1765,6 +1753,9 @@ def sole(self, query=False): return result.first() + def sole_value(self, column: str, query=False): + return self.sole()[column] + def first_where(self, column, *args): """Gets the first record with the given key / value pair""" if not args: @@ -1779,11 +1770,13 @@ def last(self, column=None, query=False): dictionary -- Returns a dictionary of results. """ _column = column if column else self._model.get_primary_key() + self.limit(1).order_by(_column, direction="DESC") + if query: - return self.limit(1).order_by(_column, direction="DESC") + return self result = self.new_connection().query( - self.limit(1).order_by(_column, direction="DESC").to_qmark(), + self.to_qmark(), self._bindings, results=1, ) @@ -1793,7 +1786,7 @@ def last(self, column=None, query=False): def _get_eager_load_result(self, related, collection): return related.eager_load_from_collection(collection) - def find(self, record_id): + def find(self, record_id, query=False): """Finds a row by the primary key ID. Requires a model Arguments: @@ -1802,8 +1795,36 @@ def find(self, record_id): Returns: Model|None """ + self.where(self._model.get_primary_key(), record_id) + + if query: + return self + + return self.first() + + def find_or(self, record_id: int, callback: Callable, args=None): + """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. + + Arguments: + record_id {int} -- The ID of the primary key to fetch. + callback {Callable} -- The function to call if no record is found. + + Returns: + Model|Callable + """ - return self.where(self._model.get_primary_key(), record_id).first() + if not callable(callback): + raise InvalidArgument("A callback must be callable.") + + result = self.find(record_id=record_id) + + if not result: + if not args: + return callback() + else: + return callback(*args) + + return result def find_or_fail(self, record_id): """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. @@ -1845,7 +1866,7 @@ def first_or_fail(self, query=False): """ if query: - return self.limit(1) + return self.first(query=True) result = self.first() @@ -1928,25 +1949,36 @@ def _register_relationships_to_model( Returns: self """ - if isinstance(hydrated_model, Collection): + if related_result and isinstance(hydrated_model, Collection): + map_related = self._map_related(related_result, related) for model in hydrated_model: if isinstance(related_result, Collection): - related.register_related(relation_key, model, related_result) + related.register_related(relation_key, model, map_related) else: - model.add_relation({relation_key: related_result or None}) + model.add_relation({relation_key: map_related or None}) else: hydrated_model.add_relation({relation_key: related_result or None}) return self + def _map_related(self, related_result, related): + if related.__class__.__name__ == 'MorphTo': + return related_result + elif related.__class__.__name__ in ['HasOneThrough', 'HasManyThrough']: + return related_result.group_by(related.local_key) + + return related_result.group_by(related.foreign_key) + def all(self, selects=[], query=False): """Returns all records from the table. Returns: dictionary -- Returns a dictionary of results. """ + self.select(*selects) + if query: - return self.to_sql() + return self result = self.new_connection().query(self.to_qmark(), self._bindings) or [] @@ -2058,9 +2090,8 @@ def to_sql(self): Returns: self """ - for name, scope in self._global_scopes.get(self._action, {}).items(): - scope(self) + self.run_scopes() grammar = self.get_grammar() sql = grammar.compile(self._action, qmark=False).to_sql() return sql @@ -2087,11 +2118,8 @@ def to_qmark(self): Returns: self """ - grammar = self.get_grammar() - - for name, scope in self._global_scopes.get(self._action, {}).items(): - scope(self) + self.run_scopes() grammar = self.get_grammar() sql = grammar.compile(self._action, qmark=True).to_sql() @@ -2145,7 +2173,6 @@ def min(self, column): return self def _extract_operator_value(self, *args): - operators = [ "=", ">", @@ -2270,3 +2297,30 @@ def get_schema(self): return Schema( connection=self.connection, connection_details=self._connection_details ) + + def latest(self, *fields): + """Gets the latest record. + + Returns: + querybuilder + """ + + if not fields: + fields = ("created_at",) + + return self.order_by(column=",".join(fields), direction="DESC") + + def oldest(self, *fields): + """Gets the oldest record. + + Returns: + querybuilder + """ + + if not fields: + fields = ("created_at",) + + return self.order_by(column=",".join(fields), direction="ASC") + + def value(self, column: str): + return self.get().first()[column] diff --git a/src/masoniteorm/query/grammars/BaseGrammar.py b/src/masoniteorm/query/grammars/BaseGrammar.py index bb2b24c6e..6dd248603 100644 --- a/src/masoniteorm/query/grammars/BaseGrammar.py +++ b/src/masoniteorm/query/grammars/BaseGrammar.py @@ -251,13 +251,13 @@ def process_joins(self, qmark=False): on_string += f"{keyword} {self._table_column_string(clause.column1)} {clause.equality} {self._table_column_string(clause.column2)} " else: if clause.value_type == "NULL": - sql_string = self.where_null_string() + sql_string = f"{self.where_null_string()} " on_string += sql_string.format( keyword=keyword, column=self.process_column(clause.column), ) elif clause.value_type == "NOT NULL": - sql_string = self.where_not_null_string() + sql_string = f"{self.where_not_null_string()} " on_string += sql_string.format( keyword=keyword, column=self.process_column(clause.column), @@ -292,7 +292,6 @@ def _compile_key_value_equals(self, qmark=False): """ sql = "" for update in self._updates: - if update.update_type == "increment": sql_string = self.increment_string() elif update.update_type == "decrement": @@ -304,7 +303,6 @@ def _compile_key_value_equals(self, qmark=False): value = update.value if isinstance(column, dict): for key, value in column.items(): - if hasattr(value, "expression"): sql += self.column_value_string().format( column=self._table_column_string(key), @@ -888,7 +886,6 @@ def _table_column_string(self, column, alias=None, separator=""): """ table = None if column and "." in column: - table, column = column.split(".") if column == "*": diff --git a/src/masoniteorm/query/grammars/MSSQLGrammar.py b/src/masoniteorm/query/grammars/MSSQLGrammar.py index 7a08c323c..50ec4f18c 100644 --- a/src/masoniteorm/query/grammars/MSSQLGrammar.py +++ b/src/masoniteorm/query/grammars/MSSQLGrammar.py @@ -10,7 +10,6 @@ class MSSQLGrammar(BaseGrammar): "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", - "AVG": "AVG", } join_keywords = { @@ -37,7 +36,7 @@ def select_no_table(self): return "SELECT {columns}" def select_format(self): - return "SELECT {keyword} {limit} {columns} FROM {table} {lock} {joins} {wheres} {group_by} {order_by} {offset} {having}" + return "SELECT {keyword} {limit} {columns} FROM {table} {lock} {joins} {wheres} {group_by} {having} {order_by} {offset}" def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" @@ -110,6 +109,9 @@ def aggregate_string(self): def subquery_string(self): return "({query})" + def subquery_alias_string(self): + return "AS {alias}" + def where_group_string(self): return "{keyword} {value}" @@ -144,10 +146,10 @@ def offset_string(self): return "OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY" def increment_string(self): - return "{column} = {column} + {value}" + return "{column} = {column} + '{value}'{separator}" def decrement_string(self): - return "{column} = {column} - {value}" + return "{column} = {column} - '{value}'{separator}" def aggregate_string_with_alias(self): return "{aggregate_function}({column}) AS {alias}" diff --git a/src/masoniteorm/query/grammars/MySQLGrammar.py b/src/masoniteorm/query/grammars/MySQLGrammar.py index 4796a162b..5993798c7 100644 --- a/src/masoniteorm/query/grammars/MySQLGrammar.py +++ b/src/masoniteorm/query/grammars/MySQLGrammar.py @@ -10,7 +10,6 @@ class MySQLGrammar(BaseGrammar): "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", - "AVG": "AVG", } join_keywords = { @@ -50,7 +49,7 @@ class MySQLGrammar(BaseGrammar): locks = {"share": "LOCK IN SHARE MODE", "update": "FOR UPDATE"} def select_format(self): - return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {order_by} {limit} {offset} {having} {lock}" + return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}" def select_no_table(self): return "SELECT {columns} {lock}" @@ -140,10 +139,11 @@ def column_value_string(self): return "{column} = {value}{separator}" def increment_string(self): - return "{column} = {column} + {value}" + + return "{column} = {column} + '{value}'{separator}" def decrement_string(self): - return "{column} = {column} - {value}" + return "{column} = {column} - '{value}'{separator}" def create_column_string(self): return "{column} {data_type}{length}{nullable}{default_value}, " diff --git a/src/masoniteorm/query/grammars/PostgresGrammar.py b/src/masoniteorm/query/grammars/PostgresGrammar.py index b9c34bf30..926d04ab9 100644 --- a/src/masoniteorm/query/grammars/PostgresGrammar.py +++ b/src/masoniteorm/query/grammars/PostgresGrammar.py @@ -11,7 +11,6 @@ class PostgresGrammar(BaseGrammar): "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", - "AVG": "AVG", } join_keywords = { @@ -38,7 +37,7 @@ def select_no_table(self): return "SELECT {columns} {lock}" def select_format(self): - return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {order_by} {limit} {offset} {having} {lock}" + return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}" def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" @@ -101,10 +100,11 @@ def column_value_string(self): return "{column} = {value}{separator}" def increment_string(self): - return "{column} = {column} + {value}" + return "{column} = {column} + '{value}'{separator}" def decrement_string(self): - return "{column} = {column} - {value}" + return "{column} = {column} - '{value}'{separator}" + def create_column_string(self): return "{column} {data_type}{length}{nullable}, " diff --git a/src/masoniteorm/query/grammars/SQLiteGrammar.py b/src/masoniteorm/query/grammars/SQLiteGrammar.py index 6e7f1c4b2..6b61d3018 100644 --- a/src/masoniteorm/query/grammars/SQLiteGrammar.py +++ b/src/masoniteorm/query/grammars/SQLiteGrammar.py @@ -11,7 +11,6 @@ class SQLiteGrammar(BaseGrammar): "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", - "AVG": "AVG", } join_keywords = { @@ -35,7 +34,7 @@ class SQLiteGrammar(BaseGrammar): locks = {"share": "", "update": ""} def select_format(self): - return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {order_by} {limit} {offset} {having} {lock}" + return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}" def select_no_table(self): return "SELECT {columns} {lock}" @@ -98,10 +97,11 @@ def column_value_string(self): return "{column} = {value}{separator}" def increment_string(self): - return "{column} = {column} + {value}" + return "{column} = {column} + '{value}'{separator}" def decrement_string(self): - return "{column} = {column} - {value}" + return "{column} = {column} - '{value}'{separator}" + def column_exists_string(self): return "SELECT column_name FROM information_schema.columns WHERE table_name='{clean_table}' and column_name={value}" diff --git a/src/masoniteorm/relationships/BaseRelationship.py b/src/masoniteorm/relationships/BaseRelationship.py index a9922c0b0..7767199aa 100644 --- a/src/masoniteorm/relationships/BaseRelationship.py +++ b/src/masoniteorm/relationships/BaseRelationship.py @@ -1,4 +1,3 @@ -from distutils.command.build import build from ..collection import Collection @@ -53,7 +52,7 @@ def __get__(self, instance, owner): object -- Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ - relationship = self.fn(self)() + relationship = self.fn(instance)() self.set_keys(instance, attribute) self._related_builder = relationship.builder @@ -153,7 +152,7 @@ def get_related(self, query, relation, eagers=None, callback=None): if isinstance(relation, Collection): return builder.where_in( f"{builder.get_table_name()}.{self.foreign_key}", - relation.pluck(self.local_key, keep_nulls=False).unique(), + Collection(relation._get_value(self.local_key)).unique(), ).get() else: return builder.where( diff --git a/src/masoniteorm/relationships/BelongsTo.py b/src/masoniteorm/relationships/BelongsTo.py index 1339124ac..b92af1199 100644 --- a/src/masoniteorm/relationships/BelongsTo.py +++ b/src/masoniteorm/relationships/BelongsTo.py @@ -53,7 +53,7 @@ def get_related(self, query, relation, eagers=(), callback=None): if isinstance(relation, Collection): return builder.where_in( f"{builder.get_table_name()}.{self.foreign_key}", - relation.pluck(self.local_key, keep_nulls=False).unique(), + Collection(relation._get_value(self.local_key)).unique(), ).get() else: @@ -63,8 +63,6 @@ def get_related(self, query, relation, eagers=(), callback=None): ).first() def register_related(self, key, model, collection): - related = collection.where( - self.foreign_key, getattr(model, self.local_key) - ).first() + related = collection.get(getattr(model, self.local_key), None) - model.add_relation({key: related or None}) + model.add_relation({key: related[0] if related else None}) diff --git a/src/masoniteorm/relationships/BelongsToMany.py b/src/masoniteorm/relationships/BelongsToMany.py index 36c7ec61b..897142e22 100644 --- a/src/masoniteorm/relationships/BelongsToMany.py +++ b/src/masoniteorm/relationships/BelongsToMany.py @@ -67,7 +67,7 @@ def apply_query(self, query, owner): self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" - else: + elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" @@ -185,7 +185,7 @@ def make_query(self, query, relation, eagers=None, callback=None): self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" - else: + elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" @@ -237,7 +237,7 @@ def make_query(self, query, relation, eagers=None, callback=None): if isinstance(relation, Collection): return result.where_in( self.local_owner_key, - relation.pluck(self.local_owner_key, keep_nulls=False), + Collection(relation._get_value(self.local_owner_key)).unique(), ).get() else: return result.where( @@ -302,7 +302,7 @@ def relate(self, related_record): self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" - else: + elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" @@ -368,7 +368,7 @@ def joins(self, builder, clause=None): self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" - else: + elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" @@ -541,7 +541,7 @@ def detach(self, current_model, related_record): .table(self._table) .without_global_scopes() .where(data) - .update({self.foreign_key: None, self.local_key: None}) + .delete() ) def attach_related(self, current_model, related_record): diff --git a/src/masoniteorm/relationships/HasMany.py b/src/masoniteorm/relationships/HasMany.py index 6ebfef249..a940df047 100644 --- a/src/masoniteorm/relationships/HasMany.py +++ b/src/masoniteorm/relationships/HasMany.py @@ -28,5 +28,5 @@ def set_keys(self, owner, attribute): def register_related(self, key, model, collection): model.add_relation( - {key: collection.where(self.foreign_key, getattr(model, self.local_key))} + {key: collection.get(getattr(model, self.local_key)) or Collection()} ) diff --git a/src/masoniteorm/relationships/HasManyThrough.py b/src/masoniteorm/relationships/HasManyThrough.py index 9a256fe64..e044f9a7e 100644 --- a/src/masoniteorm/relationships/HasManyThrough.py +++ b/src/masoniteorm/relationships/HasManyThrough.py @@ -1,5 +1,5 @@ -from .BaseRelationship import BaseRelationship from ..collection import Collection +from .BaseRelationship import BaseRelationship class HasManyThrough(BaseRelationship): @@ -57,33 +57,46 @@ def __get__(self, instance, owner): if attribute in instance._relationships: return instance._relationships[attribute] - result = self.apply_query( + result = self.apply_related_query( self.distant_builder, self.intermediary_builder, instance ) return result else: return self - def apply_query(self, distant_builder, intermediary_builder, owner): - """Apply the query and return a dictionary to be hydrated. - Used during accessing a relationship on a model + def apply_related_query(self, distant_builder, intermediary_builder, owner): + """ + Apply the query to return a Collection of data for the distant models to be hydrated with. - Arguments: - query {oject} -- The relationship object - owner {object} -- The current model oject. + Method is used when accessing a relationship on a model if its not + already eager loaded - Returns: - dict -- A dictionary of data which will be hydrated. + Arguments + distant_builder (QueryBuilder): QueryBuilder attached to the distant table + intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table + owner (Any): the model this relationship is starting from + + Returns + Collection: Collection of dicts which will be used for hydrating models. """ - # select * from `countries` inner join `ports` on `ports`.`country_id` = `countries`.`country_id` where `ports`.`port_id` is null and `countries`.`deleted_at` is null and `ports`.`deleted_at` is null - distant_builder.join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{distant_builder.get_table_name()}.{self.other_owner_key}", - ) - return self + distant_table = distant_builder.get_table_name() + intermediate_table = intermediary_builder.get_table_name() + + return ( + self.distant_builder.select(f"{distant_table}.*, {intermediate_table}.{self.local_key}") + .join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ) + .where( + f"{intermediate_table}.{self.local_key}", + getattr(owner, self.local_owner_key), + ) + .get() + ) def relate(self, related_model): return self.distant_builder.join( @@ -104,51 +117,144 @@ def make_builder(self, eagers=None): return builder - def get_related(self, query, relation, eagers=None, callback=None): - builder = self.distant_builder + def register_related(self, key, model, collection): + """ + Attach the related model to source models attribute + + Arguments + key (str): The attribute name + model (Any): The model instance + collection (Collection): The data for the related models + + Returns + None + """ + related = collection.get(getattr(model, self.local_owner_key), None) + if related and not isinstance(related, Collection): + related = Collection(related) + + model.add_relation({key: related if related else None}) + + def get_related(self, current_builder, relation, eagers=None, callback=None): + """ + Get a Collection to hydrate the models for the distant table with + Used when eager loading the model attribute + + Arguments + current_builder (QueryBuilder): The source models QueryBuilder object + relation (HasManyThrough): this relationship object + eagers (Any): + callback (Any): + + Returns + Collection the collection of dicts to hydrate the distant models with + """ + + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() if callback: - callback(builder) + callback(current_builder) + + ( + self.distant_builder.select( + f"{distant_table}.*, {intermediate_table}.{self.local_key}" + ).join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ) + ) if isinstance(relation, Collection): - return builder.where_in( - f"{builder.get_table_name()}.{self.foreign_key}", - relation.pluck(self.local_key, keep_nulls=False).unique(), + return self.distant_builder.where_in( + f"{intermediate_table}.{self.local_key}", + Collection(relation._get_value(self.local_owner_key)).unique(), ).get() else: - return builder.where( - f"{builder.get_table_name()}.{self.foreign_key}", + return self.distant_builder.where( + f"{intermediate_table}.{self.local_key}", getattr(relation, self.local_owner_key), ).get() - def get_with_count_query(self, builder, callback): - query = self.distant_builder + def attach(self, current_model, related_record): + raise NotImplementedError( + "HasOneThrough relationship does not implement the attach method" + ) - if not builder._columns: - builder = builder.select("*") + def attach_related(self, current_model, related_record): + raise NotImplementedError( + "HasOneThrough relationship does not implement the attach_related method" + ) - return_query = builder.add_select( + def query_has(self, current_builder, method="where_exists"): + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() + + getattr(current_builder, method)( + self.distant_builder.join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ).where_column( + f"{intermediate_table}.{self.local_key}", + f"{current_builder.get_table_name()}.{self.local_owner_key}", + ) + ) + + return self.distant_builder + + def query_where_exists(self, current_builder, callback, method="where_exists"): + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() + + getattr(current_builder, method)( + self.distant_builder.join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ) + .where_column( + f"{intermediate_table}.{self.local_key}", + f"{current_builder.get_table_name()}.{self.local_owner_key}", + ) + .when(callback, lambda q: (callback(q))) + ) + + def get_with_count_query(self, current_builder, callback): + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() + + if not current_builder._columns: + current_builder.select("*") + + return_query = current_builder.add_select( f"{self.attribute}_count", lambda q: ( ( q.count("*") .join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", "=", - f"{query.get_table_name()}.{self.other_owner_key}", + f"{distant_table}.{self.other_owner_key}", ) .where_column( - f"{builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", + f"{intermediate_table}.{self.local_key}", + f"{current_builder.get_table_name()}.{self.local_owner_key}", ) - .table(query.get_table_name()) + .table(distant_table) .when( callback, lambda q: ( q.where_in( self.foreign_key, - callback(query.select(self.other_owner_key)), + callback( + self.distant_builder.select(self.other_owner_key) + ), ) ), ) @@ -157,47 +263,3 @@ def get_with_count_query(self, builder, callback): ) return return_query - - def attach(self, current_model, related_record): - raise NotImplementedError( - "HasOneThrough relationship does not implement the attach method" - ) - - def attach_related(self, current_model, related_record): - raise NotImplementedError( - "HasOneThrough relationship does not implement the attach_related method" - ) - - def query_has(self, current_query_builder, method="where_exists"): - related_builder = self.get_builder() - - getattr(current_query_builder, method)( - self.distant_builder.where_column( - f"{current_query_builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", - ).join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{self.distant_builder.get_table_name()}.{self.other_owner_key}", - ) - ) - - return related_builder - - def query_where_exists( - self, current_query_builder, callback, method="where_exists" - ): - query = self.distant_builder - - getattr(current_query_builder, method)( - query.join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{query.get_table_name()}.{self.other_owner_key}", - ).where_column( - f"{current_query_builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", - ) - ).when(callback, lambda q: (callback(q))) diff --git a/src/masoniteorm/relationships/HasOne.py b/src/masoniteorm/relationships/HasOne.py index a1a76c403..c60056402 100644 --- a/src/masoniteorm/relationships/HasOne.py +++ b/src/masoniteorm/relationships/HasOne.py @@ -54,7 +54,7 @@ def get_related(self, query, relation, eagers=(), callback=None): if isinstance(relation, Collection): return builder.where_in( f"{builder.get_table_name()}.{self.foreign_key}", - relation.pluck(self.local_key, keep_nulls=False).unique(), + Collection(relation._get_value(self.local_key)).unique(), ).get() else: return builder.where( diff --git a/src/masoniteorm/relationships/HasOneThrough.py b/src/masoniteorm/relationships/HasOneThrough.py index 3840da32c..0f9eb0878 100644 --- a/src/masoniteorm/relationships/HasOneThrough.py +++ b/src/masoniteorm/relationships/HasOneThrough.py @@ -26,6 +26,10 @@ def __init__( self.local_owner_key = local_owner_key or "id" self.other_owner_key = other_owner_key or "id" + def __getattr__(self, attribute): + relationship = self.fn(self)[1]() + return getattr(relationship.builder, attribute) + def set_keys(self, distant_builder, intermediary_builder, attribute): self.local_key = self.local_key or "id" self.foreign_key = self.foreign_key or f"{attribute}_id" @@ -34,17 +38,18 @@ def set_keys(self, distant_builder, intermediary_builder, attribute): return self def __get__(self, instance, owner): - """This method is called when the decorated method is accessed. + """ + This method is called when the decorated method is accessed. - Arguments: - instance {object|None} -- The instance we called. + Arguments + instance (object|None): The instance we called. If we didn't call the attribute and only accessed it then this will be None. + owner (object): The current model that the property was accessed on. - owner {object} -- The current model that the property was accessed on. - - Returns: - object -- Either returns a builder or a hydrated model. + Returns + QueryBuilder|Model: Either returns a builder or a hydrated model. """ + attribute = self.fn.__name__ self.attribute = attribute relationship1 = self.fn(self)[0]() @@ -57,43 +62,60 @@ def __get__(self, instance, owner): if attribute in instance._relationships: return instance._relationships[attribute] - result = self.apply_query( + return self.apply_relation_query( self.distant_builder, self.intermediary_builder, instance ) - return result else: return self - def apply_query(self, distant_builder, intermediary_builder, owner): - """Apply the query and return a dictionary to be hydrated. - Used during accessing a relationship on a model + def apply_relation_query(self, distant_builder, intermediary_builder, owner): + """ + Apply the query and return a dict of data for the distant model to be hydrated with. - Arguments: - query {oject} -- The relationship object - owner {object} -- The current model oject. + Method is used when accessing a relationship on a model if its not + already eager loaded - Returns: - dict -- A dictionary of data which will be hydrated. + Arguments + distant_builder (QueryBuilder): QueryBuilder attached to the distant table + intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table + owner (Any): the model this relationship is starting from + + Returns + dict: A dictionary of data which will be hydrated. """ - # select * from `countries` inner join `ports` on `ports`.`country_id` = `countries`.`country_id` where `ports`.`port_id` is null and `countries`.`deleted_at` is null and `ports`.`deleted_at` is null - distant_builder.join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{distant_builder.get_table_name()}.{self.other_owner_key}", - ) - return self + dist_table = distant_builder.get_table_name() + int_table = intermediary_builder.get_table_name() + + return ( + distant_builder.select( + f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}" + ) + .join( + f"{int_table}", + f"{int_table}.{self.foreign_key}", + "=", + f"{dist_table}.{self.other_owner_key}", + ) + .where( + f"{int_table}.{self.local_owner_key}", + getattr(owner, self.local_key), + ) + .first() + ) def relate(self, related_model): + dist_table = self.distant_builder.get_table_name() + int_table = self.intermediary_builder.get_table_name() + return self.distant_builder.join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", + f"{int_table}", + f"{int_table}.{self.foreign_key}", "=", - f"{self.distant_builder.get_table_name()}.{self.other_owner_key}", - ).where( - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", - getattr(related_model, self.local_owner_key), + f"{dist_table}.{self.other_owner_key}", + ).where_column( + f"{int_table}.{self.local_owner_key}", + getattr(related_model, self.local_key), ) def get_builder(self): @@ -104,68 +126,139 @@ def make_builder(self, eagers=None): return builder - def get_related(self, query, relation, eagers=None, callback=None): - builder = self.distant_builder + def register_related(self, key, model, collection): + """ + Attach the related model to source models attribute + + Arguments + key (str): The attribute name + model (Any): The model instance + collection (Collection): The data for the related models + + Returns + None + """ + + related = collection.get(getattr(model, self.local_key), None) + model.add_relation({key: related[0] if related else None}) + + def get_related(self, current_builder, relation, eagers=None, callback=None): + """ + Get the data to hydrate the model for the distant table with + Used when eager loading the model attribute + + Arguments + query (QueryBuilder): The source models QueryBuilder object + relation (HasOneThrough): this relationship object + eagers (Any): + callback (Any): + + Returns + dict: the dict to hydrate the distant model with + """ + + dist_table = self.distant_builder.get_table_name() + int_table = self.intermediary_builder.get_table_name() if callback: - callback(builder) + callback(current_builder) + + (self.distant_builder.select(f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}") + .join( + f"{int_table}", + f"{int_table}.{self.foreign_key}", + "=", + f"{dist_table}.{self.other_owner_key}", + )) if isinstance(relation, Collection): - return builder.where_in( - f"{builder.get_table_name()}.{self.foreign_key}", - relation.pluck(self.local_key, keep_nulls=False).unique(), + return self.distant_builder.where_in( + f"{int_table}.{self.local_owner_key}", + Collection(relation._get_value(self.local_key)).unique(), ).get() else: - return builder.where( - f"{builder.get_table_name()}.{self.foreign_key}", - getattr(relation, self.local_owner_key), + return self.distant_builder.where( + f"{int_table}.{self.local_owner_key}", + getattr(relation, self.local_key), ).first() - def query_where_exists( - self, current_query_builder, callback, method="where_exists" - ): - query = self.distant_builder + def attach(self, current_model, related_record): + raise NotImplementedError( + "HasOneThrough relationship does not implement the attach method" + ) + + def attach_related(self, current_model, related_record): + raise NotImplementedError( + "HasOneThrough relationship does not implement the attach_related method" + ) - getattr(current_query_builder, method)( - query.join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", + def query_has(self, current_builder, method="where_exists"): + dist_table = self.distant_builder.get_table_name() + int_table = self.intermediary_builder.get_table_name() + + getattr(current_builder, method)( + self.distant_builder.join( + f"{int_table}", + f"{int_table}.{self.foreign_key}", "=", - f"{query.get_table_name()}.{self.other_owner_key}", + f"{dist_table}.{self.other_owner_key}", ).where_column( - f"{current_query_builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", + f"{int_table}.{self.local_owner_key}", + f"{current_builder.get_table_name()}.{self.local_key}", + ) + ) + + return self.distant_builder + + def query_where_exists(self, current_builder, callback, method="where_exists"): + dist_table = self.distant_builder.get_table_name() + int_table = self.intermediary_builder.get_table_name() + + getattr(current_builder, method)( + self.distant_builder.join( + f"{int_table}", + f"{int_table}.{self.foreign_key}", + "=", + f"{dist_table}.{self.other_owner_key}", ) - ).when(callback, lambda q: (callback(q))) + .where_column( + f"{int_table}.{self.local_owner_key}", + f"{current_builder.get_table_name()}.{self.local_key}", + ) + .when(callback, lambda q: (callback(q))) + ) - def get_with_count_query(self, builder, callback): - query = self.distant_builder + def get_with_count_query(self, current_builder, callback): + dist_table = self.distant_builder.get_table_name() + int_table = self.intermediary_builder.get_table_name() - if not builder._columns: - builder = builder.select("*") + if not current_builder._columns: + current_builder.select("*") - return_query = builder.add_select( + return_query = current_builder.add_select( f"{self.attribute}_count", lambda q: ( ( q.count("*") .join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", + f"{int_table}", + f"{int_table}.{self.foreign_key}", "=", - f"{query.get_table_name()}.{self.other_owner_key}", + f"{dist_table}.{self.other_owner_key}", ) .where_column( - f"{builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", + f"{int_table}.{self.local_owner_key}", + f"{current_builder.get_table_name()}.{self.local_key}", ) - .table(query.get_table_name()) + .table(dist_table) .when( callback, lambda q: ( q.where_in( self.foreign_key, - callback(query.select(self.other_owner_key)), + callback( + self.distant_builder.select(self.other_owner_key) + ), ) ), ) @@ -174,30 +267,3 @@ def get_with_count_query(self, builder, callback): ) return return_query - - def attach(self, current_model, related_record): - raise NotImplementedError( - "HasOneThrough relationship does not implement the attach method" - ) - - def attach_related(self, current_model, related_record): - raise NotImplementedError( - "HasOneThrough relationship does not implement the attach_related method" - ) - - def query_has(self, current_query_builder, method="where_exists"): - related_builder = self.get_builder() - - getattr(current_query_builder, method)( - self.distant_builder.where_column( - f"{current_query_builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", - ).join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{self.distant_builder.get_table_name()}.{self.other_owner_key}", - ) - ) - - return related_builder diff --git a/src/masoniteorm/relationships/MorphOne.py b/src/masoniteorm/relationships/MorphOne.py index b5595e30c..62e98a1d9 100644 --- a/src/masoniteorm/relationships/MorphOne.py +++ b/src/masoniteorm/relationships/MorphOne.py @@ -1,6 +1,6 @@ from ..collection import Collection -from .BaseRelationship import BaseRelationship from ..config import load_config +from .BaseRelationship import BaseRelationship class MorphOne(BaseRelationship): @@ -67,8 +67,8 @@ def apply_query(self, builder, instance): polymorphic_builder = self.polymorphic_builder return ( - polymorphic_builder.where("record_type", polymorphic_key) - .where("record_id", instance.get_primary_key_value()) + polymorphic_builder.where(self.morph_key, polymorphic_key) + .where(self.morph_id, instance.get_primary_key_value()) .first() ) diff --git a/src/masoniteorm/relationships/MorphTo.py b/src/masoniteorm/relationships/MorphTo.py index f0abcf445..84ba8bb75 100644 --- a/src/masoniteorm/relationships/MorphTo.py +++ b/src/masoniteorm/relationships/MorphTo.py @@ -50,7 +50,7 @@ def __get__(self, instance, owner): def __getattr__(self, attribute): relationship = self.fn(self)() - return getattr(relationship.builder, attribute) + return getattr(relationship._related_builder, attribute) def apply_query(self, builder, instance): """Apply the query and return a dictionary to be hydrated diff --git a/src/masoniteorm/schema/Blueprint.py b/src/masoniteorm/schema/Blueprint.py index 3b20464de..b2865c54a 100644 --- a/src/masoniteorm/schema/Blueprint.py +++ b/src/masoniteorm/schema/Blueprint.py @@ -396,6 +396,7 @@ def decimal(self, column, length=17, precision=6, nullable=False): Returns: self """ + self._last_column = self.table.add_column( column, "decimal", @@ -483,6 +484,46 @@ def text(self, column, length=None, nullable=False): ) return self + def tiny_text(self, column, length=None, nullable=False): + """Sets a column to be the text representation for the table. + + Arguments: + column {string} -- The column name. + + Keyword Arguments: + length {int} -- The length of the column if any. (default: {False}) + nullable {bool} -- Whether the column is nullable. (default: {False}) + + Returns: + self + """ + self._last_column = self.table.add_column( + column, "tiny_text", length=length, nullable=nullable + ) + return self + + def unsigned_decimal(self, column, length=17, precision=6, nullable=False): + """Sets a column to be the text representation for the table. + + Arguments: + column {string} -- The column name. + + Keyword Arguments: + length {int} -- The length of the column if any. (default: {False}) + nullable {bool} -- Whether the column is nullable. (default: {False}) + + Returns: + self + """ + self._last_column = self.table.add_column( + column, + "decimal", + length="{length}, {precision}".format(length=length, precision=precision), + nullable=nullable, + ).unsigned() + return self + return self + def long_text(self, column, length=None, nullable=False): """Sets a column to be the long_text representation for the table. @@ -643,13 +684,12 @@ def unsigned(self, column=None, length=None, nullable=False): self """ if not column: - self._last_column.column_type += "_unsigned" - self._last_column.length = None + self._last_column.unsigned() return self self._last_column = self.table.add_column( column, "unsigned", length=length, nullable=nullable - ) + ).unsigned() return self def unsigned_integer(self, column, nullable=False): @@ -665,8 +705,8 @@ def unsigned_integer(self, column, nullable=False): self """ self._last_column = self.table.add_column( - column, "integer_unsigned", nullable=nullable - ) + column, "integer", nullable=nullable + ).unsigned() return self def morphs(self, column, nullable=False, indexes=True): @@ -684,8 +724,8 @@ def morphs(self, column, nullable=False, indexes=True): _columns = [] _columns.append( self.table.add_column( - "{}_id".format(column), "integer_unsigned", nullable=nullable - ) + "{}_id".format(column), "integer", nullable=nullable + ).unsigned() ) _columns.append( self.table.add_column( @@ -900,9 +940,11 @@ def foreign_id_for(self, model, column=None): """ clm = column if column else model.get_foreign_key() - return self.foreign_id(clm)\ - if model.get_primary_key_type() == 'int'\ + return ( + self.foreign_id(clm) + if model.get_primary_key_type() == "int" else self.foreign_uuid(column) + ) def references(self, column): """Sets the other column on the foreign table that the local column will use to reference. diff --git a/src/masoniteorm/schema/Column.py b/src/masoniteorm/schema/Column.py index 63783a2cd..8f1a6c07f 100644 --- a/src/masoniteorm/schema/Column.py +++ b/src/masoniteorm/schema/Column.py @@ -9,6 +9,7 @@ def __init__( values=None, nullable=False, default=None, + signed=None, default_is_raw=False, column_python_type=str, ): @@ -21,6 +22,7 @@ def __init__( self._after = None self.old_column = "" self.default = default + self._signed = signed self.default_is_raw = default_is_raw self.primary = False self.comment = None @@ -34,6 +36,24 @@ def nullable(self): self.is_null = True return self + def signed(self): + """Sets this column to be nullable + + Returns: + self + """ + self._signed = "signed" + return self + + def unsigned(self): + """Sets this column to be nullable + + Returns: + self + """ + self._signed = "unsigned" + return self + def not_nullable(self): """Sets this column to be not nullable diff --git a/src/masoniteorm/schema/Schema.py b/src/masoniteorm/schema/Schema.py index 6140c8c24..af1a17c37 100644 --- a/src/masoniteorm/schema/Schema.py +++ b/src/masoniteorm/schema/Schema.py @@ -6,7 +6,6 @@ class Schema: - _default_string_length = "255" _type_hints_map = { "string": str, @@ -58,6 +57,7 @@ def __init__( grammar=None, connection_details=None, schema=None, + config_path=None, ): self._dry = dry self.connection = connection @@ -69,6 +69,7 @@ def __init__( self._blueprint = None self._sql = None self.schema = schema + self.config_path = config_path if not self.connection_class: self.on(self.connection) @@ -86,7 +87,7 @@ def on(self, connection_key): Returns: cls """ - DB = load_config().DB + DB = load_config(config_path=self.config_path).DB if connection_key == "default": self.connection = self.connection_details.get("default") @@ -287,12 +288,26 @@ def truncate(self, table, foreign_keys=False): return bool(self.new_connection().query(sql, ())) def get_schema(self): - """Gets the schema set on the migration class - """ + """Gets the schema set on the migration class""" return self.schema or self.get_connection_information().get("full_details").get( "schema" ) + def get_all_tables(self): + """Gets all tables in the database""" + sql = self.platform().compile_get_all_tables( + database=self.get_connection_information().get("database"), + schema=self.get_schema(), + ) + + if self._dry: + self._sql = sql + return sql + + result = self.new_connection().query(sql, ()) + + return list(map(lambda t: list(t.values())[0], result)) if result else [] + def has_table(self, table, query_only=False): """Checks if the a database has a specific table Arguments: diff --git a/src/masoniteorm/schema/Table.py b/src/masoniteorm/schema/Table.py index 75d0d413a..b64f33d15 100644 --- a/src/masoniteorm/schema/Table.py +++ b/src/masoniteorm/schema/Table.py @@ -25,6 +25,7 @@ def add_column( values=None, nullable=False, default=None, + signed=None, default_is_raw=False, primary=False, column_python_type=str, @@ -36,6 +37,7 @@ def add_column( nullable=nullable, values=values or [], default=default, + signed=signed, default_is_raw=default_is_raw, column_python_type=column_python_type, ) diff --git a/src/masoniteorm/schema/platforms/MSSQLPlatform.py b/src/masoniteorm/schema/platforms/MSSQLPlatform.py index 7f3ba591c..bf9386b64 100644 --- a/src/masoniteorm/schema/platforms/MSSQLPlatform.py +++ b/src/masoniteorm/schema/platforms/MSSQLPlatform.py @@ -3,7 +3,6 @@ class MSSQLPlatform(Platform): - types_without_lengths = [ "integer", "big_integer", @@ -34,6 +33,7 @@ class MSSQLPlatform(Platform): "double": "DOUBLE", "enum": "VARCHAR", "text": "TEXT", + "tiny_text": "TINYTEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", @@ -120,7 +120,6 @@ def compile_alter_sql(self, table): if table.renamed_columns: for name, column in table.get_renamed_columns().items(): - sql.append( self.rename_column_string(table.name, name, column.name).strip() ) @@ -336,6 +335,9 @@ def compile_drop_table(self, table): def compile_column_exists(self, table, column): return f"SELECT 1 FROM sys.columns WHERE Name = N'{column}' AND Object_ID = Object_ID(N'{table}')" + def compile_get_all_tables(self, database, schema=None): + return f"SELECT name FROM {database}.sys.tables" + def get_current_schema(self, connection, table_name, schema=None): return Table(table_name) diff --git a/src/masoniteorm/schema/platforms/MySQLPlatform.py b/src/masoniteorm/schema/platforms/MySQLPlatform.py index 9c8de30fb..07b3a1742 100644 --- a/src/masoniteorm/schema/platforms/MySQLPlatform.py +++ b/src/masoniteorm/schema/platforms/MySQLPlatform.py @@ -29,6 +29,7 @@ class MySQLPlatform(Platform): "double": "DOUBLE", "enum": "ENUM", "text": "TEXT", + "tiny_text": "TINYTEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", @@ -55,6 +56,8 @@ class MySQLPlatform(Platform): "null": " DEFAULT NULL", } + signed = {"unsigned": "UNSIGNED", "signed": "SIGNED"} + def columnize(self, columns): sql = [] for name, column in columns.items(): @@ -87,7 +90,6 @@ def columnize(self, columns): if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f"({values})" - sql.append( self.columnize_string() .format( @@ -98,6 +100,9 @@ def columnize(self, columns): constraint=constraint, nullable=self.premapped_nulls.get(column.is_null) or "", default=default, + signed=" " + self.signed.get(column._signed) + if column._signed + else "", comment="COMMENT '" + column.comment + "'" if column.comment else "", @@ -171,15 +176,23 @@ def compile_alter_sql(self, table): else: default = "" + column_constraint = "" + if column.column_type == "enum": + values = ", ".join(f"'{x}'" for x in column.values) + column_constraint = f"({values})" add_columns.append( self.add_column_string() .format( name=self.get_column_string().format(column=column.name), data_type=self.type_map.get(column.column_type, ""), + column_constraint=column_constraint, length=length, constraint="PRIMARY KEY" if column.primary else "", nullable="NULL" if column.is_null else "NOT NULL", default=default, + signed=" " + self.signed.get(column._signed) + if column._signed + else "", after=(" AFTER " + self.wrap_column(column._after)) if column._after else "", @@ -324,19 +337,21 @@ def compile_alter_sql(self, table): return sql def add_column_string(self): - return "ADD {name} {data_type}{length} {nullable}{default}{after}{comment}" + return ( + "ADD {name} {data_type}{length}{column_constraint}{signed} {nullable}{default}{after}{comment}" + ) def drop_column_string(self): return "DROP COLUMN {name}" def change_column_string(self): - return "MODIFY {name}{data_type}{length} {nullable}{default} {constraint}" + return "MODIFY {name}{data_type}{length}{column_constraint} {nullable}{default} {constraint}" def rename_column_string(self): return "CHANGE {old} {to}" def columnize_string(self): - return "{name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}{comment}" + return "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}{comment}" def constraintize(self, constraints, table): sql = [] @@ -403,6 +418,9 @@ def compile_drop_table(self, table): def compile_column_exists(self, table, column): return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'" + def compile_get_all_tables(self, database, schema=None): + return f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{database}'" + def get_current_schema(self, connection, table_name, schema=None): table = Table(table_name) sql = f"DESCRIBE {table_name}" diff --git a/src/masoniteorm/schema/platforms/Platform.py b/src/masoniteorm/schema/platforms/Platform.py index 25e5a4865..cbda6814c 100644 --- a/src/masoniteorm/schema/platforms/Platform.py +++ b/src/masoniteorm/schema/platforms/Platform.py @@ -1,5 +1,4 @@ class Platform: - foreign_key_actions = { "cascade": "CASCADE", "set null": "SET NULL", @@ -9,6 +8,8 @@ class Platform: "default": "SET DEFAULT", } + signed = {"signed": "SIGNED", "unsigned": "UNSIGNED"} + def columnize(self, columns): sql = [] for name, column in columns.items(): diff --git a/src/masoniteorm/schema/platforms/PostgresPlatform.py b/src/masoniteorm/schema/platforms/PostgresPlatform.py index 4c8fc0978..c6f45ef7e 100644 --- a/src/masoniteorm/schema/platforms/PostgresPlatform.py +++ b/src/masoniteorm/schema/platforms/PostgresPlatform.py @@ -40,6 +40,7 @@ class PostgresPlatform(Platform): "double": "DOUBLE PRECISION", "enum": "VARCHAR", "text": "TEXT", + "tiny_text": "TEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", @@ -193,6 +194,11 @@ def compile_alter_sql(self, table): else: default = "" + column_constraint = "" + if column.column_type == "enum": + values = ", ".join(f"'{x}'" for x in column.values) + column_constraint = f" CHECK({column.name} IN ({values}))" + add_columns.append( self.add_column_string() .format( @@ -200,6 +206,7 @@ def compile_alter_sql(self, table): data_type=self.type_map.get(column.column_type, ""), length=length, constraint="PRIMARY KEY" if column.primary else "", + column_constraint=column_constraint, nullable="NULL" if column.is_null else "NOT NULL", default=default, after=(" AFTER " + self.wrap_column(column._after)) @@ -262,12 +269,18 @@ def compile_alter_sql(self, table): changed_sql = [] for name, column in table.changed_columns.items(): + + column_constraint = "" + if column.column_type == "enum": + values = ", ".join(f"'{x}'" for x in column.values) + column_constraint = f" CHECK({column.name} IN ({values}))" changed_sql.append( self.modify_column_string() .format( name=self.wrap_column(name), data_type=self.type_map.get(column.column_type), - nullable="NULL" if column.is_null else "NOT NULL", + column_constraint=column_constraint, + constraint="PRIMARY KEY" if column.primary else "", length="(" + str(column.length) + ")" if column.column_type not in self.types_without_lengths else "", @@ -379,13 +392,13 @@ def alter_format_add_foreign_key(self): return "ALTER TABLE {table} {columns}" def add_column_string(self): - return "ADD COLUMN {name} {data_type}{length} {nullable}{default} {constraint}" + return "ADD COLUMN {name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}" def drop_column_string(self): return "DROP COLUMN {name}" def modify_column_string(self): - return "ALTER COLUMN {name} TYPE {data_type}{length}" + return "ALTER COLUMN {name} TYPE {data_type}{length}{column_constraint} {constraint}" def rename_column_string(self): return "RENAME COLUMN {old} TO {to}" @@ -459,6 +472,9 @@ def compile_drop_table(self, table): def compile_column_exists(self, table, column): return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'" + def compile_get_all_tables(self, database=None, schema=None): + return f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{database}'" + def get_current_schema(self, connection, table_name, schema=None): sql = self.table_information_string().format( table=table_name, schema=schema or "public" diff --git a/src/masoniteorm/schema/platforms/SQLitePlatform.py b/src/masoniteorm/schema/platforms/SQLitePlatform.py index d73684fb5..93a42bcc3 100644 --- a/src/masoniteorm/schema/platforms/SQLitePlatform.py +++ b/src/masoniteorm/schema/platforms/SQLitePlatform.py @@ -4,7 +4,6 @@ class SQLitePlatform(Platform): - types_without_lengths = [ "integer", "big_integer", @@ -13,6 +12,8 @@ class SQLitePlatform(Platform): "medium_integer", ] + types_without_signs = ["decimal"] + type_map = { "string": "VARCHAR", "char": "CHAR", @@ -35,6 +36,7 @@ class SQLitePlatform(Platform): "double": "DOUBLE", "enum": "VARCHAR", "text": "TEXT", + "tiny_text": "TEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", @@ -133,6 +135,10 @@ def columnize(self, columns): data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, length=length, + signed=" " + self.signed.get(column._signed) + if column.column_type not in self.types_without_signs + and column._signed + else "", constraint=constraint, nullable=self.premapped_nulls.get(column.is_null) or "", default=default, @@ -144,7 +150,6 @@ def columnize(self, columns): def compile_alter_sql(self, diff): sql = [] - if diff.removed_indexes or diff.removed_unique_indexes: indexes = diff.removed_indexes indexes += diff.removed_unique_indexes @@ -166,17 +171,27 @@ def compile_alter_sql(self, diff): else: default = "" constraint = "" + column_constraint = "" if column.name in diff.added_foreign_keys: foreign_key = diff.added_foreign_keys[column.name] constraint = f" REFERENCES {self.wrap_table(foreign_key.foreign_table)}({self.wrap_column(foreign_key.foreign_column)})" + if column.column_type == "enum": + values = ", ".join(f"'{x}'" for x in column.values) + column_constraint = f" CHECK('{column.name}' IN({values}))" sql.append( - "ALTER TABLE {table} ADD COLUMN {name} {data_type} {nullable}{default}{constraint}".format( + self.add_column_string() + .format( table=self.wrap_table(diff.name), name=self.wrap_column(column.name), data_type=self.type_map.get(column.column_type, ""), + column_constraint=column_constraint, nullable="NULL" if column.is_null else "NOT NULL", default=default, + signed=" " + self.signed.get(column._signed) + if column.column_type not in self.types_without_signs + and column._signed + else "", constraint=constraint, ).strip() ) @@ -282,13 +297,18 @@ def get_table_string(self): def get_column_string(self): return '"{column}"' + def add_column_string(self): + return ( + "ALTER TABLE {table} ADD COLUMN {name} {data_type}{column_constraint}{signed} {nullable}{default}{constraint}" + ) + def create_column_length(self, column_type): if column_type in self.types_without_lengths: return "" return "({length})" def columnize_string(self): - return "{name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}" + return "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}" def get_unique_constraint_string(self): return "UNIQUE({columns})" @@ -363,6 +383,7 @@ def get_current_schema(self, connection, table_name, schema=None): column_python_type=Schema._type_hints_map.get(column_type, str), default=default, length=length, + nullable=int(column.get("notnull")) == 0, ) if column.get("pk") == 1: table.set_primary_key(column["name"]) @@ -402,6 +423,9 @@ def compile_table_exists(self, table, database=None, schema=None): def compile_column_exists(self, table, column): return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'" + def compile_get_all_tables(self, database, schema=None): + return "SELECT name FROM sqlite_master WHERE type='table'" + def compile_truncate(self, table, foreign_keys=False): if not foreign_keys: return f"DELETE FROM {self.wrap_table(table)}" diff --git a/src/masoniteorm/scopes/SoftDeleteScope.py b/src/masoniteorm/scopes/SoftDeleteScope.py index c7a6abc8a..ee79fcf51 100644 --- a/src/masoniteorm/scopes/SoftDeleteScope.py +++ b/src/masoniteorm/scopes/SoftDeleteScope.py @@ -34,8 +34,10 @@ def _only_trashed(self, model, builder): builder.remove_global_scope("_where_null", action="select") return builder.where_not_null(self.deleted_at_column) - def _force_delete(self, model, builder): - return builder.remove_global_scope(self).set_action("delete") + def _force_delete(self, model, builder, query=False): + if query: + return builder.remove_global_scope(self).set_action("delete") + return builder.remove_global_scope(self).delete() def _restore(self, model, builder): return builder.remove_global_scope(self).update({self.deleted_at_column: None}) diff --git a/src/masoniteorm/scopes/TimeStampsScope.py b/src/masoniteorm/scopes/TimeStampsScope.py index d24ed9574..e9da387e9 100644 --- a/src/masoniteorm/scopes/TimeStampsScope.py +++ b/src/masoniteorm/scopes/TimeStampsScope.py @@ -1,6 +1,5 @@ -from .BaseScope import BaseScope - from ..expressions.expressions import UpdateQueryExpression +from .BaseScope import BaseScope class TimeStampsScope(BaseScope): @@ -27,8 +26,8 @@ def set_timestamp_create(self, builder): builder._creates.update( { - "updated_at": builder._model.get_new_date().to_datetime_string(), - "created_at": builder._model.get_new_date().to_datetime_string(), + builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string(), + builder._model.date_created_at: builder._model.get_new_date().to_datetime_string(), } ) @@ -36,8 +35,13 @@ def set_timestamp_update(self, builder): if not builder._model.__timestamps__: return builder + for update in builder._updates: + if builder._model.date_updated_at in update.column: + return builder._updates += ( UpdateQueryExpression( - {"updated_at": builder._model.get_new_date().to_datetime_string()} + { + builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string() + } ), ) diff --git a/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py b/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py index ef091de78..00f814ad9 100644 --- a/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py +++ b/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py @@ -1,4 +1,5 @@ import uuid + from .BaseScope import BaseScope @@ -9,6 +10,9 @@ def on_boot(self, builder): builder.set_global_scope( "_UUID_primary_key", self.set_uuid_create, action="insert" ) + builder.set_global_scope( + "_UUID_primary_key", self.set_bulk_uuid_create, action="bulk_create" + ) def on_remove(self, builder): pass @@ -22,15 +26,21 @@ def generate_uuid(self, builder, uuid_version, bytes=False): return uuid_func(*args).bytes if bytes else str(uuid_func(*args)) + def build_uuid_pk(self, builder): + uuid_version = getattr(builder._model, "__uuid_version__", 4) + uuid_bytes = getattr(builder._model, "__uuid_bytes__", False) + return { + builder._model.__primary_key__: self.generate_uuid( + builder, uuid_version, uuid_bytes + ) + } + def set_uuid_create(self, builder): # if there is already a primary key, no need to set a new one if builder._model.__primary_key__ not in builder._creates: - uuid_version = getattr(builder._model, "__uuid_version__", 4) - uuid_bytes = getattr(builder._model, "__uuid_bytes__", False) - builder._creates.update( - { - builder._model.__primary_key__: self.generate_uuid( - builder, uuid_version, uuid_bytes - ) - } - ) + builder._creates.update(self.build_uuid_pk(builder)) + + def set_bulk_uuid_create(self, builder): + for idx, create_atts in enumerate(builder._creates): + if builder._model.__primary_key__ not in create_atts: + builder._creates[idx].update(self.build_uuid_pk(builder)) diff --git a/src/masoniteorm/scopes/scope.py b/src/masoniteorm/scopes/scope.py index 482a5e284..79c2cb81b 100644 --- a/src/masoniteorm/scopes/scope.py +++ b/src/masoniteorm/scopes/scope.py @@ -3,7 +3,10 @@ def __init__(self, callback, *params, **kwargs): self.fn = callback def __set_name__(self, cls, name): - cls._scopes.update({name: self.fn}) + if cls not in cls._scopes: + cls._scopes[cls] = {name: self.fn} + else: + cls._scopes[cls].update({name: self.fn}) self.cls = cls def __call__(self, *args, **kwargs): diff --git a/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py b/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py index 7c50f8db7..7cee017ec 100644 --- a/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py +++ b/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py @@ -9,7 +9,6 @@ class MockConnection: - connection_details = {} def make_connection(self): @@ -341,7 +340,22 @@ def test_can_compile_join_clause_with_null(self): clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") + .or_on_null("bgt.dept") + .on_value("rg.abc", 10) + ) + to_sql = self.builder.join(clause).to_sql() + + sql = getattr( + self, inspect.currentframe().f_code.co_name.replace("test_", "") + )() + self.assertEqual(to_sql, sql) + + def test_can_compile_join_clause_with_not_null(self): + clause = ( + JoinClause("report_groups as rg") + .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") + .on_value("rg.abc", 10) ) to_sql = self.builder.join(clause).to_sql() diff --git a/tests/User.py b/tests/User.py index 771817112..452f0b4d6 100644 --- a/tests/User.py +++ b/tests/User.py @@ -8,7 +8,7 @@ class User(Model): __fillable__ = ["name", "email", "password"] - __connection__ = "mysql" + __connection__ = "t" __auth__ = "email" diff --git a/tests/collection/test_collection.py b/tests/collection/test_collection.py index 706666542..02bab4cf1 100644 --- a/tests/collection/test_collection.py +++ b/tests/collection/test_collection.py @@ -185,6 +185,24 @@ def test_max(self): collection = Collection([{"batch": 1}, {"batch": 1}]) self.assertEqual(collection.max("batch"), 1) + def test_min(self): + collection = Collection([1, 1, 2, 4]) + self.assertEqual(collection.min(), 1) + + collection = Collection( + [ + {"name": "Corentin All", "age": 1}, + {"name": "Corentin All", "age": 2}, + {"name": "Corentin All", "age": 3}, + {"name": "Corentin All", "age": 4}, + ] + ) + self.assertEqual(collection.min("age"), 1) + self.assertEqual(collection.min(), 0) + + collection = Collection([{"batch": 1}, {"batch": 1}]) + self.assertEqual(collection.min("batch"), 1) + def test_count(self): collection = Collection([1, 1, 2, 4]) self.assertEqual(collection.count(), 4) diff --git a/tests/connections/test_base_connections.py b/tests/connections/test_base_connections.py index 7711af163..96fc2e3c3 100644 --- a/tests/connections/test_base_connections.py +++ b/tests/connections/test_base_connections.py @@ -6,7 +6,6 @@ class TestDefaultBehaviorConnections(unittest.TestCase): def test_should_return_connection_with_enabled_logs(self): - connection = DB.begin_transaction("dev") should_log_queries = connection.full_details.get("log_queries") DB.commit("dev") @@ -14,7 +13,6 @@ def test_should_return_connection_with_enabled_logs(self): self.assertTrue(should_log_queries) def test_should_disable_log_queries_in_connection(self): - connection = DB.begin_transaction("dev") connection.disable_query_log() diff --git a/tests/eagers/test_eager.py b/tests/eagers/test_eager.py index 97894f48d..482f21607 100644 --- a/tests/eagers/test_eager.py +++ b/tests/eagers/test_eager.py @@ -6,7 +6,6 @@ class TestEagerRelation(unittest.TestCase): def test_can_register_string_eager_load(self): - self.assertEqual( EagerRelations().register("profile").get_eagers(), [["profile"]] ) @@ -31,7 +30,6 @@ def test_can_register_string_eager_load(self): ) def test_can_register_tuple_eager_load(self): - self.assertEqual( EagerRelations().register(("profile",)).get_eagers(), [["profile"]] ) @@ -45,7 +43,6 @@ def test_can_register_tuple_eager_load(self): ) def test_can_register_list_eager_load(self): - self.assertEqual( EagerRelations().register(["profile"]).get_eagers(), [["profile"]] ) diff --git a/tests/integrations/config/database.py b/tests/integrations/config/database.py index 096918c0f..ed1fd02e0 100644 --- a/tests/integrations/config/database.py +++ b/tests/integrations/config/database.py @@ -24,6 +24,7 @@ They can be named whatever you want. """ + DATABASES = { "default": "mysql", "mysql": { @@ -37,8 +38,11 @@ "options": {"charset": "utf8mb4"}, "log_queries": True, "propagate": False, + "connection_pooling_enabled": True, + "connection_pooling_max_size": 10, + "connection_pooling_min_size": None, }, - "t": {"driver": "sqlite", "database": "ormtestreg.sqlite3", "log_queries": True}, + "t": {"driver": "sqlite", "database": "orm.sqlite3", "log_queries": True, "foreign_keys": True}, "devprod": { "driver": "mysql", "host": os.getenv("MYSQL_DATABASE_HOST"), @@ -69,6 +73,9 @@ "password": os.getenv("POSTGRES_DATABASE_PASSWORD"), "database": os.getenv("POSTGRES_DATABASE_DATABASE"), "port": os.getenv("POSTGRES_DATABASE_PORT"), + "connection_pooling_enabled": True, + "connection_pooling_max_size": 10, + "connection_pooling_min_size": 2, "prefix": "", "log_queries": True, "propagate": False, @@ -101,6 +108,8 @@ "authentication": "ActiveDirectoryPassword", "driver": "ODBC Driver 17 for SQL Server", "connection_timeout": 15, + "connection_pooling": False, + "connection_pooling_size": 100, }, }, } diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 97368381e..0def41530 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,8 +1,10 @@ +import datetime import json import unittest -from src.masoniteorm.models import Model + import pendulum -import datetime + +from src.masoniteorm.models import Model class ModelTest(Model): @@ -16,10 +18,30 @@ class ModelTest(Model): } +class FillableModelTest(Model): + __fillable__ = ["due_date", "is_vip"] + + +class InvalidFillableGuardedModelTest(Model): + __fillable__ = ["due_date"] + __guarded__ = ["is_vip", "payload"] + + +class InvalidFillableGuardedChildModelTest(ModelTest): + __fillable__ = ["due_date"] + __guarded__ = ["is_vip", "payload"] + + class ModelTestForced(Model): __table__ = "users" __force_update__ = True +class BaseModel(Model): + def get_selects(self): + return [f"{self.get_table_name()}.*"] + +class ModelWithBaseModel(BaseModel): + __table__ = "users" class TestModels(unittest.TestCase): def test_model_can_access_str_dates_as_pendulum(self): @@ -30,7 +52,6 @@ def test_model_can_access_str_dates_as_pendulum(self): self.assertIsInstance(model.due_date, pendulum.now().__class__) def test_model_can_access_str_dates_as_pendulum_from_correct_datetimes(self): - model = ModelTest() self.assertEqual( @@ -77,13 +98,13 @@ def test_model_creates_when_new(self): model = ModelTest.hydrate({"id": 1, "username": "joe", "admin": True}) model.name = "Bill" - sql = model.save(query=True) + sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("UPDATE")) model = ModelTest() model.name = "Bill" - sql = model.save(query=True) + sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("INSERT")) def test_model_can_cast_attributes(self): @@ -97,8 +118,7 @@ def test_model_can_cast_attributes(self): } ) - self.assertEqual(type(model.payload), str) - self.assertEqual(type(json.loads(model.payload)), list) + self.assertEqual(type(model.payload), list) self.assertEqual(type(model.x), int) self.assertEqual(type(model.f), float) self.assertEqual(type(model.is_vip), bool) @@ -129,13 +149,43 @@ def test_model_can_cast_dict_attributes(self): {"is_vip": 1, "payload": dictcasttest, "x": True, "f": "10.5"} ) - self.assertEqual(type(model.payload), str) - self.assertEqual(type(json.loads(model.payload)), dict) + self.assertEqual(type(model.payload), dict) self.assertEqual(type(model.x), int) self.assertEqual(type(model.f), float) self.assertEqual(type(model.is_vip), bool) self.assertEqual(type(model.serialize()["is_vip"]), bool) + def test_valid_json_cast(self): + model = ModelTest.hydrate( + {"payload": {"this": "dict", "is": "usable", "as": "json"}} + ) + + self.assertEqual(type(model.payload), dict) + + model = ModelTest.hydrate( + {"payload": {"this": "dict", "is": "invalid", "as": "json"}} + ) + + self.assertEqual(type(model.payload), dict) + + model = ModelTest.hydrate( + {"payload": '{"this": "dict", "is": "usable", "as": "json"}'} + ) + + self.assertEqual(type(model.payload), dict) + + model = ModelTest.hydrate({"payload": '{"valid": "json", "int": 1}'}) + + self.assertEqual(type(model.payload), dict) + + model = ModelTest.hydrate({"payload": "{'this': 'should', 'throw': 'error'}"}) + + self.assertEqual(model.payload, None) + + with self.assertRaises(ValueError): + model.payload = "{'this': 'should', 'throw': 'error'}" + model.save() + def test_model_update_without_changes(self): model = ModelTest.hydrate( {"id": 1, "username": "joe", "name": "Joe", "admin": True} @@ -143,7 +193,7 @@ def test_model_update_without_changes(self): model.username = "joe" model.name = "Bill" - sql = model.save(query=True) + sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("UPDATE")) self.assertNotIn("username", sql) @@ -154,11 +204,19 @@ def test_force_update_on_model_class(self): model.username = "joe" model.name = "Bill" - sql = model.save(query=True) + sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("UPDATE")) self.assertIn("username", sql) self.assertIn("name", sql) + def test_only_method(self): + model = ModelTestForced.hydrate( + {"id": 1, "username": "joe", "name": "Joe", "admin": True} + ) + + self.assertEqual({"username": "joe"}, model.only("username")) + self.assertEqual({"username": "joe"}, model.only(["username"])) + def test_model_update_without_changes_at_all(self): model = ModelTest.hydrate( {"id": 1, "username": "joe", "name": "Joe", "admin": True} @@ -166,7 +224,7 @@ def test_model_update_without_changes_at_all(self): model.username = "joe" model.name = "Joe" - sql = model.save(query=True) + sql = model.save(query=True).to_sql() self.assertFalse(sql.startswith("UPDATE")) def test_model_using_or_where(self): @@ -195,3 +253,40 @@ def test_model_using_or_where_and_chaining_wheres(self): sql, """SELECT * FROM `model_tests` WHERE `model_tests`.`name` = 'joe' OR (`model_tests`.`username` = 'Joseph' OR `model_tests`.`age` >= '18'))""", ) + + def test_both_fillable_and_guarded_attributes_raise(self): + # Both fillable and guarded props are populated on this class + with self.assertRaises(AttributeError): + InvalidFillableGuardedModelTest() + # Child that inherits from an intermediary class also fails + with self.assertRaises(AttributeError): + InvalidFillableGuardedChildModelTest() + # Still shouldn't be allowed to define even if empty + InvalidFillableGuardedModelTest.__fillable__ = [] + with self.assertRaises(AttributeError): + InvalidFillableGuardedModelTest() + # Or wildcard + InvalidFillableGuardedModelTest.__fillable__ = ["*"] + with self.assertRaises(AttributeError): + InvalidFillableGuardedModelTest() + # Empty guarded attr still raises + InvalidFillableGuardedModelTest.__guarded__ = [] + with self.assertRaises(AttributeError): + InvalidFillableGuardedModelTest() + # Removing one of the props allows us to instantiate + delattr(InvalidFillableGuardedModelTest, "__guarded__") + InvalidFillableGuardedModelTest() + + def test_model_can_provide_default_select(self): + sql = ModelWithBaseModel.to_sql() + self.assertEqual( + sql, + """SELECT `users`.* FROM `users`""", + ) + + def test_model_can_add_to_default_select(self): + sql = ModelWithBaseModel.select(["products.name", "products.id", "store.name"]).to_sql() + self.assertEqual( + sql, + """SELECT `users`.*, `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""", + ) diff --git a/tests/mssql/builder/test_mssql_query_builder.py b/tests/mssql/builder/test_mssql_query_builder.py index a3b450792..17fa5e38b 100644 --- a/tests/mssql/builder/test_mssql_query_builder.py +++ b/tests/mssql/builder/test_mssql_query_builder.py @@ -9,7 +9,6 @@ class MockConnection: - connection_details = {} def make_connection(self): @@ -411,3 +410,31 @@ def test_truncate_without_foreign_keys(self): builder = self.get_builder(dry=True) sql = builder.truncate(foreign_keys=True) self.assertEqual(sql, "TRUNCATE TABLE [users]") + + def test_latest(self): + builder = self.get_builder() + builder.latest("email") + self.assertEqual( + builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC" + ) + + def test_latest_multiple(self): + builder = self.get_builder() + builder.latest("email", "created_at") + self.assertEqual( + builder.to_sql(), + "SELECT * FROM [users] ORDER BY [email] DESC, [created_at] DESC", + ) + + def test_oldest(self): + builder = self.get_builder() + builder.oldest("email") + self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC") + + def test_oldest_multiple(self): + builder = self.get_builder() + builder.oldest("email", "created_at") + self.assertEqual( + builder.to_sql(), + "SELECT * FROM [users] ORDER BY [email] ASC, [created_at] ASC", + ) diff --git a/tests/mssql/builder/test_mssql_query_builder_relationships.py b/tests/mssql/builder/test_mssql_query_builder_relationships.py index 2c50f8fdb..b4fb5566e 100644 --- a/tests/mssql/builder/test_mssql_query_builder_relationships.py +++ b/tests/mssql/builder/test_mssql_query_builder_relationships.py @@ -40,9 +40,16 @@ def articles(self): def profile(self): return Profile + @belongs_to("id", "parent_dynamic_id") + def parent_dynamic(self): + return self.__class__ + + @belongs_to("id", "parent_specified_id") + def parent_specified(self): + return User -class BaseTestQueryRelationships(unittest.TestCase): +class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users"): @@ -65,6 +72,26 @@ def test_has(self): """)""", ) + def test_has_reference_to_self(self): + builder = self.get_builder() + sql = builder.has("parent_dynamic").to_sql() + self.assertEqual( + sql, + """SELECT * FROM [users] WHERE EXISTS (""" + """SELECT * FROM [users] WHERE [users].[parent_dynamic_id] = [users].[id]""" + """)""", + ) + + def test_has_reference_to_self_using_class(self): + builder = self.get_builder() + sql = builder.has("parent_specified").to_sql() + self.assertEqual( + sql, + """SELECT * FROM [users] WHERE EXISTS (""" + """SELECT * FROM [users] WHERE [users].[parent_specified_id] = [users].[id]""" + """)""", + ) + def test_where_has_query(self): builder = self.get_builder() sql = builder.where_has("articles", lambda q: q.where("active", 1)).to_sql() diff --git a/tests/mssql/grammar/test_mssql_insert_grammar.py b/tests/mssql/grammar/test_mssql_insert_grammar.py index b5f7bc8b3..8980db786 100644 --- a/tests/mssql/grammar/test_mssql_insert_grammar.py +++ b/tests/mssql/grammar/test_mssql_insert_grammar.py @@ -9,7 +9,6 @@ def setUp(self): self.builder = QueryBuilder(MSSQLGrammar, table="users") def test_can_compile_insert(self): - to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() sql = "INSERT INTO [users] ([users].[name]) VALUES ('Joe')" @@ -17,10 +16,16 @@ def test_can_compile_insert(self): def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( - [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True + # These keys are intentionally out of order to show column to value alignment works + [ + {"name": "Joe", "age": 5}, + {"age": 35, "name": "Bill"}, + {"name": "John", "age": 10}, + ], + query=True, ).to_sql() - sql = "INSERT INTO [users] ([name]) VALUES ('Joe'), ('Bill'), ('John')" + sql = "INSERT INTO [users] ([age], [name]) VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')" self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): diff --git a/tests/mssql/grammar/test_mssql_select_grammar.py b/tests/mssql/grammar/test_mssql_select_grammar.py index be747595d..a0595df8f 100644 --- a/tests/mssql/grammar/test_mssql_select_grammar.py +++ b/tests/mssql/grammar/test_mssql_select_grammar.py @@ -6,7 +6,6 @@ class TestMSSQLGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): - grammar = MSSQLGrammar def can_compile_select(self): @@ -246,6 +245,12 @@ def can_compile_having(self): """ return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age]" + def can_compile_having_order(self): + """ + builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() + """ + return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] ORDER [users].[age] DESC" + def can_compile_between(self): """ builder.between('age', 18, 21).to_sql() @@ -308,6 +313,18 @@ def test_can_compile_having_raw(self): to_sql, "SELECT COUNT(*) as counts FROM [users] HAVING counts > 10" ) + def test_can_compile_having_raw_order(self): + to_sql = ( + self.builder.select_raw("COUNT(*) as counts") + .having_raw("counts > 10") + .order_by_raw("counts DESC") + .to_sql() + ) + self.assertEqual( + to_sql, + "SELECT COUNT(*) as counts FROM [users] HAVING counts > 10 ORDER BY counts DESC", + ) + def test_can_compile_select_raw(self): to_sql = self.builder.select_raw("COUNT(*)").to_sql() sql = getattr( @@ -388,11 +405,25 @@ def can_compile_join_clause_with_null(self): clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") + .or_on_null("bgt.dept") + .on_value("rg.abc", 10) + ) + builder.join(clause).to_sql() + """ + return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [acct] IS NULL OR [dept] IS NULL AND [rg].[abc] = '10'" + + def can_compile_join_clause_with_not_null(self): + """ + builder = self.get_builder() + clause = ( + JoinClause("report_groups as rg") + .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") + .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ - return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [acct] IS NULL OR [dept] IS NOT NULL" + return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [acct] IS NOT NULL OR [dept] IS NOT NULL AND [rg].[abc] = '10'" def can_compile_join_clause_with_lambda(self): """ diff --git a/tests/mssql/grammar/test_mssql_update_grammar.py b/tests/mssql/grammar/test_mssql_update_grammar.py index d43704a43..49c6e4ab4 100644 --- a/tests/mssql/grammar/test_mssql_update_grammar.py +++ b/tests/mssql/grammar/test_mssql_update_grammar.py @@ -10,7 +10,6 @@ def setUp(self): self.builder = QueryBuilder(MSSQLGrammar, table="users") def test_can_compile_update(self): - to_sql = ( self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() ) diff --git a/tests/mssql/schema/test_mssql_schema_builder.py b/tests/mssql/schema/test_mssql_schema_builder.py index e46c6efb5..4205bc7eb 100644 --- a/tests/mssql/schema/test_mssql_schema_builder.py +++ b/tests/mssql/schema/test_mssql_schema_builder.py @@ -29,6 +29,26 @@ def test_can_add_columns(self): ["CREATE TABLE [users] ([name] VARCHAR(255) NOT NULL, [age] INT NOT NULL)"], ) + def test_can_add_tiny_text(self): + with self.schema.create("users") as blueprint: + blueprint.tiny_text("description") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + ["CREATE TABLE [users] ([description] TINYTEXT NOT NULL)"], + ) + + def test_can_add_unsigned_decimal(self): + with self.schema.create("users") as blueprint: + blueprint.unsigned_decimal("amount", 19, 4) + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + ["CREATE TABLE [users] ([amount] DECIMAL(19, 4) NOT NULL)"], + ) + def test_can_add_columns_with_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name") @@ -280,3 +300,27 @@ def test_can_truncate_without_foreign_keys(self): "ALTER TABLE [users] WITH CHECK CHECK CONSTRAINT ALL", ], ) + + def test_can_add_enum(self): + with self.schema.create("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + [ + "CREATE TABLE [users] ([status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive')))" + ], + ) + + def test_can_change_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active").change() + + self.assertEqual(len(blueprint.table.changed_columns), 1) + self.assertEqual( + blueprint.to_sql(), + [ + "ALTER TABLE [users] ALTER COLUMN [status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive'))" + ], + ) diff --git a/tests/mssql/schema/test_mssql_schema_builder_alter.py b/tests/mssql/schema/test_mssql_schema_builder_alter.py index ae85faa73..f1b323e27 100644 --- a/tests/mssql/schema/test_mssql_schema_builder_alter.py +++ b/tests/mssql/schema/test_mssql_schema_builder_alter.py @@ -11,7 +11,6 @@ class TestMySQLSchemaBuilderAlter(unittest.TestCase): maxDiff = None def setUp(self): - self.schema = Schema( connection_class=MSSQLConnection, connection="mssql", @@ -269,3 +268,15 @@ def test_timestamp_alter_add_nullable_column(self): sql = ["ALTER TABLE [users] ADD [due_date] DATETIME NULL"] self.assertEqual(blueprint.to_sql(), sql) + + def test_can_add_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + + sql = [ + "ALTER TABLE [users] ADD [status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive'))" + ] + + self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/mysql/builder/test_mysql_builder_transaction.py b/tests/mysql/builder/test_mysql_builder_transaction.py index df6a0f218..7e0fa39bd 100644 --- a/tests/mysql/builder/test_mysql_builder_transaction.py +++ b/tests/mysql/builder/test_mysql_builder_transaction.py @@ -17,7 +17,6 @@ class User(Model): __timestamps__ = False class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def get_builder(self, table="users"): diff --git a/tests/mysql/builder/test_query_builder.py b/tests/mysql/builder/test_query_builder.py index 8ce5df4e8..88a7382b0 100644 --- a/tests/mysql/builder/test_query_builder.py +++ b/tests/mysql/builder/test_query_builder.py @@ -1,14 +1,14 @@ +import datetime import inspect import unittest -from tests.integrations.config.database import DATABASES from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.relationships import has_many from src.masoniteorm.scopes import SoftDeleteScope +from tests.integrations.config.database import DATABASES from tests.utils import MockConnectionFactory -import datetime class Articles(Model): @@ -549,15 +549,6 @@ def test_update_lock(self): )() self.assertEqual(sql, sql_ref) - def test_cast_values(self): - builder = self.get_builder(dry=True) - result = builder.cast_dates({"created_at": datetime.datetime(2021, 1, 1)}) - self.assertEqual(result, {"created_at": "2021-01-01T00:00:00+00:00"}) - result = builder.cast_dates({"created_at": datetime.date(2021, 1, 1)}) - self.assertEqual(result, {"created_at": "2021-01-01T00:00:00+00:00"}) - result = builder.cast_dates([{"created_at": datetime.date(2021, 1, 1)}]) - self.assertEqual(result, [{"created_at": "2021-01-01T00:00:00+00:00"}]) - class MySQLQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase): grammar = MySQLGrammar @@ -917,3 +908,31 @@ def update_lock(self): builder.truncate() """ return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE" + + def test_latest(self): + builder = self.get_builder() + builder.latest("email") + sql = getattr( + self, inspect.currentframe().f_code.co_name.replace("test_", "") + )() + self.assertEqual(builder.to_sql(), sql) + + def test_oldest(self): + builder = self.get_builder() + builder.oldest("email") + sql = getattr( + self, inspect.currentframe().f_code.co_name.replace("test_", "") + )() + self.assertEqual(builder.to_sql(), sql) + + def latest(self): + """ + builder.order_by('email', 'des') + """ + return "SELECT * FROM `users` ORDER BY `email` DESC" + + def oldest(self): + """ + builder.order_by('email', 'asc') + """ + return "SELECT * FROM `users` ORDER BY `email` ASC" diff --git a/tests/mysql/connections/test_mysql_connection_selects.py b/tests/mysql/connections/test_mysql_connection_selects.py index 0db03e3a5..acb09e50e 100644 --- a/tests/mysql/connections/test_mysql_connection_selects.py +++ b/tests/mysql/connections/test_mysql_connection_selects.py @@ -7,7 +7,6 @@ class MockUser(Model): - __table__ = "users" diff --git a/tests/mysql/grammar/test_mysql_delete_grammar.py b/tests/mysql/grammar/test_mysql_delete_grammar.py index 84ee5fea4..c17acee73 100644 --- a/tests/mysql/grammar/test_mysql_delete_grammar.py +++ b/tests/mysql/grammar/test_mysql_delete_grammar.py @@ -41,7 +41,6 @@ def test_can_compile_delete_with_where(self): class TestMySQLDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase): - grammar = "mysql" def can_compile_delete(self): diff --git a/tests/mysql/grammar/test_mysql_insert_grammar.py b/tests/mysql/grammar/test_mysql_insert_grammar.py index 5c3afa236..0089ba2c2 100644 --- a/tests/mysql/grammar/test_mysql_insert_grammar.py +++ b/tests/mysql/grammar/test_mysql_insert_grammar.py @@ -27,7 +27,13 @@ def test_can_compile_insert_with_keywords(self): def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( - [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True + # These keys are intentionally out of order to show column to value alignment works + [ + {"name": "Joe", "age": 5}, + {"age": 35, "name": "Bill"}, + {"name": "John", "age": 10}, + ], + query=True, ).to_sql() sql = getattr( @@ -62,7 +68,6 @@ def test_can_compile_bulk_create_multiple(self): class TestMySQLUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase): - grammar = "mysql" def can_compile_insert(self): @@ -83,13 +88,13 @@ def can_compile_bulk_create(self): """ self.builder.create(name="Joe").to_sql() """ - return """INSERT INTO `users` (`name`) VALUES ('Joe'), ('Bill'), ('John')""" + return """INSERT INTO `users` (`age`, `name`) VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')""" def can_compile_bulk_create_multiple(self): """ self.builder.create(name="Joe").to_sql() """ - return """INSERT INTO `users` (`name`, `active`) VALUES ('Joe', '1'), ('Bill', '1'), ('John', '1')""" + return """INSERT INTO `users` (`active`, `name`) VALUES ('1', 'Joe'), ('1', 'Bill'), ('1', 'John')""" def can_compile_bulk_create_qmark(self): """ diff --git a/tests/mysql/grammar/test_mysql_qmark.py b/tests/mysql/grammar/test_mysql_qmark.py index 5d804baae..ece461d86 100644 --- a/tests/mysql/grammar/test_mysql_qmark.py +++ b/tests/mysql/grammar/test_mysql_qmark.py @@ -82,7 +82,6 @@ def test_can_compile_where_with_false_value(self): self.assertEqual(mark._bindings, bindings) def test_can_compile_sub_group_bindings(self): - mark = self.builder.where( lambda query: ( query.where("challenger", 1) diff --git a/tests/mysql/grammar/test_mysql_select_grammar.py b/tests/mysql/grammar/test_mysql_select_grammar.py index 53cf49e66..47532a06d 100644 --- a/tests/mysql/grammar/test_mysql_select_grammar.py +++ b/tests/mysql/grammar/test_mysql_select_grammar.py @@ -6,7 +6,6 @@ class TestMySQLGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): - grammar = MySQLGrammar def can_compile_select(self): @@ -248,6 +247,12 @@ def can_compile_having(self): """ return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age`" + def can_compile_having_order(self): + """ + builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() + """ + return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` ORDER `users`.`age` DESC" + def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() @@ -304,6 +309,18 @@ def test_can_compile_having_raw(self): to_sql, "SELECT COUNT(*) as counts FROM `users` HAVING counts > 10" ) + def test_can_compile_having_raw_order(self): + to_sql = ( + self.builder.select_raw("COUNT(*) as counts") + .having_raw("counts > 10") + .order_by_raw("counts DESC") + .to_sql() + ) + self.assertEqual( + to_sql, + "SELECT COUNT(*) as counts FROM `users` HAVING counts > 10 ORDER BY counts DESC", + ) + def test_can_compile_select_raw(self): to_sql = self.builder.select_raw("COUNT(*)").to_sql() self.assertEqual(to_sql, "SELECT COUNT(*) FROM `users`") @@ -378,11 +395,25 @@ def can_compile_join_clause_with_null(self): clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") + .or_on_null("bgt.dept") + .on_value("rg.abc", 10) + ) + builder.join(clause).to_sql() + """ + return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `acct` IS NULL OR `dept` IS NULL AND `rg`.`abc` = '10'" + + def can_compile_join_clause_with_not_null(self): + """ + builder = self.get_builder() + clause = ( + JoinClause("report_groups as rg") + .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") + .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ - return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `acct` IS NULL OR `dept` IS NOT NULL" + return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `acct` IS NOT NULL OR `dept` IS NOT NULL AND `rg`.`abc` = '10'" def can_compile_join_clause_with_lambda(self): """ diff --git a/tests/mysql/grammar/test_mysql_update_grammar.py b/tests/mysql/grammar/test_mysql_update_grammar.py index a427f02e1..0212ec8fe 100644 --- a/tests/mysql/grammar/test_mysql_update_grammar.py +++ b/tests/mysql/grammar/test_mysql_update_grammar.py @@ -70,7 +70,6 @@ def test_raw_expression(self): class TestMySQLUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase): - grammar = MySQLGrammar def can_compile_update(self): diff --git a/tests/mysql/model/test_accessors_and_mutators.py b/tests/mysql/model/test_accessors_and_mutators.py index 97df526df..6423816da 100644 --- a/tests/mysql/model/test_accessors_and_mutators.py +++ b/tests/mysql/model/test_accessors_and_mutators.py @@ -12,7 +12,6 @@ class User(Model): - __casts__ = {"is_admin": "bool"} def get_name_attribute(self): @@ -23,7 +22,6 @@ def set_name_attribute(self, attribute): class SetUser(Model): - __casts__ = {"is_admin": "bool"} def set_name_attribute(self, attribute): diff --git a/tests/mysql/model/test_model.py b/tests/mysql/model/test_model.py index 6711f98af..402da2650 100644 --- a/tests/mysql/model/test_model.py +++ b/tests/mysql/model/test_model.py @@ -2,234 +2,342 @@ import json import os import unittest + import pendulum -from src.masoniteorm.exceptions import ModelNotFound from src.masoniteorm.collection import Collection +from src.masoniteorm.exceptions import ModelNotFound from src.masoniteorm.models import Model from tests.User import User -if os.getenv("RUN_MYSQL_DATABASE", False) == "True": - - class ProfileFillable(Model): - __fillable__ = ["name"] - __table__ = "profiles" - __timestamps__ = None - - class ProfileFillTimeStamped(Model): - __fillable__ = ["*"] - __table__ = "profiles" - - class ProfileFillAsterisk(Model): - __fillable__ = ["*"] - __table__ = "profiles" - __timestamps__ = None - - class ProfileGuarded(Model): - __guarded__ = ["email"] - __table__ = "profiles" - __timestamps__ = None - - class ProfileSerialize(Model): - __fillable__ = ["*"] - __table__ = "profiles" - __hidden__ = ["password"] - - class ProfileSerializeWithVisible(Model): - __fillable__ = ["*"] - __table__ = "profiles" - __visible__ = ["name", "email"] - - class ProfileSerializeWithVisibleAndHidden(Model): - __fillable__ = ["*"] - __table__ = "profiles" - __visible__ = ["name", "email"] - __hidden__ = ["password"] - - class Profile(Model): - pass - - class Company(Model): - pass - - class User(Model): - @property - def meta(self): - return {"is_subscribed": True} - - class ProductNames(Model): - pass - - class TestModel(unittest.TestCase): - def test_can_use_fillable(self): - sql = ProfileFillable.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ) - self.assertEqual( - sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" - ) +class ProfileFillable(Model): + __fillable__ = ["name"] + __table__ = "profiles" + __timestamps__ = None - def test_can_use_fillable_asterisk(self): - sql = ProfileFillAsterisk.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ) - self.assertEqual( - sql, - "INSERT INTO `profiles` (`profiles`.`name`, `profiles`.`email`) VALUES ('Joe', 'user@example.com')", - ) +class ProfileFillTimeStamped(Model): + __fillable__ = ["*"] + __table__ = "profiles" - def test_can_use_guarded(self): - sql = ProfileGuarded.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ) - self.assertEqual( - sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" - ) +class ProfileFillAsterisk(Model): + __fillable__ = ["*"] + __table__ = "profiles" + __timestamps__ = None - def test_can_use_guarded_asterisk(self): - sql = ProfileFillAsterisk.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ) - self.assertEqual( - sql, - "INSERT INTO `profiles` (`profiles`.`name`, `profiles`.`email`) VALUES ('Joe', 'user@example.com')", - ) +class ProfileGuarded(Model): + __guarded__ = ["email"] + __table__ = "profiles" + __timestamps__ = None - def test_can_touch(self): - profile = ProfileFillTimeStamped.hydrate({"name": "Joe", "id": 1}) - sql = profile.touch("now", query=True) +class ProfileGuardedAsterisk(Model): + __guarded__ = ["*"] + __table__ = "profiles" + __timestamps__ = None - self.assertEqual( - sql, - "UPDATE `profiles` SET `profiles`.`updated_at` = 'now' WHERE `profiles`.`id` = '1'", - ) - def test_table_name(self): - table_name = Profile.get_table_name() - self.assertEqual(table_name, "profiles") +class ProfileSerialize(Model): + __fillable__ = ["*"] + __table__ = "profiles" + __hidden__ = ["password"] - table_name = Company.get_table_name() - self.assertEqual(table_name, "companies") - table_name = ProductNames.get_table_name() - self.assertEqual(table_name, "product_names") +class ProfileSerializeWithVisible(Model): + __fillable__ = ["*"] + __table__ = "profiles" + __visible__ = ["name", "email"] - def test_returns_correct_data_type(self): - self.assertIsInstance(User.all(), Collection) - # self.assertIsInstance(User.first(), User) - # self.assertIsInstance(User.first(), User) - def test_serialize(self): - profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) +class ProfileSerializeWithVisibleAndHidden(Model): + __fillable__ = ["*"] + __table__ = "profiles" + __visible__ = ["name", "email"] + __hidden__ = ["password"] - self.assertEqual(profile.serialize(), {"name": "Joe", "id": 1}) - def test_json(self): - profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) +class Profile(Model): + pass - self.assertEqual(profile.to_json(), '{"name": "Joe", "id": 1}') - def test_serialize_with_hidden(self): - profile = ProfileSerialize.hydrate( - {"name": "Joe", "id": 1, "password": "secret"} - ) +class Company(Model): + pass - self.assertTrue(profile.serialize().get("name")) - self.assertTrue(profile.serialize().get("id")) - self.assertFalse(profile.serialize().get("password")) - - def test_serialize_with_visible(self): - profile = ProfileSerializeWithVisible.hydrate( - { - "name": "Joe", - "id": 1, - "password": "secret", - "email": "joe@masonite.com", - } - ) - self.assertTrue( - {"name": "Joe", "email": "joe@masonite.com"}, profile.serialize() - ) - def test_serialize_with_visible_and_hidden_raise_error(self): - profile = ProfileSerializeWithVisibleAndHidden.hydrate( - { - "name": "Joe", - "id": 1, - "password": "secret", - "email": "joe@masonite.com", - } - ) - with self.assertRaises(AttributeError): - profile.serialize() - - def test_serialize_with_on_the_fly_appends(self): - user = User.hydrate({"name": "Joe", "id": 1}) - - user.set_appends(["meta"]) - - serialized = user.serialize() - self.assertEqual(serialized["id"], 1) - self.assertEqual(serialized["name"], "Joe") - self.assertEqual(serialized["meta"]["is_subscribed"], True) - - def test_serialize_with_model_appends(self): - User.__appends__ = ["meta"] - user = User.hydrate({"name": "Joe", "id": 1}) - serialized = user.serialize() - self.assertEqual(serialized["id"], 1) - self.assertEqual(serialized["name"], "Joe") - self.assertEqual(serialized["meta"]["is_subscribed"], True) - - def test_serialize_with_date(self): - user = User.hydrate({"name": "Joe", "created_at": pendulum.now()}) - - self.assertTrue(json.dumps(user.serialize())) - - def test_set_as_date(self): - user = User.hydrate( - { - "name": "Joe", - "created_at": pendulum.now().add(days=10).to_datetime_string(), - } - ) +class User(Model): + @property + def meta(self): + return {"is_subscribed": True} - self.assertTrue(user.created_at) - self.assertTrue(user.created_at.is_future()) - def test_access_as_date(self): - user = User.hydrate( - { - "name": "Joe", - "created_at": datetime.datetime.now() + datetime.timedelta(days=1), - } - ) +class ProductNames(Model): + pass - self.assertTrue(user.created_at) - self.assertTrue(user.created_at.is_future()) - def test_hydrate_with_none(self): - profile = ProfileFillAsterisk.hydrate(None) +class TestModel(unittest.TestCase): + def test_create_can_use_fillable(self): + sql = ProfileFillable.create( + {"name": "Joe", "email": "user@example.com"}, query=True + ).to_sql() - self.assertEqual(profile, None) + self.assertEqual( + sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" + ) + def test_create_can_use_fillable_asterisk(self): + sql = ProfileFillAsterisk.create( + {"name": "Joe", "email": "user@example.com"}, query=True + ).to_sql() + + self.assertEqual( + sql, + "INSERT INTO `profiles` (`profiles`.`name`, `profiles`.`email`) VALUES ('Joe', 'user@example.com')", + ) + + def test_create_can_use_guarded(self): + sql = ProfileGuarded.create( + {"name": "Joe", "email": "user@example.com"}, query=True + ).to_sql() + + self.assertEqual( + sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" + ) + + def test_create_can_use_guarded_asterisk(self): + sql = ProfileGuardedAsterisk.create( + {"name": "Joe", "email": "user@example.com"}, query=True + ).to_sql() + + # An asterisk guarded attribute excludes all fields from mass-assignment. + # This would raise a DB error if there are any required fields. + self.assertEqual(sql, "INSERT INTO `profiles` (*) VALUES ()") + + def test_bulk_create_can_use_fillable(self): + query_builder = ProfileFillable.bulk_create( + [ + {"name": "Joe", "email": "user@example.com"}, + {"name": "Joe II", "email": "userII@example.com"}, + ], + query=True, + ) + + self.assertEqual( + query_builder.to_sql(), + "INSERT INTO `profiles` (`name`) VALUES ('Joe'), ('Joe II')", + ) + + def test_bulk_create_can_use_fillable_asterisk(self): + query_builder = ProfileFillAsterisk.bulk_create( + [ + {"name": "Joe", "email": "user@example.com"}, + {"name": "Joe II", "email": "userII@example.com"}, + ], + query=True, + ) + + self.assertEqual( + query_builder.to_sql(), + "INSERT INTO `profiles` (`email`, `name`) VALUES ('user@example.com', 'Joe'), ('userII@example.com', 'Joe II')", + ) + + def test_bulk_create_can_use_guarded(self): + query_builder = ProfileGuarded.bulk_create( + [ + {"name": "Joe", "email": "user@example.com"}, + {"name": "Joe II", "email": "userII@example.com"}, + ], + query=True, + ) + + self.assertEqual( + query_builder.to_sql(), + "INSERT INTO `profiles` (`name`) VALUES ('Joe'), ('Joe II')", + ) + + def test_bulk_create_can_use_guarded_asterisk(self): + query_builder = ProfileGuardedAsterisk.bulk_create( + [ + {"name": "Joe", "email": "user@example.com"}, + {"name": "Joe II", "email": "userII@example.com"}, + ], + query=True, + ) + + # An asterisk guarded attribute excludes all fields from mass-assignment. + # This would obviously raise an invalid SQL syntax error. + # TODO: Raise a clearer error? + self.assertEqual( + query_builder.to_sql(), "INSERT INTO `profiles` () VALUES (), ()" + ) + + def test_update_can_use_fillable(self): + query_builder = ProfileFillable().update( + {"name": "Joe", "email": "user@example.com"}, dry=True + ) + + self.assertEqual( + query_builder.to_sql(), "UPDATE `profiles` SET `profiles`.`name` = 'Joe'" + ) + + def test_update_can_use_fillable_asterisk(self): + query_builder = ProfileFillAsterisk().update( + {"name": "Joe", "email": "user@example.com"}, dry=True + ) + + self.assertEqual( + query_builder.to_sql(), + "UPDATE `profiles` SET `profiles`.`name` = 'Joe', `profiles`.`email` = 'user@example.com'", + ) + + def test_update_can_use_guarded(self): + query_builder = ProfileGuarded().update( + {"name": "Joe", "email": "user@example.com"}, dry=True + ) + + self.assertEqual( + query_builder.to_sql(), "UPDATE `profiles` SET `profiles`.`name` = 'Joe'" + ) + + def test_update_can_use_guarded_asterisk(self): + profile = ProfileGuardedAsterisk() + initial_sql = profile.get_builder().to_sql() + query_builder = profile.update( + {"name": "Joe", "email": "user@example.com"}, dry=True + ) + + # An asterisk guarded attribute excludes all fields from mass-assignment. + # The query builder's sql should not have been altered in any way. + self.assertEqual(query_builder.to_sql(), initial_sql) + + def test_table_name(self): + table_name = Profile.get_table_name() + self.assertEqual(table_name, "profiles") + + table_name = Company.get_table_name() + self.assertEqual(table_name, "companies") + + table_name = ProductNames.get_table_name() + self.assertEqual(table_name, "product_names") + + def test_serialize(self): + profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) + + self.assertEqual(profile.serialize(), {"name": "Joe", "id": 1}) + + def test_json(self): + profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) + + self.assertEqual(profile.to_json(), '{"name": "Joe", "id": 1}') + + def test_serialize_with_hidden(self): + profile = ProfileSerialize.hydrate( + {"name": "Joe", "id": 1, "password": "secret"} + ) + + self.assertTrue(profile.serialize().get("name")) + self.assertTrue(profile.serialize().get("id")) + self.assertFalse(profile.serialize().get("password")) + + def test_serialize_with_visible(self): + profile = ProfileSerializeWithVisible.hydrate( + {"name": "Joe", "id": 1, "password": "secret", "email": "joe@masonite.com"} + ) + self.assertTrue( + {"name": "Joe", "email": "joe@masonite.com"}, profile.serialize() + ) + + def test_serialize_with_visible_and_hidden_raise_error(self): + profile = ProfileSerializeWithVisibleAndHidden.hydrate( + {"name": "Joe", "id": 1, "password": "secret", "email": "joe@masonite.com"} + ) + with self.assertRaises(AttributeError): + profile.serialize() + + def test_serialize_with_on_the_fly_appends(self): + user = User.hydrate({"name": "Joe", "id": 1}) + + user.set_appends(["meta"]) + + serialized = user.serialize() + self.assertEqual(serialized["id"], 1) + self.assertEqual(serialized["name"], "Joe") + self.assertEqual(serialized["meta"]["is_subscribed"], True) + + def test_serialize_with_model_appends(self): + User.__appends__ = ["meta"] + user = User.hydrate({"name": "Joe", "id": 1}) + serialized = user.serialize() + self.assertEqual(serialized["id"], 1) + self.assertEqual(serialized["name"], "Joe") + self.assertEqual(serialized["meta"]["is_subscribed"], True) + + def test_serialize_with_date(self): + user = User.hydrate({"name": "Joe", "created_at": pendulum.now()}) + + self.assertTrue(json.dumps(user.serialize())) + + def test_set_as_date(self): + user = User.hydrate( + { + "name": "Joe", + "created_at": pendulum.now().add(days=10).to_datetime_string(), + } + ) + + self.assertTrue(user.created_at) + self.assertTrue(user.created_at.is_future()) + + def test_access_as_date(self): + user = User.hydrate( + { + "name": "Joe", + "created_at": datetime.datetime.now() + datetime.timedelta(days=1), + } + ) + + self.assertTrue(user.created_at) + self.assertTrue(user.created_at.is_future()) + + def test_hydrate_with_none(self): + profile = ProfileFillAsterisk.hydrate(None) + + self.assertEqual(profile, None) + + def test_serialize_with_dirty_attribute(self): + profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) + + profile.age = 18 + self.assertEqual(profile.serialize(), {"age": 18, "name": "Joe", "id": 1}) + + def test_attribute_check_with_hasattr(self): + self.assertFalse(hasattr(Profile(), "__password__")) + + +if os.getenv("RUN_MYSQL_DATABASE", "false").lower() == "true": + + class MysqlTestModel(unittest.TestCase): + # TODO: these tests aren't getting run in CI... is that intentional? def test_can_find_first(self): profile = User.find(1) + def test_can_touch(self): + profile = ProfileFillTimeStamped.hydrate({"name": "Joe", "id": 1}) + + sql = profile.touch("now", query=True).to_sql() + + self.assertEqual( + sql, + "UPDATE `profiles` SET `profiles`.`updated_at` = 'now' WHERE `profiles`.`id` = '1'", + ) + def test_find_or_fail_raise_an_exception_if_not_exists(self): with self.assertRaises(ModelNotFound): User.find(100) - def test_serialize_with_dirty_attribute(self): - profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) - - profile.age = 18 - self.assertEqual(profile.serialize(), {"age": 18, "name": "Joe", "id": 1}) - - def test_attribute_check_with_hasattr(self): - self.assertFalse(hasattr(Profile(), "__password__")) + def test_returns_correct_data_type(self): + self.assertIsInstance(User.all(), Collection) + # self.assertIsInstance(User.first(), User) + # self.assertIsInstance(User.first(), User) diff --git a/tests/mysql/relationships/test_has_many_through.py b/tests/mysql/relationships/test_has_many_through.py index 310a3b3c4..3c6d5b7ef 100644 --- a/tests/mysql/relationships/test_has_many_through.py +++ b/tests/mysql/relationships/test_has_many_through.py @@ -2,10 +2,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( - has_one, - belongs_to_many, has_many_through, - has_many, ) from dotenv import load_dotenv @@ -34,7 +31,7 @@ def test_has_query(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_has(self): @@ -42,7 +39,7 @@ def test_or_has(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_where_has_query(self): @@ -52,7 +49,7 @@ def test_where_has_query(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_or_where_has(self): @@ -64,7 +61,7 @@ def test_or_where_has(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_doesnt_have(self): @@ -72,7 +69,7 @@ def test_doesnt_have(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_where_doesnt_have(self): @@ -86,13 +83,5 @@ def test_or_where_doesnt_have(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", - ) - - def test_has_one_through_with_count(self): - sql = InboundShipment.with_count("from_country").to_sql() - - self.assertEqual( - sql, - """SELECT `inbound_shipments`.*, (SELECT COUNT(*) AS m_count_reserved FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AS from_country_count FROM `inbound_shipments`""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) diff --git a/tests/mysql/relationships/test_has_one_through.py b/tests/mysql/relationships/test_has_one_through.py index f0ae66a11..4337dc837 100644 --- a/tests/mysql/relationships/test_has_one_through.py +++ b/tests/mysql/relationships/test_has_one_through.py @@ -2,10 +2,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( - has_one, - belongs_to_many, has_one_through, - has_many, ) from dotenv import load_dotenv @@ -13,7 +10,7 @@ class InboundShipment(Model): - @has_one_through("port_id", "country_id", "from_port_id", "country_id") + @has_one_through(None, "from_port_id", "country_id", "port_id", "country_id") def from_country(self): return Country, Port @@ -26,7 +23,7 @@ class Port(Model): pass -class MySQLRelationships(unittest.TestCase): +class MySQLHasOneThroughRelationship(unittest.TestCase): maxDiff = None def test_has_query(self): @@ -34,7 +31,7 @@ def test_has_query(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_has(self): @@ -42,7 +39,7 @@ def test_or_has(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_where_has_query(self): @@ -52,7 +49,7 @@ def test_where_has_query(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_or_where_has(self): @@ -64,7 +61,7 @@ def test_or_where_has(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_doesnt_have(self): @@ -72,7 +69,7 @@ def test_doesnt_have(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_where_doesnt_have(self): @@ -86,7 +83,7 @@ def test_or_where_doesnt_have(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_has_one_through_with_count(self): @@ -94,5 +91,5 @@ def test_has_one_through_with_count(self): self.assertEqual( sql, - """SELECT `inbound_shipments`.*, (SELECT COUNT(*) AS m_count_reserved FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AS from_country_count FROM `inbound_shipments`""", + """SELECT `inbound_shipments`.*, (SELECT COUNT(*) AS m_count_reserved FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`) AS from_country_count FROM `inbound_shipments`""", ) diff --git a/tests/mysql/schema/test_mysql_schema_builder.py b/tests/mysql/schema/test_mysql_schema_builder.py index f2019e9b6..b3aefc033 100644 --- a/tests/mysql/schema/test_mysql_schema_builder.py +++ b/tests/mysql/schema/test_mysql_schema_builder.py @@ -1,12 +1,14 @@ import os import unittest -from masoniteorm import Model +from src.masoniteorm import Model from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MySQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MySQLPlatform +from tests.integrations.config.database import DATABASES + class Discussion(Model): pass @@ -37,6 +39,26 @@ def test_can_add_columns1(self): ], ) + def test_can_add_tiny_text(self): + with self.schema.create("users") as blueprint: + blueprint.tiny_text("description") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + ["CREATE TABLE `users` (`description` TINYTEXT NOT NULL)"], + ) + + def test_can_add_unsigned_decimal(self): + with self.schema.create("users") as blueprint: + blueprint.unsigned_decimal("amount", 19, 4) + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + ["CREATE TABLE `users` (`amount` DECIMAL(19, 4) UNSIGNED NOT NULL)"], + ) + def test_can_create_table_if_not_exists(self): with self.schema.create_table_if_not_exists("users") as blueprint: blueprint.string("name") @@ -203,7 +225,7 @@ def test_can_advanced_table_creation2(self): "CREATE TABLE `users` (`id` BIGINT UNSIGNED AUTO_INCREMENT NOT NULL, `name` VARCHAR(255) NOT NULL, " "`duration` VARCHAR(255) NOT NULL, `url` VARCHAR(255) NOT NULL, `last_address` VARCHAR(255) NULL, `route_origin` VARCHAR(255) NULL, `mac_address` VARCHAR(255) NULL, " "`published_at` DATETIME NOT NULL, `thumbnail` VARCHAR(255) NULL, " - "`premium` INT(11) NOT NULL, `author_id` INT UNSIGNED NULL, `description` TEXT NOT NULL, `created_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, " + "`premium` INT(11) NOT NULL, `author_id` INT(11) UNSIGNED NULL, `description` TEXT NOT NULL, `created_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, " "`updated_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY (`author_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)" ], ) @@ -278,13 +300,13 @@ def test_can_have_unsigned_columns(self): blueprint.to_sql(), [ "CREATE TABLE `users` (" - "`profile_id` INT UNSIGNED NOT NULL, " - "`big_profile_id` BIGINT UNSIGNED NOT NULL, " - "`tiny_profile_id` TINYINT UNSIGNED NOT NULL, " - "`small_profile_id` SMALLINT UNSIGNED NOT NULL, " - "`medium_profile_id` MEDIUMINT UNSIGNED NOT NULL, " + "`profile_id` INT(11) UNSIGNED NOT NULL, " + "`big_profile_id` BIGINT(32) UNSIGNED NOT NULL, " + "`tiny_profile_id` TINYINT(1) UNSIGNED NOT NULL, " + "`small_profile_id` SMALLINT(5) UNSIGNED NOT NULL, " + "`medium_profile_id` MEDIUMINT(7) UNSIGNED NOT NULL, " "`unsigned_profile_id` INT UNSIGNED NOT NULL, " - "`unsigned_big_profile_id` BIGINT UNSIGNED NOT NULL)" + "`unsigned_big_profile_id` BIGINT(32) UNSIGNED NOT NULL)" ], ) @@ -362,3 +384,15 @@ def test_can_truncate_without_foreign_keys(self): "SET FOREIGN_KEY_CHECKS=1", ], ) + + def test_can_add_enum(self): + with self.schema.create("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + [ + "CREATE TABLE `users` (`status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active')" + ], + ) diff --git a/tests/mysql/schema/test_mysql_schema_builder_alter.py b/tests/mysql/schema/test_mysql_schema_builder_alter.py index 4b2befe8c..10e345e76 100644 --- a/tests/mysql/schema/test_mysql_schema_builder_alter.py +++ b/tests/mysql/schema/test_mysql_schema_builder_alter.py @@ -12,7 +12,6 @@ class TestMySQLSchemaBuilderAlter(unittest.TestCase): maxDiff = None def setUp(self): - self.schema = Schema( connection_class=MySQLConnection, connection="mysql", @@ -295,3 +294,27 @@ def test_can_create_indexes(self): "ALTER TABLE `users` ADD FULLTEXT description_fulltext(description)", ], ) + + def test_can_add_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + + sql = [ + "ALTER TABLE `users` ADD `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'" + ] + + self.assertEqual(blueprint.to_sql(), sql) + + def test_can_change_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active").change() + + self.assertEqual(len(blueprint.table.changed_columns), 1) + + sql = [ + "ALTER TABLE `users` MODIFY `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'" + ] + + self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/mysql/scopes/test_can_use_global_scopes.py b/tests/mysql/scopes/test_can_use_global_scopes.py index e763d0539..1f670fcf9 100644 --- a/tests/mysql/scopes/test_can_use_global_scopes.py +++ b/tests/mysql/scopes/test_can_use_global_scopes.py @@ -15,7 +15,6 @@ class UserSoft(Model, SoftDeletesMixin): class User(Model): - __dry__ = True @@ -36,7 +35,7 @@ def test_can_use_global_scopes_on_select(self): def test_can_use_global_scopes_on_time(self): sql = "INSERT INTO `users` (`users`.`name`, `users`.`updated_at`, `users`.`created_at`) VALUES ('Joe'" - self.assertTrue(User.create({"name": "Joe"}, query=True).startswith(sql)) + self.assertTrue(User.create({"name": "Joe"}, query=True).to_sql().startswith(sql)) # def test_can_use_global_scopes_on_inherit(self): # sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NULL" diff --git a/tests/mysql/scopes/test_soft_delete.py b/tests/mysql/scopes/test_soft_delete.py index f405defbe..1a0c6c43d 100644 --- a/tests/mysql/scopes/test_soft_delete.py +++ b/tests/mysql/scopes/test_soft_delete.py @@ -1,8 +1,8 @@ -import inspect import unittest +import pendulum + from tests.integrations.config.database import DATABASES -from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.scopes import SoftDeleteScope @@ -10,16 +10,14 @@ from src.masoniteorm.models import Model from src.masoniteorm.scopes import SoftDeletesMixin -from tests.User import User class UserSoft(Model, SoftDeletesMixin): __dry__ = True - + __table__ = "users" class UserSoftArchived(Model, SoftDeletesMixin): __dry__ = True - __deleted_at__ = "archived_at" __table__ = "users" @@ -44,7 +42,7 @@ def test_with_trashed(self): def test_force_delete(self): sql = "DELETE FROM `users`" builder = self.get_builder().set_global_scope(SoftDeleteScope()) - self.assertEqual(sql, builder.force_delete().to_sql()) + self.assertEqual(sql, builder.force_delete(query=True).to_sql()) def test_restore(self): sql = "UPDATE `users` SET `users`.`deleted_at` = 'None'" @@ -52,9 +50,11 @@ def test_restore(self): self.assertEqual(sql, builder.restore().to_sql()) def test_force_delete_with_wheres(self): - sql = "DELETE FROM `user_softs` WHERE `user_softs`.`active` = '1'" + sql = "DELETE FROM `users` WHERE `users`.`active` = '1'" builder = self.get_builder().set_global_scope(SoftDeleteScope()) - self.assertEqual(sql, UserSoft.where("active", 1).force_delete().to_sql()) + self.assertEqual( + sql, UserSoft.where("active", 1).force_delete(query=True).to_sql() + ) def test_that_trashed_users_are_not_returned_by_default(self): sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL" @@ -67,9 +67,24 @@ def test_only_trashed(self): self.assertEqual(sql, builder.only_trashed().to_sql()) def test_only_trashed_on_model(self): - sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NOT NULL" + sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL" self.assertEqual(sql, UserSoft.only_trashed().to_sql()) def test_can_change_column(self): sql = "SELECT * FROM `users` WHERE `users`.`archived_at` IS NOT NULL" self.assertEqual(sql, UserSoftArchived.only_trashed().to_sql()) + + def test_find_with_global_scope(self): + find_sql = UserSoft.find("1", query=True).to_sql() + raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NULL""" + self.assertEqual(find_sql, raw_sql) + + def test_find_with_trashed_scope(self): + find_sql = UserSoft.with_trashed().find("1", query=True).to_sql() + raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1'""" + self.assertEqual(find_sql, raw_sql) + + def test_find_with_only_trashed_scope(self): + find_sql = UserSoft.only_trashed().find("1", query=True).to_sql() + raw_sql = """SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL AND `users`.`id` = '1'""" + self.assertEqual(find_sql, raw_sql) diff --git a/tests/postgres/builder/test_postgres_query_builder.py b/tests/postgres/builder/test_postgres_query_builder.py index d905e0d80..86e5b62f8 100644 --- a/tests/postgres/builder/test_postgres_query_builder.py +++ b/tests/postgres/builder/test_postgres_query_builder.py @@ -9,7 +9,6 @@ class MockConnection: - connection_details = {} def make_connection(self): @@ -462,7 +461,6 @@ def test_update_lock(self): class PostgresQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase): - grammar = PostgresGrammar def sum(self): @@ -772,3 +770,31 @@ def shared_lock(self): builder.truncate() """ return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR SHARE""" + + def test_latest(self): + builder = self.get_builder() + builder.latest("email") + sql = getattr( + self, inspect.currentframe().f_code.co_name.replace("test_", "") + )() + self.assertEqual(builder.to_sql(), sql) + + def test_oldest(self): + builder = self.get_builder() + builder.oldest("email") + sql = getattr( + self, inspect.currentframe().f_code.co_name.replace("test_", "") + )() + self.assertEqual(builder.to_sql(), sql) + + def oldest(self): + """ + builder.order_by('email', 'asc') + """ + return """SELECT * FROM "users" ORDER BY "email" ASC""" + + def latest(self): + """ + builder.order_by('email', 'des') + """ + return """SELECT * FROM "users" ORDER BY "email" DESC""" diff --git a/tests/postgres/builder/test_postgres_transaction.py b/tests/postgres/builder/test_postgres_transaction.py index d0b92c9ce..a07a77f45 100644 --- a/tests/postgres/builder/test_postgres_transaction.py +++ b/tests/postgres/builder/test_postgres_transaction.py @@ -17,7 +17,6 @@ class User(Model): __timestamps__ = False class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def get_builder(self, table="users"): diff --git a/tests/postgres/grammar/test_delete_grammar.py b/tests/postgres/grammar/test_delete_grammar.py index 66b3a9532..690e72537 100644 --- a/tests/postgres/grammar/test_delete_grammar.py +++ b/tests/postgres/grammar/test_delete_grammar.py @@ -41,7 +41,6 @@ def test_can_compile_delete_with_where(self): class TestPostgresDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase): - grammar = "postgres" def can_compile_delete(self): diff --git a/tests/postgres/grammar/test_insert_grammar.py b/tests/postgres/grammar/test_insert_grammar.py index 03e50c8c4..6404d2e34 100644 --- a/tests/postgres/grammar/test_insert_grammar.py +++ b/tests/postgres/grammar/test_insert_grammar.py @@ -27,7 +27,13 @@ def test_can_compile_insert_with_keywords(self): def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( - [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True + # These keys are intentionally out of order to show column to value alignment works + [ + {"name": "Joe", "age": 5}, + {"age": 35, "name": "Bill"}, + {"name": "John", "age": 10}, + ], + query=True, ).to_sql() sql = getattr( @@ -47,7 +53,6 @@ def test_can_compile_bulk_create_qmark(self): class TestPostgresUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase): - grammar = "postgres" def can_compile_insert(self): @@ -68,7 +73,7 @@ def can_compile_bulk_create(self): """ self.builder.create(name="Joe").to_sql() """ - return """INSERT INTO "users" ("name") VALUES ('Joe'), ('Bill'), ('John') RETURNING *""" + return """INSERT INTO "users" ("age", "name") VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John') RETURNING *""" def can_compile_bulk_create_qmark(self): """ diff --git a/tests/postgres/grammar/test_select_grammar.py b/tests/postgres/grammar/test_select_grammar.py index 1cd8be855..2793a60bd 100644 --- a/tests/postgres/grammar/test_select_grammar.py +++ b/tests/postgres/grammar/test_select_grammar.py @@ -6,7 +6,6 @@ class TestPostgresGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): - grammar = PostgresGrammar def can_compile_select(self): @@ -247,6 +246,12 @@ def can_compile_having(self): """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" + def can_compile_having_order(self): + """ + builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() + """ + return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" ORDER "users"."age" DESC""" + def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() @@ -309,6 +314,18 @@ def test_can_compile_having_raw(self): to_sql, """SELECT COUNT(*) as counts FROM "users" HAVING counts > 10""" ) + def test_can_compile_having_raw_order(self): + to_sql = ( + self.builder.select_raw("COUNT(*) as counts") + .having_raw("counts > 10") + .order_by_raw("counts DESC") + .to_sql() + ) + self.assertEqual( + to_sql, + """SELECT COUNT(*) as counts FROM "users" HAVING counts > 10 ORDER BY counts DESC""", + ) + def test_can_compile_where_raw_and_where_with_multiple_bindings(self): query = self.builder.where_raw( """ "age" = ? AND "is_admin" = ?""", [18, True] @@ -393,11 +410,25 @@ def can_compile_join_clause_with_null(self): clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") + .or_on_null("bgt.dept") + .on_value("rg.abc", 10) + ) + builder.join(clause).to_sql() + """ + return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NULL AND "rg"."abc" = '10'""" + + def can_compile_join_clause_with_not_null(self): + """ + builder = self.get_builder() + clause = ( + JoinClause("report_groups as rg") + .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") + .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ - return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NOT NULL""" + return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NOT NULL OR "dept" IS NOT NULL AND "rg"."abc" = '10'""" def can_compile_join_clause_with_lambda(self): """ diff --git a/tests/postgres/grammar/test_update_grammar.py b/tests/postgres/grammar/test_update_grammar.py index cb6d7fe5c..76d19f7f0 100644 --- a/tests/postgres/grammar/test_update_grammar.py +++ b/tests/postgres/grammar/test_update_grammar.py @@ -71,9 +71,16 @@ def test_raw_expression(self): self.assertEqual(to_sql, sql) + def test_update_null(self): + to_sql = self.builder.update({"name": None}, dry=True).to_sql() + print(to_sql) -class TestPostgresUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase): + sql = """UPDATE "users" SET "name" = \'None\'""" + + self.assertEqual(to_sql, sql) + +class TestPostgresUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase): grammar = "postgres" def can_compile_update(self): diff --git a/tests/postgres/relationships/test_postgres_relationships.py b/tests/postgres/relationships/test_postgres_relationships.py index a20477e16..023fd536d 100644 --- a/tests/postgres/relationships/test_postgres_relationships.py +++ b/tests/postgres/relationships/test_postgres_relationships.py @@ -23,7 +23,6 @@ class Logo(Model): __connection__ = "postgres" class User(Model): - __connection__ = "postgres" _eager_loads = () diff --git a/tests/postgres/schema/test_postgres_schema_builder.py b/tests/postgres/schema/test_postgres_schema_builder.py index 9c7751b7c..9e04b3dd5 100644 --- a/tests/postgres/schema/test_postgres_schema_builder.py +++ b/tests/postgres/schema/test_postgres_schema_builder.py @@ -31,6 +31,25 @@ def test_can_add_columns(self): ], ) + def test_can_add_tiny_text(self): + with self.schema.create("users") as blueprint: + blueprint.tiny_text("description") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)'] + ) + + def test_can_add_unsigned_decimal(self): + with self.schema.create("users") as blueprint: + blueprint.unsigned_decimal("amount", 19, 4) + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + ['CREATE TABLE "users" ("amount" DECIMAL(19, 4) NOT NULL)'], + ) + def test_can_create_table_if_not_exists(self): with self.schema.create_table_if_not_exists("users") as blueprint: blueprint.string("name") @@ -349,3 +368,15 @@ def test_can_truncate_without_foreign_keys(self): 'ALTER TABLE "users" ENABLE TRIGGER ALL', ], ) + + def test_can_add_enum(self): + with self.schema.create("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + [ + 'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL ' 'DEFAULT \'active\')' + ], + ) diff --git a/tests/postgres/schema/test_postgres_schema_builder_alter.py b/tests/postgres/schema/test_postgres_schema_builder_alter.py index 0f0519c52..d77d632b6 100644 --- a/tests/postgres/schema/test_postgres_schema_builder_alter.py +++ b/tests/postgres/schema/test_postgres_schema_builder_alter.py @@ -11,7 +11,6 @@ class TestPostgresSchemaBuilderAlter(unittest.TestCase): maxDiff = None def setUp(self): - self.schema = Schema( connection_class=PostgresConnection, connection="postgres", @@ -302,3 +301,27 @@ def test_alter_drop_on_table_schema_table(self): with schema.table("table_schema") as blueprint: blueprint.string("name") + + def test_can_add_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + + sql = [ + 'ALTER TABLE "users" ADD COLUMN "status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\'', + ] + + self.assertEqual(blueprint.to_sql(), sql) + + def test_can_change_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active").change() + + self.assertEqual(len(blueprint.table.changed_columns), 1) + + sql = [ + 'ALTER TABLE "users" ALTER COLUMN "status" TYPE VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')), ALTER COLUMN "status" SET NOT NULL, ALTER COLUMN "status" SET DEFAULT active', + ] + + self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/scopes/test_default_global_scopes.py b/tests/scopes/test_default_global_scopes.py index a12b53f0d..f890be6fd 100644 --- a/tests/scopes/test_default_global_scopes.py +++ b/tests/scopes/test_default_global_scopes.py @@ -28,6 +28,12 @@ class UserWithTimeStamps(Model, TimeStampsMixin): __dry__ = True +class UserWithCustomTimeStamps(Model, TimeStampsMixin): + __dry__ = True + date_updated_at = "updated_ts" + date_created_at = "created_ts" + + class UserSoft(Model, SoftDeletesMixin): __dry__ = True @@ -112,3 +118,24 @@ def test_timestamps_can_be_disabled(self): self.scope.set_timestamp_create(self.builder) self.assertNotIn("created_at", self.builder._creates) self.assertNotIn("updated_at", self.builder._creates) + + def test_uses_custom_timestamp_columns_on_create(self): + self.builder = MockBuilder(UserWithCustomTimeStamps) + self.scope.set_timestamp_create(self.builder) + created_column = UserWithCustomTimeStamps.date_created_at + updated_column = UserWithCustomTimeStamps.date_updated_at + self.assertNotIn("created_at", self.builder._creates) + self.assertNotIn("updated_at", self.builder._creates) + self.assertIn(created_column, self.builder._creates) + self.assertIn(updated_column, self.builder._creates) + self.assertIsInstance( + pendulum.parse(self.builder._creates[created_column]), pendulum.DateTime + ) + self.assertIsInstance( + pendulum.parse(self.builder._creates[updated_column]), pendulum.DateTime + ) + + def test_uses_custom_updated_column_on_update(self): + user = UserWithCustomTimeStamps.hydrate({"id": 1}) + sql = user.update({"id": 2}).to_sql() + self.assertTrue(UserWithCustomTimeStamps.date_updated_at in sql) diff --git a/tests/sqlite/builder/test_sqlite_builder_insert.py b/tests/sqlite/builder/test_sqlite_builder_insert.py index df58e73a1..52f6a29fb 100644 --- a/tests/sqlite/builder/test_sqlite_builder_insert.py +++ b/tests/sqlite/builder/test_sqlite_builder_insert.py @@ -17,7 +17,6 @@ class User(Model): class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def get_builder(self, table="users"): diff --git a/tests/sqlite/builder/test_sqlite_builder_pagination.py b/tests/sqlite/builder/test_sqlite_builder_pagination.py index 832df6631..2338b6b38 100644 --- a/tests/sqlite/builder/test_sqlite_builder_pagination.py +++ b/tests/sqlite/builder/test_sqlite_builder_pagination.py @@ -15,7 +15,6 @@ class User(Model): class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def get_builder(self, table="users", model=User): diff --git a/tests/sqlite/builder/test_sqlite_query_builder.py b/tests/sqlite/builder/test_sqlite_query_builder.py index 7fa4a80d8..9da9f235c 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder.py +++ b/tests/sqlite/builder/test_sqlite_query_builder.py @@ -392,7 +392,6 @@ def test_between(self): self.assertEqual(builder.to_sql(), sql) def test_between_persisted(self): - builder = QueryBuilder().table("users").on("dev") users = builder.between("age", 1, 2).count() @@ -407,7 +406,6 @@ def test_not_between(self): self.assertEqual(builder.to_sql(), sql) def test_not_between_persisted(self): - builder = QueryBuilder().table("users").on("dev") users = builder.where_not_null("id").not_between("age", 1, 2).count() @@ -583,7 +581,6 @@ def test_truncate_without_foreign_keys(self): class SQLiteQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase): - grammar = SQLiteGrammar def sum(self): @@ -971,3 +968,31 @@ def truncate_without_foreign_keys(self): 'DELETE FROM "users"', "PRAGMA foreign_keys = ON", ] + + def test_latest(self): + builder = self.get_builder() + builder.latest("email") + sql = getattr( + self, inspect.currentframe().f_code.co_name.replace("test_", "") + )() + self.assertEqual(builder.to_sql(), sql) + + def test_oldest(self): + builder = self.get_builder() + builder.oldest("email") + sql = getattr( + self, inspect.currentframe().f_code.co_name.replace("test_", "") + )() + self.assertEqual(builder.to_sql(), sql) + + def oldest(self): + """ + builder.order_by('email', 'asc') + """ + return """SELECT * FROM "users" ORDER BY "email" ASC""" + + def latest(self): + """ + builder.order_by('email', 'des') + """ + return """SELECT * FROM "users" ORDER BY "email" DESC""" diff --git a/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py b/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py index aa2524dde..800e94405 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py +++ b/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py @@ -56,7 +56,6 @@ def profile(self): class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def get_builder(self, table="users", model=User): diff --git a/tests/sqlite/builder/test_sqlite_query_builder_relationships.py b/tests/sqlite/builder/test_sqlite_query_builder_relationships.py index 8a35a4e8e..3e4dc03b9 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder_relationships.py +++ b/tests/sqlite/builder/test_sqlite_query_builder_relationships.py @@ -42,7 +42,6 @@ def profile(self): class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def get_builder(self, table="users"): diff --git a/tests/sqlite/builder/test_sqlite_transaction.py b/tests/sqlite/builder/test_sqlite_transaction.py index 2cf0238c8..41e87612b 100644 --- a/tests/sqlite/builder/test_sqlite_transaction.py +++ b/tests/sqlite/builder/test_sqlite_transaction.py @@ -18,7 +18,6 @@ class User(Model): class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def get_builder(self, table="users"): diff --git a/tests/sqlite/grammar/test_sqlite_delete_grammar.py b/tests/sqlite/grammar/test_sqlite_delete_grammar.py index ee501ffb6..3bd36a872 100644 --- a/tests/sqlite/grammar/test_sqlite_delete_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_delete_grammar.py @@ -40,7 +40,6 @@ def test_can_compile_delete_with_where(self): class TestSqliteDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase): - grammar = "sqlite" def can_compile_delete(self): diff --git a/tests/sqlite/grammar/test_sqlite_insert_grammar.py b/tests/sqlite/grammar/test_sqlite_insert_grammar.py index 47d6d125f..35ee7eb91 100644 --- a/tests/sqlite/grammar/test_sqlite_insert_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_insert_grammar.py @@ -27,7 +27,13 @@ def test_can_compile_insert_with_keywords(self): def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( - [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True + # These keys are intentionally out of order to show column to value alignment works + [ + {"name": "Joe", "age": 5}, + {"age": 35, "name": "Bill"}, + {"name": "John", "age": 10}, + ], + query=True, ).to_sql() sql = getattr( @@ -62,7 +68,6 @@ def test_can_compile_bulk_create_multiple(self): class TestSqliteUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase): - grammar = "sqlite" def can_compile_insert(self): @@ -83,13 +88,13 @@ def can_compile_bulk_create(self): """ self.builder.create(name="Joe").to_sql() """ - return """INSERT INTO "users" ("name") VALUES ('Joe'), ('Bill'), ('John')""" + return """INSERT INTO "users" ("age", "name") VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')""" def can_compile_bulk_create_multiple(self): """ self.builder.create(name="Joe").to_sql() """ - return """INSERT INTO "users" ("name", "active") VALUES ('Joe', '1'), ('Bill', '1'), ('John', '1')""" + return """INSERT INTO "users" ("active", "name") VALUES ('1', 'Joe'), ('1', 'Bill'), ('1', 'John')""" def can_compile_bulk_create_qmark(self): """ diff --git a/tests/sqlite/grammar/test_sqlite_select_grammar.py b/tests/sqlite/grammar/test_sqlite_select_grammar.py index 82147d680..f835bb434 100644 --- a/tests/sqlite/grammar/test_sqlite_select_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_select_grammar.py @@ -6,7 +6,6 @@ class TestSQLiteGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): - grammar = SQLiteGrammar maxDiff = None @@ -239,6 +238,12 @@ def can_compile_having(self): """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" + def can_compile_having_order(self): + """ + builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() + """ + return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\" ORDER "users"."age" DESC""" + def can_compile_having_raw(self): """ builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql() @@ -375,11 +380,25 @@ def can_compile_join_clause_with_null(self): clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") + .or_on_null("bgt.dept") + .on_value("rg.abc", 10) + ) + builder.join(clause).to_sql() + """ + return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NULL AND "rg"."abc" = '10'""" + + def can_compile_join_clause_with_not_null(self): + """ + builder = self.get_builder() + clause = ( + JoinClause("report_groups as rg") + .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") + .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ - return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NOT NULL""" + return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NOT NULL OR "dept" IS NOT NULL AND "rg"."abc" = '10'""" def can_compile_join_clause_with_lambda(self): """ diff --git a/tests/sqlite/grammar/test_sqlite_update_grammar.py b/tests/sqlite/grammar/test_sqlite_update_grammar.py index 1bf081b40..4ee265432 100644 --- a/tests/sqlite/grammar/test_sqlite_update_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_update_grammar.py @@ -68,7 +68,6 @@ def test_raw_expression(self): class TestSqliteUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase): - grammar = "sqlite" def can_compile_update(self): diff --git a/tests/sqlite/models/test_observers.py b/tests/sqlite/models/test_observers.py index a0ce9e50a..178eaee46 100644 --- a/tests/sqlite/models/test_observers.py +++ b/tests/sqlite/models/test_observers.py @@ -63,7 +63,6 @@ class Observer(Model): class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def test_created_is_observed(self): diff --git a/tests/sqlite/models/test_sqlite_model.py b/tests/sqlite/models/test_sqlite_model.py index 7e76116a8..1456e1834 100644 --- a/tests/sqlite/models/test_sqlite_model.py +++ b/tests/sqlite/models/test_sqlite_model.py @@ -55,7 +55,6 @@ def team(self): class BaseTestQueryRelationships(unittest.TestCase): - maxDiff = None def test_update_specific_record(self): @@ -73,16 +72,29 @@ def test_update_all_records(self): self.assertEqual(sql, """UPDATE "users" SET "name" = 'joe'""") def test_can_find_list(self): - sql = User.find(1, query=True) + sql = User.find(1, query=True).to_sql() self.assertEqual(sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""") - sql = User.find([1, 2, 3], query=True) + sql = User.find([1, 2, 3], query=True).to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')""" ) + def test_find_or_if_record_not_found(self): + # Insane record number so record cannot be found + record_id = 1_000_000_000_000_000 + + result = User.find_or(record_id, lambda: "Record not found.") + self.assertEqual(result, "Record not found.") + + def test_find_or_if_record_found(self): + record_id = 1 + result_id = User.find_or(record_id, lambda: "Record not found.").id + + self.assertEqual(result_id, record_id) + def test_can_set_and_retreive_attribute(self): user = User.hydrate({"id": 1, "name": "joe", "customer_id": 1}) user.customer_id = "CUST1" @@ -96,7 +108,7 @@ def test_model_can_use_selects(self): def test_model_can_use_selects_from_methods(self): self.assertEqual( - SelectPass.all(["username"], query=True), + SelectPass.all(["username"], query=True).to_sql(), 'SELECT "select_passes"."username" FROM "select_passes"', ) @@ -163,7 +175,6 @@ class ModelUser(Model): self.assertEqual(count, 0) def test_get_columns(self): - columns = User.get_columns() self.assertEqual( columns, @@ -192,7 +203,6 @@ def test_get_columns(self): ) def test_should_return_relation_applying_hidden_attributes(self): - schema = Schema( connection_details=DATABASES, connection="dev", platform=SQLitePlatform ).on("dev") diff --git a/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py new file mode 100644 index 000000000..baf68eae8 --- /dev/null +++ b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py @@ -0,0 +1,144 @@ +import unittest + +from src.masoniteorm.collection import Collection +from src.masoniteorm.models import Model +from src.masoniteorm.relationships import has_many_through +from tests.integrations.config.database import DATABASES +from src.masoniteorm.schema import Schema +from src.masoniteorm.schema.platforms import SQLitePlatform + + +class Enrolment(Model): + __table__ = "enrolment" + __connection__ = "dev" + __fillable__ = ["active_student_id", "in_course_id"] + + +class Student(Model): + __table__ = "student" + __connection__ = "dev" + __fillable__ = ["student_id", "name"] + + +class Course(Model): + __table__ = "course" + __connection__ = "dev" + __fillable__ = ["course_id", "name"] + + @has_many_through( + None, + "in_course_id", + "active_student_id", + "course_id", + "student_id" + ) + def students(self): + return [Student, Enrolment] + + +class TestHasManyThroughRelationship(unittest.TestCase): + def setUp(self): + self.schema = Schema( + connection="dev", + connection_details=DATABASES, + platform=SQLitePlatform, + ).on("dev") + + with self.schema.create_table_if_not_exists("student") as table: + table.integer("student_id").primary() + table.string("name") + + with self.schema.create_table_if_not_exists("course") as table: + table.integer("course_id").primary() + table.string("name") + + with self.schema.create_table_if_not_exists("enrolment") as table: + table.integer("enrolment_id").primary() + table.integer("active_student_id") + table.integer("in_course_id") + + if not Course.count(): + Course.builder.new().bulk_create( + [ + {"course_id": 10, "name": "Math 101"}, + {"course_id": 20, "name": "History 101"}, + {"course_id": 30, "name": "Math 302"}, + {"course_id": 40, "name": "Biology 302"}, + ] + ) + + if not Student.count(): + Student.builder.new().bulk_create( + [ + {"student_id": 100, "name": "Bob"}, + {"student_id": 200, "name": "Alice"}, + {"student_id": 300, "name": "Steve"}, + {"student_id": 400, "name": "Megan"}, + ] + ) + + if not Enrolment.count(): + Enrolment.builder.new().bulk_create( + [ + {"active_student_id": 100, "in_course_id": 30}, + {"active_student_id": 200, "in_course_id": 10}, + {"active_student_id": 100, "in_course_id": 10}, + {"active_student_id": 400, "in_course_id": 20}, + ] + ) + + def test_has_many_through_can_eager_load(self): + courses = Course.where("name", "Math 101").with_("students").get() + students = courses.first().students + + self.assertIsInstance(students, Collection) + self.assertEqual(students.count(), 2) + + student1 = students.shift() + self.assertIsInstance(student1, Student) + self.assertEqual(student1.name, "Alice") + + student2 = students.shift() + self.assertIsInstance(student2, Student) + self.assertEqual(student2.name, "Bob") + + # check .first() and .get() produce the same result + single = ( + Course.where("name", "History 101") + .with_("students") + .first() + ) + self.assertIsInstance(single.students, Collection) + + single_get = ( + Course.where("name", "History 101").with_("students").get() + ) + + print(single.students) + print(single_get.first().students) + self.assertEqual(single.students.count(), 1) + self.assertEqual(single_get.first().students.count(), 1) + + single_name = single.students.first().name + single_get_name = single_get.first().students.first().name + self.assertEqual(single_name, single_get_name) + + def test_has_many_through_eager_load_can_be_empty(self): + courses = ( + Course.where("name", "Biology 302") + .with_("students") + .get() + ) + self.assertIsNone(courses.first().students) + + def test_has_many_through_can_get_related(self): + course = Course.where("name", "Math 101").first() + self.assertIsInstance(course.students, Collection) + self.assertIsInstance(course.students.first(), Student) + self.assertEqual(course.students.count(), 2) + + def test_has_many_through_has_query(self): + courses = Course.where_has( + "students", lambda query: query.where("name", "Bob") + ) + self.assertEqual(courses.count(), 2) diff --git a/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py new file mode 100644 index 000000000..dee1bff9d --- /dev/null +++ b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py @@ -0,0 +1,139 @@ +import unittest + +from src.masoniteorm.models import Model +from src.masoniteorm.relationships import has_one_through +from tests.integrations.config.database import DATABASES +from src.masoniteorm.schema import Schema +from src.masoniteorm.schema.platforms import SQLitePlatform + + +class Port(Model): + __table__ = "ports" + __connection__ = "dev" + __fillable__ = ["port_id", "name", "port_country_id"] + + +class Country(Model): + __table__ = "countries" + __connection__ = "dev" + __fillable__ = ["country_id", "name"] + + +class IncomingShipment(Model): + __table__ = "incoming_shipments" + __connection__ = "dev" + __fillable__ = ["shipment_id", "name", "from_port_id"] + + @has_one_through(None, "from_port_id", "port_country_id", "port_id", "country_id") + def from_country(self): + return [Country, Port] + + + +class TestHasOneThroughRelationship(unittest.TestCase): + def setUp(self): + self.schema = Schema( + connection="dev", + connection_details=DATABASES, + platform=SQLitePlatform, + ).on("dev") + + with self.schema.create_table_if_not_exists("incoming_shipments") as table: + table.integer("shipment_id").primary() + table.string("name") + table.integer("from_port_id") + + with self.schema.create_table_if_not_exists("ports") as table: + table.integer("port_id").primary() + table.string("name") + table.integer("port_country_id") + + with self.schema.create_table_if_not_exists("countries") as table: + table.integer("country_id").primary() + table.string("name") + + if not Country.count(): + Country.builder.new().bulk_create( + [ + {"country_id": 10, "name": "Australia"}, + {"country_id": 20, "name": "USA"}, + {"country_id": 30, "name": "Canada"}, + {"country_id": 40, "name": "United Kingdom"}, + ] + ) + + if not Port.count(): + Port.builder.new().bulk_create( + [ + {"port_id": 100, "name": "Melbourne", "port_country_id": 10}, + {"port_id": 200, "name": "Darwin", "port_country_id": 10}, + {"port_id": 300, "name": "South Louisiana", "port_country_id": 20}, + {"port_id": 400, "name": "Houston", "port_country_id": 20}, + {"port_id": 500, "name": "Montreal", "port_country_id": 30}, + {"port_id": 600, "name": "Vancouver", "port_country_id": 30}, + {"port_id": 700, "name": "Southampton", "port_country_id": 40}, + {"port_id": 800, "name": "London Gateway", "port_country_id": 40}, + ] + ) + + if not IncomingShipment.count(): + IncomingShipment.builder.new().bulk_create( + [ + {"name": "Bread", "from_port_id": 300}, + {"name": "Milk", "from_port_id": 100}, + {"name": "Tractor Parts", "from_port_id": 100}, + {"name": "Fridges", "from_port_id": 700}, + {"name": "Wheat", "from_port_id": 600}, + {"name": "Kettles", "from_port_id": 400}, + {"name": "Bread", "from_port_id": 700}, + ] + ) + + def test_has_one_through_can_eager_load(self): + shipments = IncomingShipment.where("name", "Bread").with_("from_country").get() + self.assertEqual(shipments.count(), 2) + + shipment1 = shipments.shift() + self.assertIsInstance(shipment1.from_country, Country) + self.assertEqual(shipment1.from_country.country_id, 20) + + shipment2 = shipments.shift() + self.assertIsInstance(shipment2.from_country, Country) + self.assertEqual(shipment2.from_country.country_id, 40) + + # check .first() and .get() produce the same result + single = ( + IncomingShipment.where("name", "Tractor Parts") + .with_("from_country") + .first() + ) + single_get = ( + IncomingShipment.where("name", "Tractor Parts").with_("from_country").get() + ) + self.assertEqual(single.from_country.country_id, 10) + self.assertEqual(single_get.count(), 1) + self.assertEqual( + single.from_country.country_id, single_get.first().from_country.country_id + ) + + def test_has_one_through_eager_load_can_be_empty(self): + shipments = ( + IncomingShipment.where("name", "Bread") + .where_has("from_country", lambda query: query.where("name", "Ueaguay")) + .with_( + "from_country", + ) + .get() + ) + self.assertEqual(shipments.count(), 0) + + def test_has_one_through_can_get_related(self): + shipment = IncomingShipment.where("name", "Milk").first() + self.assertIsInstance(shipment.from_country, Country) + self.assertEqual(shipment.from_country.country_id, 10) + + def test_has_one_through_has_query(self): + shipments = IncomingShipment.where_has( + "from_country", lambda query: query.where("name", "USA") + ) + self.assertEqual(shipments.count(), 2) diff --git a/tests/sqlite/relationships/test_sqlite_polymorphic.py b/tests/sqlite/relationships/test_sqlite_polymorphic.py index 5a69dbca7..e0cd4d392 100644 --- a/tests/sqlite/relationships/test_sqlite_polymorphic.py +++ b/tests/sqlite/relationships/test_sqlite_polymorphic.py @@ -26,7 +26,6 @@ class Logo(Model): class Like(Model): - __connection__ = "dev" @morph_to("record_type", "record_id") @@ -35,7 +34,6 @@ def record(self): class User(Model): - __connection__ = "dev" _eager_loads = () diff --git a/tests/sqlite/relationships/test_sqlite_relationships.py b/tests/sqlite/relationships/test_sqlite_relationships.py index 0349a5a32..a3be52469 100644 --- a/tests/sqlite/relationships/test_sqlite_relationships.py +++ b/tests/sqlite/relationships/test_sqlite_relationships.py @@ -1,6 +1,4 @@ -import os import unittest - from src.masoniteorm.models import Model from src.masoniteorm.relationships import belongs_to, has_many, has_one, belongs_to_many from tests.integrations.config.database import DB @@ -30,7 +28,6 @@ class Logo(Model): class User(Model): - __connection__ = "dev" _eager_loads = () @@ -50,7 +47,6 @@ def get_is_admin(self): class Store(Model): - __connection__ = "dev" @belongs_to_many("store_id", "product_id", "id", "id", with_timestamps=True) @@ -67,12 +63,10 @@ def store_products(self): class Product(Model): - __connection__ = "dev" class UserHasOne(Model): - __table__ = "users" __connection__ = "dev" diff --git a/tests/sqlite/schema/test_sqlite_schema_builder.py b/tests/sqlite/schema/test_sqlite_schema_builder.py index 6b37303d9..2259ea7ec 100644 --- a/tests/sqlite/schema/test_sqlite_schema_builder.py +++ b/tests/sqlite/schema/test_sqlite_schema_builder.py @@ -30,6 +30,25 @@ def test_can_add_columns(self): ], ) + def test_can_add_tiny_text(self): + with self.schema.create("users") as blueprint: + blueprint.tiny_text("description") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)'] + ) + + def test_can_add_unsigned_decimal(self): + with self.schema.create("users") as blueprint: + blueprint.unsigned_decimal("amount", 19, 4) + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + ['CREATE TABLE "users" ("amount" DECIMAL(19, 4) NOT NULL)'], + ) + def test_can_create_table_if_not_exists(self): with self.schema.create_table_if_not_exists("users") as blueprint: blueprint.string("name") @@ -114,7 +133,7 @@ def test_can_use_morphs_for_polymorphism_relationships(self): self.assertEqual(len(blueprint.table.added_columns), 2) sql = [ - 'CREATE TABLE "likes" ("record_id" INT UNSIGNED NOT NULL, "record_type" VARCHAR(255) NOT NULL)', + 'CREATE TABLE "likes" ("record_id" INTEGER UNSIGNED NOT NULL, "record_type" VARCHAR(255) NOT NULL)', 'CREATE INDEX likes_record_id_index ON "likes"(record_id)', 'CREATE INDEX likes_record_type_index ON "likes"(record_type)', ] @@ -251,7 +270,7 @@ def test_can_advanced_table_creation2(self): [ 'CREATE TABLE "users" ("id" BIGINT NOT NULL, "name" VARCHAR(255) NOT NULL, "duration" VARCHAR(255) NOT NULL, ' '"url" VARCHAR(255) NOT NULL, "payload" JSON NOT NULL, "birth" VARCHAR(4) NOT NULL, "last_address" VARCHAR(255) NULL, "route_origin" VARCHAR(255) NULL, "mac_address" VARCHAR(255) NULL, ' - '"published_at" DATETIME NOT NULL, "wakeup_at" TIME NOT NULL, "thumbnail" VARCHAR(255) NULL, "premium" INTEGER NOT NULL, "author_id" INT UNSIGNED NULL, "description" TEXT NOT NULL, ' + '"published_at" DATETIME NOT NULL, "wakeup_at" TIME NOT NULL, "thumbnail" VARCHAR(255) NULL, "premium" INTEGER NOT NULL, "author_id" INTEGER UNSIGNED NULL, "description" TEXT NOT NULL, ' '"created_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, "updated_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, ' 'CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY ("author_id") REFERENCES "users"("id") ON DELETE SET NULL)' ] @@ -301,13 +320,11 @@ def test_can_have_unsigned_columns(self): blueprint.small_integer("small_profile_id").unsigned() blueprint.medium_integer("medium_profile_id").unsigned() - print(blueprint.to_sql()) - self.assertEqual( blueprint.to_sql(), [ """CREATE TABLE "users" (""" - """"profile_id" INT UNSIGNED NOT NULL, """ + """"profile_id" INTEGER UNSIGNED NOT NULL, """ """"big_profile_id" BIGINT UNSIGNED NOT NULL, """ """"tiny_profile_id" TINYINT UNSIGNED NOT NULL, """ """"small_profile_id" SMALLINT UNSIGNED NOT NULL, """ @@ -336,3 +353,15 @@ def test_can_truncate_without_foreign_keys(self): "PRAGMA foreign_keys = ON", ], ) + + def test_can_add_enum(self): + with self.schema.create("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + self.assertEqual( + blueprint.to_sql(), + [ + 'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\')' + ], + ) diff --git a/tests/sqlite/schema/test_sqlite_schema_builder_alter.py b/tests/sqlite/schema/test_sqlite_schema_builder_alter.py index fde74e94f..8f0ed8a78 100644 --- a/tests/sqlite/schema/test_sqlite_schema_builder_alter.py +++ b/tests/sqlite/schema/test_sqlite_schema_builder_alter.py @@ -176,10 +176,10 @@ def test_alter_add_column_and_foreign_key(self): blueprint.table.from_table = table sql = [ - 'ALTER TABLE "users" ADD COLUMN "playlist_id" INT UNSIGNED NULL REFERENCES "playlists"("id")', + 'ALTER TABLE "users" ADD COLUMN "playlist_id" INTEGER UNSIGNED NULL REFERENCES "playlists"("id")', "CREATE TEMPORARY TABLE __temp__users AS SELECT age, email FROM users", 'DROP TABLE "users"', - 'CREATE TABLE "users" ("age" VARCHAR NOT NULL, "email" VARCHAR NOT NULL, "playlist_id" INT UNSIGNED NULL, ' + 'CREATE TABLE "users" ("age" VARCHAR NOT NULL, "email" VARCHAR NOT NULL, "playlist_id" INTEGER UNSIGNED NULL, ' 'CONSTRAINT users_playlist_id_foreign FOREIGN KEY ("playlist_id") REFERENCES "playlists"("id") ON DELETE CASCADE ON UPDATE SET NULL)', 'INSERT INTO "users" ("age", "email") SELECT age, email FROM __temp__users', "DROP TABLE __temp__users", @@ -209,3 +209,33 @@ def test_alter_add_foreign_key_only(self): ] self.assertEqual(blueprint.to_sql(), sql) + + def test_can_add_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active") + + self.assertEqual(len(blueprint.table.added_columns), 1) + + sql = [ + 'ALTER TABLE "users" ADD COLUMN "status" VARCHAR CHECK(\'status\' IN(\'active\', \'inactive\')) NOT NULL DEFAULT \'active\'' + ] + + self.assertEqual(blueprint.to_sql(), sql) + + def test_can_change_column_enum(self): + with self.schema.table("users") as blueprint: + blueprint.enum("status", ["active", "inactive"]).default("active").change() + + blueprint.table.from_table = Table("users") + + self.assertEqual(len(blueprint.table.changed_columns), 1) + + sql = [ + 'CREATE TEMPORARY TABLE __temp__users AS SELECT FROM users', + 'DROP TABLE "users"', + 'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\')', + 'INSERT INTO "users" ("status") SELECT status FROM __temp__users', + 'DROP TABLE __temp__users' + ] + + self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/sqlite/schema/test_table.py b/tests/sqlite/schema/test_table.py index 5f95755eb..2e73a4ab4 100644 --- a/tests/sqlite/schema/test_table.py +++ b/tests/sqlite/schema/test_table.py @@ -7,7 +7,6 @@ class TestTable(unittest.TestCase): - maxDiff = None def setUp(self):