diff --git a/pathtraits/access.py b/pathtraits/access.py index 13acdf8..8d9a916 100644 --- a/pathtraits/access.py +++ b/pathtraits/access.py @@ -50,12 +50,12 @@ def get_dict(db, path): # get traits from path and its parents dirs_data = [] - data = db.get("data", path=abs_path) + data = db.get_pathtraits(abs_path) if data: dirs_data.append(data) for i in reversed(range(0, len(dirs))): cur_path = "/".join(dirs[0 : i + 1]) - data = db.get("data", path=cur_path) + data = db.get_pathtraits(cur_path) if data: dirs_data.append(data) @@ -70,40 +70,6 @@ def get_dict(db, path): return res -def get_paths(db, query_str): - """ - Docstring for get_paths - - :param db: Description - :param query_str: Description - """ - query_str = f"SELECT DISTINCT path FROM data where {query_str};" - res = db.execute(query_str, ignore_error=False).fetchall() - res = [x["path"] for x in res] - return res - - -def get_paths_values(db, query_str): - """ - Docstring for get_paths_values - - :param db: Description - :param query_str: Description - """ - query_str = f"SELECT * FROM data where {query_str};" - response = db.execute(query_str, ignore_error=False).fetchall() - res = {} - for r in response: - path = r["path"] - # ensure distinct paths - # pylint: disable=C0201 - if path not in res.keys(): - r = nest_dict(r) - r.pop("path") - res[path] = r - return res - - def get(path, db_path, verbose): """ Docstring for get @@ -124,7 +90,7 @@ def get(path, db_path, verbose): sys.exit(1) -def query(query_str, db_path, show_values): +def query(query_str, db_path): """ Docstring for query @@ -132,25 +98,8 @@ def query(query_str, db_path, show_values): :param db_path: Description """ db = TraitsDB(db_path) - if show_values: - res = get_paths_values(db, query_str) - if len(res) > 0: - print(yaml.safe_dump(res)) - else: - logger.error( - "No paths found for traits matching %s in database %s", - query_str, - db_path, - ) - sys.exit(1) - else: - res = get_paths(db, query_str) - if len(res) > 0: - for r in res: - print(r) - else: - logger.error( - "No paths found for traits matching %s in database %s", - query_str, - db_path, - ) + paths = db.get_paths(query_str) + if paths == []: + sys.exit(f"No paths matching query '{query_str}'") + for path in paths: + print(path) diff --git a/pathtraits/cli.py b/pathtraits/cli.py index 04cb373..a9c7131 100644 --- a/pathtraits/cli.py +++ b/pathtraits/cli.py @@ -75,17 +75,14 @@ def get(path, db_path, verbose): default=DB_PATH, type=click.Path(file_okay=True, dir_okay=False), ) -@click.option( - "--show-values", flag_value=True, default=False, help="Also show their trait values" -) -def query(query_str, db_path, show_values): +def query(query_str, db_path): """ Get paths of given traits Enter QUERY_STR in SQLite3 where statement format, - e.g. "[score/REAL]>1" to get all paths having a score >1. + e.g. "score_REAL>1" to get all paths having a numerical score >1. """ - access.query(query_str, db_path, show_values) + access.query(query_str, db_path) if __name__ == "__main__": diff --git a/pathtraits/db.py b/pathtraits/db.py index 2b5457d..ee9f042 100644 --- a/pathtraits/db.py +++ b/pathtraits/db.py @@ -21,6 +21,17 @@ class TraitsDB: cursor = None traits = [] + @staticmethod + def remove_type_suffixes(s: str): + """ + Docstring for remove_type_suffixes + + :param s: Description + :type s: str + """ + s = s.removesuffix("_TEXT").removesuffix("_REAL").removesuffix("_BOOL") + return s + @staticmethod def row_factory(cursor, row): """ @@ -35,12 +46,12 @@ def row_factory(cursor, row): if v is None: continue # sqlite don't know bool - if k.endswith("/BOOL"): + if k.endswith("_BOOL"): v = v > 0 if isinstance(v, float): v_int = int(v) v = v_int if v_int == v else v - k = k.removesuffix("/TEXT").removesuffix("/REAL").removesuffix("/BOOL") + k = TraitsDB.remove_type_suffixes(k) res[k] = v return res @@ -90,8 +101,8 @@ def __init__(self, db_path): self.cursor.row_factory = TraitsDB.row_factory init_path_table_query = """ - CREATE TABLE IF NOT EXISTS path ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + CREATE TABLE IF NOT EXISTS _path ( + path_id INTEGER PRIMARY KEY AUTOINCREMENT, path text NOT NULL UNIQUE ); """ @@ -99,10 +110,30 @@ def __init__(self, db_path): init_path_index_query = """ CREATE INDEX IF NOT EXISTS idx_path_path - ON path(path); + ON _path(path); """ self.execute(init_path_index_query) - self.update_traits() + + init_trait_table_query = """ + CREATE TABLE IF NOT EXISTS _trait ( + trait_id INTEGER PRIMARY KEY AUTOINCREMENT, + trait text NOT NULL UNIQUE + ); + """ + self.execute(init_trait_table_query) + + init_trait_path_table_query = """ + CREATE TABLE IF NOT EXISTS _trait_path ( + trait_id INTEGER, + path_id INTEGER, + FOREIGN KEY(trait_id) REFERENCES _trait(trait_id), + FOREIGN KEY(path_id) REFERENCES path(path_id), + UNIQUE(trait_id, path_id) + ); + """ + self.execute(init_trait_path_table_query) + + self.update_trait() # pylint: disable=R1710 def execute(self, query, ignore_error=True): @@ -151,6 +182,81 @@ def get(self, table, cols="*", condition=None, **kwargs): return res + def get_path_id(self, path: str): + """ + Docstring for get_path_id + + :param self: Description + :param path: Description + :type path: str + """ + res = self.get("_path", path=path, cols="path_id") + if res == []: + return None + + return res["path_id"] + + def get_traits(self, path_id: int): + """ + Get traits of a given path + """ + query = f""" + SELECT DISTINCT trait + FROM _trait + INNER JOIN _trait_path + WHERE path_id = '{path_id}' + """ + response = self.execute(query) + + if response is None: + return None + + res = response.fetchall() + if len(res) == 1: + return res[0] + return [x["trait"] for x in res] + + def get_pathtraits(self, path: str): + """ + Docstring for get_pathtraits + + :param self: Description + :param path: Description + :type path: str + """ + path_id = self.get_path_id(path) + traits = self.get_traits(path_id) + res = {} + for trait in traits: + pathtraits = self.get(trait, path_id=path_id) + if isinstance(pathtraits, dict): + pathtraits.pop("path_id") + for k, v in pathtraits.items(): + res[k] = v + return res + + def get_paths(self, query_str): + """ + Get paths matching pathtraits + + :param self: Description + :param kwargs: pathtraits to match + """ + traits = filter(lambda x: x in query_str, self.traits) + query = "SELECT DISTINCT path FROM _path" + for trait in traits: + # pylint: disable=R1713 + query += f" NATURAL JOIN {trait}" + query += f" WHERE {query_str};" + + response = self.execute(query) + if response is None: + return None + + res = response.fetchall() + res = [x["path"] for x in res] + return res + def put_path_id(self, path): """ Docstring for put_path_id @@ -159,13 +265,13 @@ def put_path_id(self, path): :param path: path to put to the data base :returns: the id of that path """ - get_row_query = f"SELECT id FROM path WHERE path = '{path}' LIMIT 1;" + get_row_query = f"SELECT path_id FROM _path WHERE path = '{path}' LIMIT 1;" res = self.execute(get_row_query).fetchone() if res: - return res["id"] + return res["path_id"] # create - self.put("path", path=path) - path_id = self.get("path", path=path, cols="id")["id"] + self.put("_path", path=path) + path_id = self.get_path_id(path) return path_id @staticmethod @@ -225,50 +331,21 @@ def put(self, table, condition=None, update=True, **kwargs): insert_query = f"INSERT INTO [{table}] ({keys}) VALUES ({values});" self.execute(insert_query) - def put_data_view(self): - """ - Creates a SQL View with all denormalized traits - """ - self.execute("DROP VIEW IF EXISTS DATA;") - - if self.traits: - join_query = " ".join( - [ - f"LEFT JOIN [{x}] ON [{x}].path = path.id \n" - for x in self.traits - if x != "path" - ] - ) - - create_view_query = f""" - CREATE VIEW data AS - SELECT path.path, [{'], ['.join(self.traits)}] - FROM path - {join_query}; - """ - else: - create_view_query = """ - CREATE VIEW data AS - SELECT path.path - FROM path; - """ - self.execute(create_view_query) - - def update_traits(self): + def update_trait(self): """ Get all traits from the database """ - get_traits_query = """ - SELECT name - FROM sqlite_master - WHERE type='table' - AND name NOT LIKE 'sqlite_%' - AND name != 'path' - ORDER BY name; + get_trait_query = """ + SELECT trait + FROM _trait + ORDER BY trait; """ - traits = self.execute(get_traits_query).fetchall() + traits = self.execute(get_trait_query) + if traits is not None: + traits = traits.fetchall() + else: + traits = [] self.traits = [list(x.values())[0] for x in traits] - self.put_data_view() def create_trait_table(self, trait_name, value_type): """ @@ -289,13 +366,14 @@ def create_trait_table(self, trait_name, value_type): sql_type = TraitsDB.sql_type(value_type) add_table_query = f""" CREATE TABLE [{trait_name}] ( - path INTEGER, + path_id INTEGER, [{trait_name}] {sql_type}, - FOREIGN KEY(path) REFERENCES path(id) + FOREIGN KEY(path_id) REFERENCES path(path_id) ); """ self.execute(add_table_query) - self.update_traits() + self.put("_trait", trait=trait_name) + self.update_trait() def put_trait(self, path_id, trait_name, value, update=True): """ @@ -306,8 +384,10 @@ def put_trait(self, path_id, trait_name, value, update=True): :param key: trait name :param value: trait value """ - kwargs = {"path": path_id, trait_name: value} - self.put(trait_name, condition=f"path = {path_id}", update=update, **kwargs) + kwargs = {"path_id": path_id, trait_name: value} + trait_id = self.get("_trait", trait=trait_name)["trait_id"] + self.put("_trait_path", trait_id=trait_id, path_id=path_id) + self.put(trait_name, condition=f"path_id = {path_id}", update=update, **kwargs) def add_pathpair(self, pair: PathPair): """ @@ -339,7 +419,7 @@ def add_pathpair(self, pair: PathPair): # get element type for list # add: handle lists with mixed element type t = type(v[0]) if isinstance(v, list) and len(v) > 0 else type(v) - k = f"{k}/{TraitsDB.sql_type(t)}" + k = f"{k}_{TraitsDB.sql_type(t)}" if k not in self.traits: self.create_trait_table(k, t) if k in self.traits: diff --git a/test/test.py b/test/test.py index 197c7aa..d1d2490 100644 --- a/test/test.py +++ b/test/test.py @@ -70,20 +70,14 @@ def test_missing_north_america(self): for k, v in target.items(): self.assertEqual(source[k], v) - def test_data_view(self): - source = len(self.db.execute("SELECT * FROM data;").fetchall()) - target = 8 - self.assertEqual(source, target) + def test_query(self): + q1 = self.db.get_paths("score_REAL > 3") + self.assertEqual(len(q1), 2) - def test_data_query(self): - source = len(pathtraits.access.get_paths(self.db, "[score/REAL] >= 5")) - target = 1 - self.assertEqual(source, target) - - traits = pathtraits.access.get_paths_values(self.db, "TRUE") - self.assertEqual(len(traits), 3) - for v in traits.values(): - self.assertTrue("path" not in v.keys()) + q2 = self.db.get_paths( + "score_TEXT = 'zero' AND description_TEXT LIKE '%Germany%'" + ) + self.assertEqual(len(q2), 1) if __name__ == "__main__":