Skip to content

Upserting capabilities added at creation #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
uses: docker/setup-buildx-action@v1

- name: Cache Docker layers
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
Expand Down
125 changes: 84 additions & 41 deletions src/api/create.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import math
import uuid
Expand All @@ -9,8 +10,9 @@
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
from psycopg2.extras import Json
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy import MetaData, create_engine, inspect, text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy_utils.functions import create_database, database_exists, drop_database
from sqlmodel import SQLModel
from tqdm import tqdm
Expand Down Expand Up @@ -86,56 +88,98 @@ def normalize_signal_name(name):
)
return signal_name

def get_primary_keys(table_name, engine):
inspector = inspect(engine)
pk_columns = inspector.get_pk_constraint(table_name).get("constrained_columns", [])
if not pk_columns:
raise ValueError(f"No primary key found for table {table_name}")
return [', '.join(pk_columns)]

class DBCreationClient:
def __init__(self, uri: str, db_name: str):
def __init__(self, uri: str, db_name: str, mode: str = "create"):
self.uri = uri
self.db_name = db_name
self.mode = mode

def create_database(self):
if database_exists(self.uri):
drop_database(self.uri)

create_database(self.uri)
if self.mode == "create":
logging.info("creating database")
if database_exists(self.uri):
drop_database(self.uri)
create_database(self.uri)
else:
logging.info("updating database")
if not database_exists(self.uri):
raise ValueError("Cannot update as the database hasn't been created.")

self.metadata_obj, self.engine = connect(self.uri)

engine = create_engine(self.uri, echo=True)
SQLModel.metadata.create_all(engine)
# recreate the engine/metadata object

self.metadata_obj, self.engine = connect(self.uri)
return engine

def create_user(self):
engine = create_engine(self.uri, echo=True)
name = password = "public_user"
drop_user = text(f"DROP USER IF EXISTS {name}")

drop_user = text(f"DROP USER IF EXISTS {name};")
create_user_query = text(f"CREATE USER {name} WITH PASSWORD :password;")
grant_privledges = text(f"GRANT CONNECT ON DATABASE {self.db_name} TO {name};")
grant_public_schema = text(f"GRANT USAGE ON SCHEMA public TO {name};")
grant_public_schema_tables = text(
f"GRANT SELECT ON ALL TABLES IN SCHEMA public TO {name};"
)

grant_privileges = [
text(f"GRANT CONNECT ON DATABASE {self.db_name} TO {name};"),
text(f"GRANT USAGE ON SCHEMA public TO {name};"),
text(f"GRANT SELECT ON ALL TABLES IN SCHEMA public TO {name};"),
]

with engine.connect() as conn:
conn.execute(drop_user)
conn.execute(create_user_query, {"password": password})
conn.execute(grant_privledges)
conn.execute(grant_public_schema)
conn.execute(grant_public_schema_tables)
if self.mode == "create":
conn.execute(drop_user)
conn.execute(create_user_query, {"password": password})
elif self.mode == "update":
user_exists_query = text(
f"SELECT 1 FROM pg_catalog.pg_roles WHERE rolname = '{name}';"
)
result = conn.execute(user_exists_query).fetchone()

if not result:
conn.execute(create_user_query, {"password": password})

for grant_query in grant_privileges:
conn.execute(grant_query)

def create_or_upsert_table(self, table_name: str, df: pd.DataFrame):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also have unit testing for this method?


self.metadata_obj.reflect(bind=self.engine)
if self.mode == 'create':
df.to_sql(table_name, self.engine, if_exists="append", index=False)

elif self.mode == 'update':
table = self.metadata_obj.tables[table_name]

insert_stmt = insert(table).values(df.to_dict(orient="records"))
primary_key = get_primary_keys(table_name, self.engine)
update_column_stmt = {col.key: insert_stmt.excluded[col.key] for col in table.columns if col.key != "uuid"}

stmt = insert_stmt.on_conflict_do_update(index_elements=primary_key, set_=update_column_stmt)
with Session(self.engine) as session:
session.execute(stmt)
session.commit()

def create_cpf_summary(self, data_path: Path):
"""Create the CPF summary table"""
paths = data_path.glob("cpf/*_cpf_columns.parquet")
dfs = [pd.read_parquet(path) for path in paths]
df = pd.concat(dfs).reset_index(drop=True)
df["context"] = [Json(base_context)] * len(df)
df = pd.concat(dfs).reset_index(drop=False)
df["context"] = [json.dumps(base_context)] * len(df)
df = df.drop_duplicates(subset=["name"])
df["name"] = df["name"].apply(
lambda x: models.ShotModel.__fields__.get("cpf_" + x.lower()).alias
if models.ShotModel.__fields__.get("cpf_" + x.lower())
else x
)
df.to_sql("cpf_summary", self.uri, if_exists="append")
self.create_or_upsert_table("cpf_summary", df)

def create_scenarios(self, data_path: Path):
"""Create the scenarios metadata table"""
Expand All @@ -144,10 +188,10 @@ def create_scenarios(self, data_path: Path):
ids = shot_metadata["scenario_id"].unique()
scenarios = shot_metadata["scenario"].unique()

data = pd.DataFrame(dict(id=ids, name=scenarios)).set_index("id")
data = pd.DataFrame(dict(id=ids, name=scenarios))
data = data.dropna()
data["context"] = [Json(base_context)] * len(data)
data.to_sql("scenarios", self.uri, if_exists="append")
data["context"] = [json.dumps(base_context)] * len(data)
self.create_or_upsert_table("scenarios", data)

