diff --git a/collectoss/api/gunicorn_conf.py b/collectoss/api/gunicorn_conf.py index 22c11231a..ee7797471 100644 --- a/collectoss/api/gunicorn_conf.py +++ b/collectoss/api/gunicorn_conf.py @@ -7,6 +7,7 @@ from collectoss.application.db.lib import get_value from collectoss.application.db import dispose_database_engine +from collectoss.application.environment import SystemEnv logger = logging.getLogger(__name__) @@ -20,8 +21,8 @@ workers = multiprocessing.cpu_count() * 2 + 1 umask = 0o007 reload = True - -is_dev = os.getenv("AUGUR_DEV", 'False').lower() in ('true', '1', 't', 'y', 'yes') +# this satisfies the type checker +is_dev = SystemEnv.get_bool("AUGUR_DEV", False) if is_dev: @@ -40,7 +41,8 @@ # set the log location for gunicorn logs_directory = get_value('Logging', 'logs_directory') -is_docker = os.getenv("AUGUR_DOCKER_DEPLOY", 'False').lower() in ('true', '1', 't', 'y', 'yes') +# this syntax satisfies the type checker +is_docker = SystemEnv.get_bool("AUGUR_DOCKER_DEPLOY", False) accesslog = f"{logs_directory}/gunicorn.log" errorlog = f"{logs_directory}/gunicorn.log" diff --git a/collectoss/api/routes/auggie.py b/collectoss/api/routes/auggie.py index 18642498f..6d036045a 100644 --- a/collectoss/api/routes/auggie.py +++ b/collectoss/api/routes/auggie.py @@ -14,6 +14,8 @@ import requests import slack +from collectoss.application.environment import SystemEnv + from ..server import app @@ -252,7 +254,7 @@ def get_auggie_user(): # return Response(response=response, status=200, mimetype="application/json") ## From Method profile_name = 'collectoss' - if os.environ.get('AUGUR_IS_PROD'): + if SystemEnv.get('AUGUR_IS_PROD'): profile_name = 'default' client = boto3.Session(region_name='us-east-1', profile_name=profile_name).client('dynamodb') response = client.get_item( @@ -278,7 +280,7 @@ def update_auggie_user_tracking(): # return Response(response=response, status=200, mimetype="application/json") ## From Method profile_name = 'collectoss' - if os.environ.get('AUGUR_IS_PROD'): + if SystemEnv.get('AUGUR_IS_PROD'): profile_name = 'default' client = boto3.Session(region_name='us-east-1', profile_name=profile_name).client('dynamodb') response = client.update_item( @@ -326,7 +328,7 @@ def slack_login(): print("slack_login") r = requests.get( - url=f'https://slack.com/api/oauth.v2.access?code={body["code"]}&client_id={os.environ["AUGGIE_CLIENT_ID"]}&client_secret={os.environ["AUGGIE_CLIENT_SECRET"]}&redirect_uri=http%3A%2F%2Flocalhost%3A8080') + url=f'https://slack.com/api/oauth.v2.access?code={body["code"]}&client_id={SystemEnv.get("AUGGIE_CLIENT_ID")}&client_secret={SystemEnv.get("AUGGIE_CLIENT_SECRET")}&redirect_uri=http%3A%2F%2Flocalhost%3A8080') data = r.json() if (data["ok"]): @@ -340,7 +342,7 @@ def slack_login(): email = user_response["user"]["email"] profile_name = 'collectoss' - if os.environ.get('AUGUR_IS_PROD'): + if SystemEnv.get('AUGUR_IS_PROD'): profile_name = 'default' print("Making Boto3 Session") client = boto3.Session(region_name='us-east-1', diff --git a/collectoss/api/view/init.py b/collectoss/api/view/init.py index ab4708793..1ab68912c 100644 --- a/collectoss/api/view/init.py +++ b/collectoss/api/view/init.py @@ -1,13 +1,11 @@ import os from pathlib import Path -from .server import Environment from collectoss.application.logs import SystemLogger import secrets, yaml - -env = Environment() +from collectoss.application.environment import SystemEnv # load configuration files and initialize globals -configFile = Path(env.setdefault("CONFIG_LOCATION", "config.yml")) +configFile = Path(SystemEnv.get("CONFIG_LOCATION") or "config.yml") settings = {} diff --git a/collectoss/api/view/server/Environment.py b/collectoss/api/view/server/Environment.py deleted file mode 100644 index 76b8207ca..000000000 --- a/collectoss/api/view/server/Environment.py +++ /dev/null @@ -1,52 +0,0 @@ -import os - -class Environment: - """ - This class is used to make dealing with environment variables easier. It - allows you to set multiple environment variables at once, and to get items - with subscript notation without needing to deal with the particularities of - non-existent values. - """ - def __init__(self, **kwargs): - for (key, value) in kwargs.items(): - self[key] = value - - def setdefault(self, key, value): - if not self[key]: - self[key] = value - return value - return self[key] - - def setall(self, **kwargs): - result = {} - for (key, value) in kwargs.items(): - if self[key]: - result[key] = self[key] - self[key] = value - - def getany(self, *args): - result = {} - for arg in args: - if self[arg]: - result[arg] = self[arg] - return result - - def as_type(self, type, key): - if self[key]: - return type(self[key]) - return None - - def __getitem__(self, key): - return os.getenv(key) - - def __setitem__(self, key, value): - os.environ[key] = str(value) - - def __len__(self)-> int: - return len(os.environ) - - def __str__(self)-> str: - return str(os.environ) - - def __iter__(self): - return (item for item in os.environ.items()) \ No newline at end of file diff --git a/collectoss/api/view/server/__init__.py b/collectoss/api/view/server/__init__.py index e919a597a..98ce903be 100644 --- a/collectoss/api/view/server/__init__.py +++ b/collectoss/api/view/server/__init__.py @@ -1,2 +1 @@ -from .LoginException import LoginException -from .Environment import Environment \ No newline at end of file +from .LoginException import LoginException \ No newline at end of file diff --git a/collectoss/application/cli/__init__.py b/collectoss/application/cli/__init__.py index 8081d6a8e..b398614e2 100644 --- a/collectoss/application/cli/__init__.py +++ b/collectoss/application/cli/__init__.py @@ -10,7 +10,9 @@ from collectoss.application.db.engine import DatabaseEngine from collectoss.application.db import get_engine, dispose_database_engine -from sqlalchemy.exc import OperationalError +from sqlalchemy.exc import OperationalError +from collectoss.application.environment import SystemEnv + def check_connectivity(urls=["http://chaoss.community", "http://github.com", "http://gitlab.com"], timeout=10.0): @@ -65,11 +67,11 @@ def new_func(ctx, *args, **kwargs): return ctx.invoke(function_db_connection, *args, **kwargs) except OperationalError as e: - db_environment_var = os.getenv("AUGUR_DB") + db_environment_var = SystemEnv.get("AUGUR_DB") # determine the location to print in error string if db_environment_var: - location = f"the AUGUR_DB environment variable\nAUGUR_DB={os.getenv('AUGUR_DB')}" + location = f"the AUGUR_DB environment variable\nAUGUR_DB={SystemEnv.get('AUGUR_DB')}" else: with open("db.config.json", 'r') as f: db_config = json.load(f) diff --git a/collectoss/application/cli/api.py b/collectoss/application/cli/api.py index a8bb9e53b..4f7077a78 100644 --- a/collectoss/application/cli/api.py +++ b/collectoss/application/cli/api.py @@ -17,6 +17,8 @@ from collectoss.application.cli import test_connection, test_db_connection, with_database, DatabaseContext from collectoss.application.cli._cli_util import _broadcast_signal_to_processes, raise_open_file_limit, clear_redis_caches, clear_rabbitmq_messages from collectoss.application.db.lib import get_value +from collectoss.application.environment import SystemEnv + logger = SystemLogger("collectoss", reset_logfiles=False).get_logger() @@ -36,7 +38,7 @@ def start(ctx, development, port): """Start CollectOSS's backend server.""" try: - if os.environ.get('AUGUR_DOCKER_DEPLOY') != "1": + if SystemEnv.get('AUGUR_DOCKER_DEPLOY') != "1": raise_open_file_limit(100000) except Exception as e: logger.error( @@ -46,7 +48,7 @@ def start(ctx, development, port): raise e if development: - os.environ["AUGUR_DEV"] = "1" + SystemEnv.set("AUGUR_DEV", "1") logger.info("Starting in development mode") try: @@ -142,7 +144,7 @@ def get_api_processes(): def is_api_process(process): command = ''.join(process.info['cmdline'][:]).lower() - if os.getenv('VIRTUAL_ENV') in process.info['environ']['VIRTUAL_ENV'] and 'python' in command: + if SystemEnv.get('VIRTUAL_ENV') in process.info['environ']['VIRTUAL_ENV'] and 'python' in command: if process.pid != os.getpid(): diff --git a/collectoss/application/cli/backend.py b/collectoss/application/cli/backend.py index a07ddf198..d4586083b 100644 --- a/collectoss/application/cli/backend.py +++ b/collectoss/application/cli/backend.py @@ -15,6 +15,7 @@ import requests from redis.exceptions import ConnectionError as RedisConnectionError +from collectoss.application.environment import SystemEnv from collectoss.tasks.start_tasks import collection_monitor, create_collection_status_records from collectoss.tasks.git.facade_tasks import clone_repos from collectoss.tasks.github.contributors import process_contributors @@ -31,7 +32,7 @@ from keyman.KeyClient import KeyClient, KeyPublisher -reset_logs = os.getenv("AUGUR_RESET_LOGS", 'True').lower() in ('true', '1', 't', 'y', 'yes') +reset_logs = SystemEnv.get_bool("AUGUR_RESET_LOGS", True) logger = SystemLogger("collectoss", reset_logfiles=reset_logs).get_logger() @@ -61,7 +62,7 @@ def start(ctx, disable_collection, development, pidfile, port): signal.signal(signal.SIGINT, manager.shutdown_signal_handler) try: - if os.environ.get('AUGUR_DOCKER_DEPLOY') != "1": + if SystemEnv.get('AUGUR_DOCKER_DEPLOY') != "1": raise_open_file_limit(100000) except Exception as e: logger.error( @@ -71,10 +72,10 @@ def start(ctx, disable_collection, development, pidfile, port): raise e if development: - os.environ["AUGUR_DEV"] = "1" + SystemEnv.set("AUGUR_DEV", "1") logger.info("Starting in development mode") - os.environ["AUGUR_PIDFILE"] = pidfile + SystemEnv.set("AUGUR_PIDFILE", pidfile) try: gunicorn_location = os.getcwd() + "/collectoss/api/gunicorn_conf.py" @@ -86,10 +87,10 @@ def start(ctx, disable_collection, development, pidfile, port): if not port: port = get_value("Server", "port") - os.environ["AUGUR_PORT"] = str(port) + SystemEnv.set("AUGUR_PORT", str(port)) if disable_collection: - os.environ["AUGUR_DISABLE_COLLECTION"] = "1" + SystemEnv.set("AUGUR_DISABLE_COLLECTION", "1") core_worker_count = get_value("Celery", 'core_worker_count') secondary_worker_count = get_value("Celery", 'secondary_worker_count') @@ -130,7 +131,7 @@ def start(ctx, disable_collection, development, pidfile, port): processes = start_celery_worker_processes((core_worker_count, secondary_worker_count, facade_worker_count), disable_collection) manager.processes = processes - celery_beat_schedule_db = os.getenv("CELERYBEAT_SCHEDULE_DB", "celerybeat-schedule.db") + celery_beat_schedule_db = SystemEnv.get("CELERYBEAT_SCHEDULE_DB", "celerybeat-schedule.db") if os.path.exists(celery_beat_schedule_db): logger.info("Deleting old task schedule") os.remove(celery_beat_schedule_db) @@ -144,7 +145,7 @@ def start(ctx, disable_collection, development, pidfile, port): manager.keypub = keypub if not disable_collection: - if os.environ.get('AUGUR_DOCKER_DEPLOY') != "1": + if SystemEnv.get('AUGUR_DOCKER_DEPLOY') != "1": orchestrator = subprocess.Popen("python keyman/Orchestrator.py".split()) # Wait for orchestrator startup @@ -355,10 +356,10 @@ def export_env(config): Exports your GitHub key and database credentials """ - export_file = open(os.getenv('AUGUR_EXPORT_FILE', 'collectoss_export_env.sh'), 'w+') + export_file = open(SystemEnv.get('AUGUR_EXPORT_FILE') or 'collectoss_export_env.sh', 'w+') export_file.write('#!/bin/bash') export_file.write('\n') - env_file = open(os.getenv('AUGUR_ENV_FILE', 'docker_env.txt'), 'w+') + env_file = open(SystemEnv.get('AUGUR_ENV_FILE') or 'docker_env.txt', 'w+') for env_var in config.get_env_config().items(): if "LOG" not in env_var[0]: @@ -403,7 +404,7 @@ def get_backend_processes(): for process in psutil.process_iter(['cmdline', 'name', 'environ']): if process.info['cmdline'] is not None and process.info['environ'] is not None: try: - if os.getenv('VIRTUAL_ENV') in process.info['environ']['VIRTUAL_ENV'] and 'python' in ''.join(process.info['cmdline'][:]).lower(): + if SystemEnv.get('VIRTUAL_ENV') in process.info['environ']['VIRTUAL_ENV'] and 'python' in ''.join(process.info['cmdline'][:]).lower(): if process.pid != os.getpid(): process_list.append(process) except (KeyError, FileNotFoundError): diff --git a/collectoss/application/cli/collection.py b/collectoss/application/cli/collection.py index b1a93ce80..5127a8d17 100644 --- a/collectoss/application/cli/collection.py +++ b/collectoss/application/cli/collection.py @@ -14,6 +14,7 @@ import traceback import sqlalchemy as s +from collectoss.application.environment import SystemEnv from collectoss.tasks.start_tasks import collection_monitor, create_collection_status_records from collectoss.tasks.git.facade_tasks import clone_repos from collectoss.tasks.github.util.github_api_key_handler import GithubApiKeyHandler @@ -45,7 +46,7 @@ def start(ctx, development): """Start CollectOSS's backend server.""" try: - if os.environ.get('AUGUR_DOCKER_DEPLOY') != "1": + if SystemEnv.get('AUGUR_DOCKER_DEPLOY') != "1": raise_open_file_limit(100000) except Exception as e: logger.error( @@ -75,7 +76,7 @@ def start(ctx, development): keypub.publish(key, "gitlab_rest") if development: - os.environ["AUGUR_DEV"] = "1" + SystemEnv.set("AUGUR_DEV", "1") logger.info("Starting in development mode") core_worker_count = get_value("Celery", 'core_worker_count') @@ -237,7 +238,7 @@ def get_collection_processes(): def is_collection_process(process): command = ''.join(process.info['cmdline'][:]).lower() - if os.getenv('VIRTUAL_ENV') in process.info['environ']['VIRTUAL_ENV'] and 'python' in command: + if SystemEnv.get('VIRTUAL_ENV') in process.info['environ']['VIRTUAL_ENV'] and 'python' in command: if process.pid != os.getpid(): if "collectossbackendcollection" in command or "celery_app.celery_appbeat" in command: diff --git a/collectoss/application/cli/db.py b/collectoss/application/cli/db.py index fd5db52cf..e2b1b7e3f 100644 --- a/collectoss/application/cli/db.py +++ b/collectoss/application/cli/db.py @@ -28,6 +28,7 @@ process_repo_csv, process_repo_group_csv, ) +from collectoss.application.environment import SystemEnv logger = logging.getLogger(__name__) @@ -379,7 +380,7 @@ def get_api_key(ctx): short_help="Check the ~/.pgpass file for CollectOSS's database credentials", ) def check_pgpass(): - db_environment_var = getenv("AUGUR_DB") + db_environment_var = SystemEnv.get("AUGUR_DB") if db_environment_var: # gets the user, passowrd, host, port, and database_name out of environment variable # assumes database string of structure //:@:/ @@ -495,7 +496,7 @@ def run_psql_command_in_database(target_type, target): logger.error("Invalid target type. Exiting...") exit(1) - db_environment_var = getenv("AUGUR_DB") + db_environment_var = SystemEnv.get("AUGUR_DB") # db_json_file_location = os.getcwd() + "/db.config.json" # db_json_exists = os.path.exists(db_json_file_location) diff --git a/collectoss/application/config.py b/collectoss/application/config.py index 56e6c57ae..051235323 100644 --- a/collectoss/application/config.py +++ b/collectoss/application/config.py @@ -7,6 +7,8 @@ from collectoss.application.db.models import Config from collectoss.application.db.util import execute_session_query, convert_type_of_value from pathlib import Path +from collectoss.application.environment import SystemEnv + import logging def get_development_flag_from_config(): @@ -27,7 +29,7 @@ def get_development_flag_from_config(): return flag def get_development_flag(): - return os.getenv("AUGUR_DEV") or get_development_flag_from_config() or False + return SystemEnv.get("AUGUR_DEV") or get_development_flag_from_config() or False def redact_setting_value(section_name, setting_name, value): value_redacted = value if section_name != "Keys" else "REDACTED" @@ -167,7 +169,7 @@ def __init__(self, logger, session: DatabaseSession, config_sources: list = None JsonConfig(default_config, logger) ] - config_dir = Path(os.getenv("CONFIG_DATADIR", "./")) + config_dir = Path(SystemEnv.get("CONFIG_DATADIR") or "./") config_path = config_dir.joinpath("augur.json") if config_path.exists(): config_sources.append(JsonConfig(json.loads(config_path.read_text(encoding="UTF-8")), logger)) diff --git a/collectoss/application/environment.py b/collectoss/application/environment.py new file mode 100644 index 000000000..3a28c12a9 --- /dev/null +++ b/collectoss/application/environment.py @@ -0,0 +1,80 @@ +from typing import Optional +import os +import warnings +import logging + +logger = logging.getLogger(__name__) + +def extract_prefix(key: str, prefixes: list[str], separator = "_") -> Optional[str]: + """Detect and return the prefix present on the provided key + + Args: + key (str): the key to remove the prefix from + prefixes (list[str]): the prefixes to look for + separator (str, optional): the separator between elements of the key to also remove (if they would otherwise be dangling). Defaults to "_". + + Returns: + str: The detected prefix (including any separators) if any, otherwise None + """ + prefix_len = 0 + for p in prefixes: + p = p.upper() + k = key.upper() + if k.startswith(p): + prefix_len += len(p) + + if k[prefix_len] == separator: + prefix_len += len(separator) + return key[0:prefix_len] + return None + + +class SystemEnv: + """Centralized environment variable access + Built for enabling migration of environment variable names + """ + + _prefixes = ["COLLECTOSS", "AUGUR"] + _warn_prefixes = ["AUGUR"] + _separator = "_" + + @classmethod + def get(cls, key: str, default = None, prefixes = _prefixes) -> Optional[str]: + # extract the suffix so we can try multiple prefixes + canonical_prefix = extract_prefix(key, prefixes, cls._separator) + suffix = key[len(canonical_prefix):] if canonical_prefix is not None else key + # check prefixes in order and use the first one that has a value + for p in prefixes: + check_key = f"{p}{cls._separator}{suffix}" + value = os.getenv(check_key, None) + + if value is not None: + # emit a warning if configured + if p in cls._warn_prefixes: + msg = ( + f"Environment variable '{check_key}' is deprecated. " + f"Use '{key}' instead. This automatic recovery may become a failure in a future version " + ) + logger.warning(msg) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + return value + + if not canonical_prefix: + return os.getenv(key, default) + + return default + + @classmethod + def get_bool(cls, key:str, default: bool, prefixes = _prefixes) -> bool: + """gets a value from the environment and cast it to a boolean + """ + raw_val = cls.get(key, None, prefixes) + return raw_val.lower() in ('true', '1', 't', 'y', 'yes') if raw_val else default + + @classmethod + def set(cls, key: str, value: str, overwrite=True) -> None: + if os.getenv(key) is not None and not overwrite: + return + + os.environ[key] = value \ No newline at end of file diff --git a/collectoss/tasks/git/dependency_tasks/core.py b/collectoss/tasks/git/dependency_tasks/core.py index a9e74b4e1..21f24246a 100644 --- a/collectoss/tasks/git/dependency_tasks/core.py +++ b/collectoss/tasks/git/dependency_tasks/core.py @@ -2,6 +2,7 @@ import os from collectoss.application.db.models import * from collectoss.application.db.lib import bulk_insert_dicts, get_repo_by_repo_git, get_value, get_session +from collectoss.application.environment import SystemEnv from collectoss.tasks.github.util.github_api_key_handler import GithubApiKeyHandler from collectoss.tasks.git.dependency_tasks.dependency_util import dependency_calculator as dep_calc from collectoss.tasks.util.worker_util import parse_json_from_subprocess_call @@ -79,19 +80,11 @@ def generate_scorecard(logger, repo_git): command = '--repo=' + path #this is path where our scorecard project is located - path_to_scorecard = os.getenv('SCORECARD_DIR', os.environ['HOME'] + '/scorecard') + path_to_scorecard = SystemEnv.get('SCORECARD_DIR', (SystemEnv.get('HOME') or "~") + '/scorecard') #setting the environmental variable which is required by scorecard - - with get_session() as session: - #key_handler = GithubRandomKeyAuth(logger) - key_handler = GithubApiKeyHandler(logger) - os.environ['GITHUB_AUTH_TOKEN'] = key_handler.get_random_key() - - # This seems outdated - #setting the environmental variable which is required by scorecard - #key_handler = GithubApiKeyHandler(session, session.logger) - #os.environ['GITHUB_AUTH_TOKEN'] = key_handler.get_random_key() + key_handler = GithubApiKeyHandler(logger) + SystemEnv.set('GITHUB_AUTH_TOKEN', key_handler.get_random_key()) try: required_output = parse_json_from_subprocess_call(logger,['./scorecard', command, '--format=json'],cwd=path_to_scorecard) diff --git a/collectoss/tasks/git/scc_value_tasks/core.py b/collectoss/tasks/git/scc_value_tasks/core.py index 7c9e0bafd..a526af990 100644 --- a/collectoss/tasks/git/scc_value_tasks/core.py +++ b/collectoss/tasks/git/scc_value_tasks/core.py @@ -2,6 +2,7 @@ import os from collectoss.application.db.models import * from collectoss.application.db.lib import bulk_insert_dicts, get_repo_by_repo_git, get_value +from collectoss.application.environment import SystemEnv from collectoss.tasks.util.worker_util import parse_json_from_subprocess_call from collectoss.tasks.git.util.facade_worker.facade_worker.utilitymethods import get_absolute_repo_path @@ -20,7 +21,7 @@ def value_model(logger,repo_git): logger.info(f"Repo ID: {repo_id}, Path: {path}") logger.info('Running scc...') - path_to_scc = os.getenv('SCC_DIR', os.environ['HOME'] + '/scc') + path_to_scc = SystemEnv.get('SCC_DIR', (SystemEnv.get('HOME') or "~") + '/scc') required_output = parse_json_from_subprocess_call(logger,['./scc', '-f','json','--by-file', path], cwd=path_to_scc) diff --git a/collectoss/tasks/git/util/facade_worker/facade_worker/config.py b/collectoss/tasks/git/util/facade_worker/facade_worker/config.py index 7da6495bd..9db7d8866 100644 --- a/collectoss/tasks/git/util/facade_worker/facade_worker/config.py +++ b/collectoss/tasks/git/util/facade_worker/facade_worker/config.py @@ -40,11 +40,13 @@ from collectoss.application.db.lib import execute_sql from logging import Logger +from collectoss.application.environment import SystemEnv + logger = logging.getLogger(__name__) def get_database_args_from_env(): - db_str = os.getenv("AUGUR_DB") + db_str = SystemEnv.get("AUGUR_DB") try: db_json_file_location = os.getcwd() + "/db.config.json" except FileNotFoundError: diff --git a/collectoss/tasks/init/celery_app.py b/collectoss/tasks/init/celery_app.py index e14230f99..a33e1e961 100644 --- a/collectoss/tasks/init/celery_app.py +++ b/collectoss/tasks/init/celery_app.py @@ -63,7 +63,7 @@ tasks = start_tasks + github_tasks + gitlab_tasks + git_tasks + materialized_view_tasks + frontend_tasks -if os.environ.get('AUGUR_DOCKER_DEPLOY') != "1": +if SystemEnv.get('AUGUR_DOCKER_DEPLOY') != "1": tasks += data_analysis_tasks redis_db_number, redis_conn_string = get_redis_conn_values() diff --git a/collectoss/tasks/start_tasks.py b/collectoss/tasks/start_tasks.py index 644b6cbc4..38c721235 100644 --- a/collectoss/tasks/start_tasks.py +++ b/collectoss/tasks/start_tasks.py @@ -7,8 +7,9 @@ import sqlalchemy as s +from collectoss.application.environment import SystemEnv from collectoss.tasks.github import * -if os.environ.get('AUGUR_DOCKER_DEPLOY') != "1": +if SystemEnv.get('AUGUR_DOCKER_DEPLOY') != "1": from collectoss.tasks.data_analysis import * from collectoss.tasks.github.detect_move.tasks import detect_github_repo_move_core, detect_github_repo_move_secondary from collectoss.tasks.github.releases.tasks import collect_releases @@ -32,7 +33,7 @@ from collectoss.application.db.lib import execute_sql, get_session from collectoss.application.config import SystemConfig -RUNNING_DOCKER = os.environ.get('AUGUR_DOCKER_DEPLOY') == "1" +RUNNING_DOCKER = SystemEnv.get('AUGUR_DOCKER_DEPLOY') == "1" CELERY_GROUP_TYPE = type(group()) CELERY_CHAIN_TYPE = type(chain()) diff --git a/keyman/Orchestrator.py b/keyman/Orchestrator.py index 71cfae8bb..d93a1f064 100644 --- a/keyman/Orchestrator.py +++ b/keyman/Orchestrator.py @@ -4,15 +4,16 @@ import time from keyman.KeyOrchestrationAPI import spec, WaitKeyTimeout, InvalidRequest +from collectoss.application.environment import SystemEnv -if os.environ.get("KEYMAN_DOCKER"): +if SystemEnv.get("KEYMAN_DOCKER"): import sys import redis import logging sys.path.append("/collectoss") - conn = redis.Redis.from_url(os.environ.get("REDIS_CONN_STRING")) + conn = redis.Redis.from_url(SystemEnv.get("REDIS_CONN_STRING")) # Just log to stdout if we're running in docker logger = logging.Logger("KeyOrchestrator") diff --git a/tests/test_application/test_config/test_environment.py b/tests/test_application/test_config/test_environment.py new file mode 100644 index 000000000..6b62f2ec9 --- /dev/null +++ b/tests/test_application/test_config/test_environment.py @@ -0,0 +1,80 @@ +from collectoss.application.environment import SystemEnv, extract_prefix +import logging +import os + +logger = logging.getLogger(__name__) + +prefixes = ["COLLECTOSS", "OTHER"] + +def test_env_extract_prefix(): + assert extract_prefix("OTHER_DB", prefixes) == "OTHER_" + assert extract_prefix("COLLECTOSS_DB", prefixes) == "COLLECTOSS_" + +def test_env_extract_prefix_default(): + assert extract_prefix("SOME_DB", prefixes) is None + assert extract_prefix("THINGY_DB", prefixes) is None + + +def test_env_extract_prefix_unprefixed(): + assert extract_prefix("DB", prefixes) is None + +def test_fetching_env(): + # plain + os.environ["COLLECTOSS_NAME"] = "A" + assert SystemEnv.get("COLLECTOSS_NAME") == "A" + + # fallback handling + os.environ["OTHER_THING"] = "B" + assert SystemEnv.get("COLLECTOSS_THING", None, prefixes) == "B" + + # cleanup + del os.environ["COLLECTOSS_NAME"] + del os.environ["OTHER_THING"] + +def test_fetching_env_backwards(): + os.environ["COLLECTOSS_NAME"] = "A" + assert SystemEnv.get("OTHER_NAME", None, prefixes) == "A" + + # cleanup + del os.environ["COLLECTOSS_NAME"] + +def test_fetching_env_no_value(): + assert SystemEnv.get("COLLECTOSS_MISSING", None, prefixes) is None + +def test_fetching_env_default(): + assert SystemEnv.get("COLLECTOSS_DEFAULT", "SOME", prefixes) == "SOME" + +def test_no_known_prefix(): + # fallback handling + os.environ["THING"] = "C" + assert SystemEnv.get("THING", None, prefixes) == "C" + + +def test_get_bool_trues(): + + cases = ["1", "true", "True", "TRUE", "y", "Y", "yes", "Yes"] + + for case in cases: + os.environ["OTHER_BOOL"] = case + assert SystemEnv.get_bool("OTHER_BOOL", False, prefixes) == True + del os.environ["OTHER_BOOL"] + +def test_get_bool_falses(): + + cases = ["0", "false", "False", "FALSE", "n", "N", "no", "No"] + + for case in cases: + os.environ["OTHER_BOOL"] = case + assert SystemEnv.get_bool("OTHER_BOOL", True, prefixes) == False + del os.environ["OTHER_BOOL"] + +def test_get_bool_default(): + + cases = ["?", "maybe", "Stuff", "333"] + + for case in cases: + os.environ["OTHER_BOOL"] = case + assert SystemEnv.get_bool("OTHER_BOOL", False, prefixes) == False + del os.environ["OTHER_BOOL"] + +