Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 99 additions & 76 deletions src/DatabaseLibrary/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import inspect
import re
import sys
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import sqlparse
from robot.api import logger
Expand Down Expand Up @@ -328,9 +328,7 @@ def execute_sql_script(
else:
statements_to_execute = self.split_sql_script(script_path, external_parser=external_parser)
for statement in statements_to_execute:
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
line_ends_with_proc_end = re.compile(r"(\s|;)" + proc_end_pattern.pattern + "$")
omit_semicolon = not line_ends_with_proc_end.search(statement.lower())
omit_semicolon = self._omit_semicolon_needed(statement)
self._execute_sql(cur, statement, omit_semicolon, replace_robot_variables=replace_robot_variables)
self._commit_if_needed(db_connection, no_transaction)
except Exception as e:
Expand All @@ -350,72 +348,82 @@ def split_sql_script(
"""
with open(script_path, encoding="UTF-8") as sql_file:
logger.info("Splitting script file into statements...")
statements_to_execute = []
if external_parser:
split_statements = sqlparse.split(sql_file.read())
for statement in split_statements:
statement_without_comments = sqlparse.format(statement, strip_comments=True)
if statement_without_comments:
statements_to_execute.append(statement_without_comments)
else:
current_statement = ""
inside_statements_group = False
proc_start_pattern = re.compile("create( or replace)? (procedure|function){1}( )?")
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
for line in sql_file:
line = line.strip()
if line.startswith("#") or line.startswith("--") or line == "/":
continue

# check if the line matches the creating procedure regexp pattern
if proc_start_pattern.match(line.lower()):
inside_statements_group = True
elif line.lower().startswith("begin"):
inside_statements_group = True

# semicolons inside the line? use them to separate statements
# ... but not if they are inside a begin/end block (aka. statements group)
sqlFragments = line.split(";")
# no semicolons
if len(sqlFragments) == 1:
current_statement += line + " "
continue
return self.split_sql_string(sql_file.read(), external_parser=external_parser)

def split_sql_string(self, sql_string: str, external_parser: bool = False):
if external_parser:
return self._split_statements_using_external_parser(sql_string)
else:
return self._parse_sql_internally(sql_string.splitlines())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finding by my side after commit. If general approach is ok, we may shift split to _parse_sql_internally function


def _parse_sql_internally(self, sql_file: List[str]) -> list[str]:
statements_to_execute = []
current_statement = ""
inside_statements_group = False
proc_start_pattern = re.compile("create( or replace)? (procedure|function){1}( )?")
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
for line in sql_file:
line = line.strip()
if line.startswith("#") or line.startswith("--") or line == "/":
continue

# check if the line matches the creating procedure regexp pattern
if proc_start_pattern.match(line.lower()):
inside_statements_group = True
elif line.lower().startswith("begin"):
inside_statements_group = True

# semicolons inside the line? use them to separate statements
# ... but not if they are inside a begin/end block (aka. statements group)
sqlFragments = line.split(";")
# no semicolons
if len(sqlFragments) == 1:
current_statement += line + " "
continue
quotes = 0
# "select * from person;" -> ["select..", ""]
for sqlFragment in sqlFragments:
if len(sqlFragment.strip()) == 0:
continue

if inside_statements_group:
# if statements inside a begin/end block have semicolns,
# they must persist - even with oracle
sqlFragment += "; "

if proc_end_pattern.match(sqlFragment.lower()):
inside_statements_group = False
elif proc_start_pattern.match(sqlFragment.lower()):
inside_statements_group = True
elif sqlFragment.lower().startswith("begin"):
inside_statements_group = True

# check if the semicolon is a part of the value (quoted string)
quotes += sqlFragment.count("'")
quotes -= sqlFragment.count("\\'")
inside_quoted_string = quotes % 2 != 0
if inside_quoted_string:
sqlFragment += ";" # restore the semicolon

current_statement += sqlFragment
if not inside_statements_group and not inside_quoted_string:
statements_to_execute.append(current_statement.strip())
current_statement = ""
quotes = 0
# "select * from person;" -> ["select..", ""]
for sqlFragment in sqlFragments:
if len(sqlFragment.strip()) == 0:
continue

if inside_statements_group:
# if statements inside a begin/end block have semicolns,
# they must persist - even with oracle
sqlFragment += "; "

if proc_end_pattern.match(sqlFragment.lower()):
inside_statements_group = False
elif proc_start_pattern.match(sqlFragment.lower()):
inside_statements_group = True
elif sqlFragment.lower().startswith("begin"):
inside_statements_group = True

# check if the semicolon is a part of the value (quoted string)
quotes += sqlFragment.count("'")
quotes -= sqlFragment.count("\\'")
inside_quoted_string = quotes % 2 != 0
if inside_quoted_string:
sqlFragment += ";" # restore the semicolon

current_statement += sqlFragment
if not inside_statements_group and not inside_quoted_string:
statements_to_execute.append(current_statement.strip())
current_statement = ""
quotes = 0

current_statement = current_statement.strip()
if len(current_statement) != 0:
statements_to_execute.append(current_statement)

return statements_to_execute

current_statement = current_statement.strip()
if len(current_statement) != 0:
statements_to_execute.append(current_statement)
return statements_to_execute

def _split_statements_using_external_parser(self, sql_file_content: str):
statements_to_execute = []
split_statements = sqlparse.split(sql_file_content)
for statement in split_statements:
statement_without_comments = sqlparse.format(statement, strip_comments=True)
if statement_without_comments:
statements_to_execute.append(statement_without_comments)
return statements_to_execute

@renamed_args(
mapping={
Expand All @@ -436,6 +444,8 @@ def execute_sql_string(
sqlString: Optional[str] = None,
sansTran: Optional[bool] = None,
omitTrailingSemicolon: Optional[bool] = None,
split: bool = False,
external_parser: bool = False,
):
"""
Executes the ``sql_string`` as a single SQL command.
Expand Down Expand Up @@ -473,17 +483,30 @@ def execute_sql_string(
cur = db_connection.client.cursor()
if omit_trailing_semicolon is None:
omit_trailing_semicolon = db_connection.omit_trailing_semicolon
self._execute_sql(
cur,
sql_string,
omit_trailing_semicolon=omit_trailing_semicolon,
parameters=parameters,
replace_robot_variables=replace_robot_variables,
)
if not split:
self._execute_sql(
cur,
sql_string,
omit_trailing_semicolon=omit_trailing_semicolon,
parameters=parameters,
replace_robot_variables=replace_robot_variables,
)
else:
statements_to_execute = self.split_sql_string(sql_string, external_parser=external_parser)
for statement in statements_to_execute:
omit_semicolon = self._omit_semicolon_needed(statement)
self._execute_sql(cur, statement, omit_semicolon, replace_robot_variables=replace_robot_variables)

self._commit_if_needed(db_connection, no_transaction)
except Exception as e:
self._rollback_and_raise(db_connection, no_transaction, e)

def _omit_semicolon_needed(self, statement: str) -> bool:
proc_end_pattern = re.compile("end(?!( if;| loop;| case;| while;| repeat;)).*;()?")
line_ends_with_proc_end = re.compile(r"(\s|;)" + proc_end_pattern.pattern + "$")
omit_semicolon = not line_ends_with_proc_end.search(statement.lower())
return omit_semicolon

@renamed_args(mapping={"spName": "procedure_name", "spParams": "procedure_params", "sansTran": "no_transaction"})
def call_stored_procedure(
self,
Expand Down
Loading