def create_shots(self, data_path: Path):
"""Create the shot metadata table"""
Expand All @@ -165,7 +209,7 @@ def create_shots(self, data_path: Path):
shot_metadata["scenario"] = shot_metadata["scenario_id"]
shot_metadata["facility"] = "MAST"
shot_metadata = shot_metadata.drop(["scenario_id", "reference_id"], axis=1)
shot_metadata["context"] = [Json(base_context)] * len(shot_metadata)
shot_metadata["context"] = [json.dumps(base_context)] * len(shot_metadata)
shot_metadata["uuid"] = shot_metadata.index.map(get_dataset_uuid)
shot_metadata["url"] = (
"s3://mast/level1/shots/" + shot_metadata.index.astype(str) + ".zarr"
Expand All @@ -182,7 +226,7 @@ def create_shots(self, data_path: Path):
cpfs = pd.concat(cpfs, axis=0)
cpfs = cpfs = cpfs.reset_index()
cpfs = cpfs.loc[cpfs.shot_id <= LAST_MAST_SHOT]
cpfs = cpfs.drop_duplicates(subset="shot_id")
cpfs = cpfs.drop_duplicates(subset="shot_id").sort_values(by="shot_id")
cpfs = cpfs.set_index("shot_id")

shot_metadata = pd.merge(
Expand All @@ -192,11 +236,11 @@ def create_shots(self, data_path: Path):
right_on="shot_id",
how="left",
)

shot_metadata.to_sql("shots", self.uri, if_exists="append")
shot_metadata = shot_metadata.reset_index()
shot_metadata = shot_metadata.replace(np.nan, None)
self.create_or_upsert_table("shots", shot_metadata)

def create_signals(self, data_path: Path):
logging.info(f"Loading signals from {data_path}")
file_name = data_path / "signals.parquet"

parquet_file = pq.ParquetFile(file_name)
Expand All @@ -212,7 +256,7 @@ def create_signals(self, data_path: Path):
df = signals_metadata
df = df[df.shot_id <= LAST_MAST_SHOT]
df = df.drop_duplicates(subset="uuid")
df["context"] = [Json(base_context)] * len(df)
df['context'] = [json.dumps(base_context)] * len(df)
df["shape"] = df["shape"].map(lambda x: x.tolist())
df["dimensions"] = df["dimensions"].map(lambda x: x.tolist())
df["url"] = (
Expand All @@ -225,15 +269,14 @@ def create_signals(self, data_path: Path):
uda_attributes = ["uda_name", "mds_name", "file_name", "format"]
df = df.drop(uda_attributes, axis=1)
df["shot_id"] = df.shot_id.astype(int)
df = df.set_index("shot_id", drop=True)
df["description"] = df.description.map(lambda x: "" if x is None else x)
df.to_sql("signals", self.uri, if_exists="append")
self.create_or_upsert_table("signals", df)

def create_sources(self, data_path: Path):
source_metadata = pd.read_parquet(data_path / "sources.parquet")
source_metadata = source_metadata.drop_duplicates("uuid")
source_metadata = source_metadata.loc[source_metadata.shot_id <= LAST_MAST_SHOT]
source_metadata["context"] = [Json(base_context)] * len(source_metadata)
source_metadata["context"] = [json.dumps(base_context)] * len(source_metadata)
source_metadata["url"] = (
"s3://mast/level1/shots/"
+ source_metadata["shot_id"].map(str)
Expand All @@ -250,7 +293,7 @@ def create_sources(self, data_path: Path):
"context",
]
source_metadata = source_metadata[column_names]
source_metadata.to_sql("sources", self.uri, if_exists="append", index=False)
self.create_or_upsert_table("sources", source_metadata)

def create_serve_dataset(self):
data = {
Expand Down Expand Up @@ -296,10 +339,10 @@ def create_serve_dataset(self):
}
}
df = pd.DataFrame(data, index=[0])
df["publisher"] = Json(publisher)
df['publisher'] = json.dumps(publisher)
df["id"] = "host/json/data-service"
df["context"] = Json(dict(list(base_context.items())[-3:]))
df.to_sql("dataservice", self.uri, if_exists="append", index=False)
df['context'] = json.dumps(dict(list(base_context.items())[-3:]))
self.create_or_upsert_table("dataservice", df)


def read_cpf_metadata(cpf_file_name: Path) -> pd.DataFrame:
Expand All @@ -317,10 +360,11 @@ def read_cpf_metadata(cpf_file_name: Path) -> pd.DataFrame:

@click.command()
@click.argument("data_path", default="~/mast-data/meta")
def create_db_and_tables(data_path):
@click.argument("mode", type=click.Choice(["create", "update"]), default="create")
def create_db_and_tables(data_path, mode):
data_path = Path(data_path)

client = DBCreationClient(SQLALCHEMY_DATABASE_URL, DB_NAME)
client = DBCreationClient(SQLALCHEMY_DATABASE_URL, DB_NAME, mode)
client.create_database()

# populate the database tables
Expand All @@ -347,5 +391,4 @@ def create_db_and_tables(data_path):

if __name__ == "__main__":
dask.config.set({"dataframe.convert-string": False})
# print(models.ShotModel.__fields__)
create_db_and_tables()
Loading