diff --git a/pathtraits/access.py b/pathtraits/access.py index 0300ae6..0101d47 100644 --- a/pathtraits/access.py +++ b/pathtraits/access.py @@ -37,7 +37,7 @@ def nest_dict(flat_dict, delimiter="/"): return nested_dict -def get_dict(self, path): +def get_dict(db, path): """ Get traits for a path as a Python dictionary @@ -50,12 +50,12 @@ def get_dict(self, path): # get traits from path and its parents dirs_data = [] - data = self.get("data", path=abs_path) + data = db.get("data", path=abs_path) if data: dirs_data.append(data) for i in reversed(range(0, len(dirs))): cur_path = "/".join(dirs[0 : i + 1]) - data = self.get("data", path=cur_path) + data = db.get("data", path=cur_path) if data: dirs_data.append(data) @@ -70,6 +70,31 @@ def get_dict(self, path): return res +def get_paths(db, query_str): + """ + Docstring for get_paths + + :param db: Description + :param query_str: Description + """ + query_str = f"SELECT DISTINCT path FROM data where {query_str};" + res = db.execute(query_str, ignore_error=False).fetchall() + res = [x["path"] for x in res] + return res + + +def get_paths_values(db, query_str): + """ + Docstring for get_paths_values + + :param db: Description + :param query_str: Description + """ + query_str = f"SELECT * FROM data where {query_str};" + res = db.execute(query_str, ignore_error=False).fetchall() + return res + + def get(path, db_path, verbose): """ Docstring for get @@ -88,3 +113,35 @@ def get(path, db_path, verbose): else: logger.error("No traits found for path %s in database %s", path, db_path) sys.exit(1) + + +def query(query_str, db_path, show_values): + """ + Docstring for query + + :param query_str: Description + :param db_path: Description + """ + db = TraitsDB(db_path) + if show_values: + res = get_paths_values(db, query_str) + if len(res) > 0: + print(yaml.safe_dump(res)) + else: + logger.error( + "No paths found for traits matching %s in database %s", + query_str, + db_path, + ) + sys.exit(1) + else: + res = get_paths(db, query_str) + if len(res) > 0: + for r in res: + print(r) + else: + logger.error( + "No paths found for traits matching %s in database %s", + query_str, + db_path, + ) diff --git a/pathtraits/cli.py b/pathtraits/cli.py index f4ff07c..04cb373 100644 --- a/pathtraits/cli.py +++ b/pathtraits/cli.py @@ -19,7 +19,7 @@ def main(): """ -@main.command(help="Update database once, searches for all directories recursively.") +@main.command() @click.argument("path", required=True, type=click.Path(exists=True)) @click.option( "--db-path", @@ -34,16 +34,11 @@ def main(): def batch(path, db_path, exclude_regex, verbose): """ Update database once, searches for all directories recursively. - - :param path: path to scan in batch mode recursively - :param db_path: path to the database - :param exclude_regex: exclue file paths matching this regex - :param verbose: enable verbose logging """ scan.batch(path, db_path, exclude_regex, verbose) -@main.command(help="Update database continiously, watches for new or changed files.") +@main.command() @click.argument("path", required=True, type=click.Path(exists=True)) @click.option( "--db-path", @@ -54,15 +49,11 @@ def batch(path, db_path, exclude_regex, verbose): def watch(path, db_path, verbose): """ Update database continiously, watches for new or changed files. - - :param path: path to watch recursively - :param db_path: path to the database - :param verbose: enable verbose logging """ scan.watch(path, db_path, verbose) -@main.command(help="Get traits of a given path") +@main.command() @click.argument("path", required=True, type=click.Path(exists=True)) @click.option( "--db-path", @@ -73,13 +64,29 @@ def watch(path, db_path, verbose): def get(path, db_path, verbose): """ Get traits of a given path - - :param path: path to get traits for - :param db_path: path to the database - :param verbose: enable verbose logging """ access.get(path, db_path, verbose) +@main.command() +@click.argument("query_str", required=True) +@click.option( + "--db-path", + default=DB_PATH, + type=click.Path(file_okay=True, dir_okay=False), +) +@click.option( + "--show-values", flag_value=True, default=False, help="Also show their trait values" +) +def query(query_str, db_path, show_values): + """ + Get paths of given traits + + Enter QUERY_STR in SQLite3 where statement format, + e.g. "[score/REAL]>1" to get all paths having a score >1. + """ + access.query(query_str, db_path, show_values) + + if __name__ == "__main__": main() diff --git a/pathtraits/db.py b/pathtraits/db.py index 0fb2653..13f276c 100644 --- a/pathtraits/db.py +++ b/pathtraits/db.py @@ -5,6 +5,7 @@ import logging import sqlite3 import os +import sys from collections.abc import MutableMapping import yaml from pathtraits.pathpair import PathPair @@ -104,7 +105,7 @@ def __init__(self, db_path): self.update_traits() # pylint: disable=R1710 - def execute(self, query): + def execute(self, query, ignore_error=True): """ Execute a SQLite query @@ -114,8 +115,11 @@ def execute(self, query): try: res = self.cursor.execute(query) return res - except sqlite3.DatabaseError: - logger.debug("Ignore failed query %s", query) + except sqlite3.DatabaseError as e: + if ignore_error: + logger.debug("Ignore failed query %s: %s", query, e) + else: + sys.exit(e) def get(self, table, cols="*", condition=None, **kwargs): """ diff --git a/test/test.py b/test/test.py index e214e1e..48d6639 100644 --- a/test/test.py +++ b/test/test.py @@ -75,6 +75,15 @@ def test_data_view(self): target = 8 self.assertEqual(source, target) + def test_data_query(self): + source = len(pathtraits.access.get_paths(self.db, "[score/REAL] >= 5")) + target = 1 + self.assertEqual(source, target) + + source = len(pathtraits.access.get_paths(self.db, "TRUE")) + target = 3 + self.assertEqual(source, target) + if __name__ == "__main__": unittest.main()