Source code for dvc.core.config

import enum
import re
import traceback

import psycopg2
import yaml
from pathlib import Path
from typing import Dict, List, Any, Optional, Union
import logging
from psycopg2._psycopg import connection
import os

from dvc.core.database import SupportedDatabaseFlavour, DBConnLike
from dvc.core.regex import get_matched_files_in_folder_by_regex
from dvc.core.exception import RequestedDatabaseFlavourNotSupportedException, InvalidDatabaseRevisionFilesException, \
    EnvironmentVariableNotSetException, Operation
from dvc.core.struct import DatabaseRevisionFile, DatabaseVersion


[docs]class ConfigDefault: # Default keys for Environment variable KEY__DATABASE_REVISION_SQL_FILES_FOLDER = "DVC__DATABASE_REVISION_SQL_FILES_FOLDER" KEY__TARGET_SCHEMA = "DVC__TARGET_SCHEMA" KEY__USER = "DVC__USER" KEY__PASSWORD = "DVC__PASSWORD" KEY__HOST = "DVC__HOST" KEY__PORT = "DVC__PORT" KEY__DBNAME = "DVC__DBNAME" KEY__DBFLAVOUR = "DVC__DBFLAVOUR" KEY__LOGGING_LEVEL = "DVC__LOGGING_LEVEL" # Default values for environment variables VAL__DATABASE_REVISION_SQL_FILES_FOLDER = "sample_revision_sql_files" VAL__TARGET_SCHEMA = "dvc" VAL__USER = "" VAL__PASSWORD = "" VAL__HOST = "" VAL__PORT = 5432 VAL__DBNAME = "" VAL__DBFLAVOUR = "postgres" VAL__LOGGING_LEVEL: str = logging._levelToName[logging.INFO] # Default values for config file VAL__FilE_NAME: str = "config.yaml" VAL__FILE_PATH: Path = Path(VAL__FilE_NAME)
[docs] @classmethod def get_config_dict( cls, database_revision_sql_files_folder: str, target_schema: str, user: str, password: str, host: str, port: int, dbname: str, dbflavour: str, logging_level: int, as_file=False ): """ :param database_revision_sql_files_folder: :param user: :param password: :param host: :param port: :param dbname: :param dbflavour: :param logging_level: Assumed to be integer value :param as_file: whether to dump the dict as file. :return: """ CONFIG_DICT: Dict = { "logging_level": logging_level if not as_file else logging._levelToName[logging_level], "database_revision_sql_files_folder": database_revision_sql_files_folder, "target_schema": target_schema, "credentials": { "user": user, "password": password, "host": host, "port": port, "dbname": dbname, "dbflavour": dbflavour, } } return CONFIG_DICT
[docs]class ConfigFileWriter: """ Read Config Files (in different formats) to Python Dictionary """ def __init__(self, config_file_path: Union[Path, str] = ConfigDefault.VAL__FILE_PATH, ): if type(config_file_path) == str: self.config_file_path = Path(config_file_path) elif isinstance(config_file_path, Path): self.config_file_path = config_file_path else: raise TypeError( f"config file path must be of either type str or is instance of Path. Yours is {type(config_file_path)}")
[docs] def write_to_yaml(self) -> None: default_config_dict: Dict = ConfigDefault.get_config_dict( database_revision_sql_files_folder=ConfigDefault.VAL__DATABASE_REVISION_SQL_FILES_FOLDER, target_schema=ConfigDefault.VAL__TARGET_SCHEMA, user=ConfigDefault.VAL__USER, password=ConfigDefault.VAL__PASSWORD, host=ConfigDefault.VAL__HOST, port=ConfigDefault.VAL__PORT, dbname=ConfigDefault.VAL__DBNAME, dbflavour=ConfigDefault.VAL__DBFLAVOUR, logging_level=logging._nameToLevel[ConfigDefault.VAL__LOGGING_LEVEL], as_file=True ) if not self.config_file_path.exists(): logging.info(f"Now generating default config file {self.config_file_path}") with open(self.config_file_path, 'w') as default_config_file: yaml.dump(default_config_dict, default_config_file, default_flow_style=False) else: logging.info(f"{self.config_file_path} already exists! Do nothing.")
[docs]class ConfigReader: """ Read Config (in different formats) to Python Dictionary Precedence in descending order 1. Config File 2. Environment Variable """ def __init__(self, config_file_path: Union[Path, str] = ConfigDefault.VAL__FILE_PATH, ): if type(config_file_path) == str: self.config_file_path = Path(config_file_path) elif isinstance(config_file_path, Path): self.config_file_path = config_file_path else: raise TypeError(f"config file path must be of either type str or Path. Yours is {type(config_file_path)}") # read user config self.user_config = self._read_user_config() self.requested_db_flavour = self._read_requested_db_flavour() self.logging_level = self._read_logging_level() def _read_logging_level(self) -> int: user_config = self.user_config logging_level: int = user_config['logging_level'] return logging_level def _read_requested_db_flavour(self) -> str: user_config = self.user_config requested_db_flavour: str = user_config['credentials']['dbflavour'] return requested_db_flavour def _read_user_config(self) -> Dict: """ Check if config_file_path points to an existing file. If yes, read config from the file. If not, read config from env var. :return: """ if self.config_file_path.is_file() and self.config_file_path.exists(): user_config = self._read_from_yaml() else: user_config = self._read_from_environment() return user_config def _read_from_yaml(self) -> Dict: """ Read User Config from Yaml File """ logging.info(f"Reading config from file...") with open(self.config_file_path, 'r', encoding='utf-8') as config_file: user_config: Dict = yaml.load(config_file, Loader=yaml.FullLoader) try: # Assume is value user_config['logging_level'] = int(user_config['logging_level']) except ValueError as e: try: # Assume is string user_config['logging_level'] = logging._nameToLevel[user_config['logging_level']] except KeyError as e: logging.error("logging_level must be one of the below:") logging.error(logging._nameToLevel) raise return user_config def _read_from_environment(self) -> Dict: """ Read User Config from environment variables """ logging.info(f"Reading config from environment...") # Raise Key error if the environment variable is not set try: database_revision_sql_files_folder = os.environ[ConfigDefault.KEY__DATABASE_REVISION_SQL_FILES_FOLDER] target_schema = os.environ[ConfigDefault.KEY__TARGET_SCHEMA] host = os.environ[ConfigDefault.KEY__HOST] port = int(os.environ[ConfigDefault.KEY__PORT]) user = os.environ[ConfigDefault.KEY__USER] password = os.environ[ConfigDefault.KEY__PASSWORD] dbname = os.environ[ConfigDefault.KEY__DBNAME] dbflavour = os.environ[ConfigDefault.KEY__DBFLAVOUR] logging_level = os.environ[ConfigDefault.KEY__LOGGING_LEVEL] except KeyError as err: missing_env_var = err.args[0] raise EnvironmentVariableNotSetException(missing_env_var) # Convert logging_level try: # Assume is value logging_level = int(logging_level) except ValueError as e: try: # Assume is string logging_level = logging._nameToLevel[logging_level] except KeyError as e: logging.error("logging_level must be one of the below:") logging.error(logging._nameToLevel) raise user_config = ConfigDefault.get_config_dict( database_revision_sql_files_folder=database_revision_sql_files_folder, target_schema=target_schema, host=host, user=user, password=password, dbname=dbname, dbflavour=dbflavour, port=port, logging_level=logging_level ) return user_config
[docs]class DatabaseRevisionFilesManager: """ Manager all Database Revision Files """
[docs] class Pointer: """ Head: ALl the way to the latest """ HEAD = 'head' BASE = 'base'
def __init__(self, config_file_reader: ConfigReader, ): self.config_file_reader = config_file_reader self.database_revision_files_folder = self._get_database_revision_files_folder() self.all_database_revision_files = self._scan_database_revision_files() def _scan_database_revision_files(self) -> List[DatabaseRevisionFile]: """ Return all the available database revision files :return: """ candidate_database_revision_files: List[DatabaseRevisionFile] = [] database_revision_files_folder = self.database_revision_files_folder logging.debug("---Scanning database revision files----") for file_or_dir in database_revision_files_folder.glob('**/*'): file_or_dir: Path = file_or_dir logging.debug(file_or_dir) if file_or_dir.is_file(): candidate_database_revision_file = DatabaseRevisionFile(file_or_dir) candidate_database_revision_files.append(candidate_database_revision_file) logging.debug("---/Scanning database revision files----") return candidate_database_revision_files def _get_database_revision_files_folder(self) -> Path: """ Get database revision files folder :return: """ return Path(self.config_file_reader.user_config['database_revision_sql_files_folder'])
[docs] def create_database_revision_files_folder(self) -> None: """ Safely create the database revision files folder. """ database_revision_sql_folder = self.database_revision_files_folder database_revision_sql_folder_path = Path(database_revision_sql_folder) if database_revision_sql_folder_path.exists(): logging.info(f"{database_revision_sql_folder_path} already exists. Do nothing") else: logging.info("Generating database revision folder") database_revision_sql_folder_path.mkdir(parents=True)
def _raise_for_status(self, database_revision_files: List[DatabaseRevisionFile], steps: int, ) -> None: """ Raise Exception when number of database revision files are not the same as number of steps :param database_revision_files: :param steps: :return: """ # Step 3: Raise Error if number of returned revision files are different from the number of steps specified logging.debug(f"database revision files: {database_revision_files}") logging.debug(f"steps: {steps}") if len(database_revision_files) > abs(steps): raise InvalidDatabaseRevisionFilesException( config_file_path=self.database_revision_files_folder, status=InvalidDatabaseRevisionFilesException.Status.MORE_REVISION_SQL_FILES_FOUND_THAN_REQUIRED_STEPS_SPECIFIED, database_revision_file_paths=[actual_revision_file.file_path for actual_revision_file in database_revision_files], ) elif len(database_revision_files) < abs(steps): raise InvalidDatabaseRevisionFilesException( config_file_path=self.database_revision_files_folder, status=InvalidDatabaseRevisionFilesException.Status.FEWER_REVISION_SQL_FILES_FOUND_THAN_REQUIRED_STEPS_SPECIFIED, database_revision_file_paths=[actual_revision_file.file_path for actual_revision_file in database_revision_files], ) else: # All good pass
[docs] def get_target_database_revision_files_by_pointer( self, current_database_version: DatabaseVersion, candidate_database_revision_files: List[DatabaseRevisionFile], pointer: Pointer, ) -> List[DatabaseRevisionFile]: """ Given current database version and pointer, filter for target database revision files in the folder :param current_database_version: :param candidate_database_revision_files: :return: """ # Step 1: Deduce the number of steps current_database_version_number = current_database_version.version_number # create a reference database revision file if pointer == self.Pointer.HEAD: reference_database_revision_file = DatabaseRevisionFile.get_dummy_revision_file( revision=f'RV{current_database_version_number + 1}', operation_type=Operation.Upgrade, ) target_database_revision_files = [file for file in candidate_database_revision_files if file >= reference_database_revision_file] # Closest to current db version. Ascending order target_database_revision_files.sort(reverse=False) if len(target_database_revision_files) == 0: deduced_steps = 0 else: deduced_steps = target_database_revision_files[-1].revision_number - current_database_version_number elif pointer == self.Pointer.BASE: reference_database_revision_file = DatabaseRevisionFile.get_dummy_revision_file( revision=f'RV{current_database_version_number}', operation_type=Operation.Downgrade, ) # Closest to current db version. Descending order target_database_revision_files = [file for file in candidate_database_revision_files if file <= reference_database_revision_file] target_database_revision_files.sort(reverse=True) if len(target_database_revision_files) == 0: deduced_steps = 0 else: deduced_steps = current_database_version_number - target_database_revision_files[-1].revision_number + 1 else: raise ValueError(f"Unhandled Pointer {pointer}!") self._raise_for_status(database_revision_files=target_database_revision_files, steps=deduced_steps, ) return target_database_revision_files
[docs] def get_target_database_revision_files_by_steps( self, current_database_version: DatabaseVersion, steps: int, candidate_database_revision_files: List[DatabaseRevisionFile], ) -> List[DatabaseRevisionFile]: """ Given current database version and number of steps, filter for target database revision files in the folder. :return: """ # Step 1: Get a list of dummy database revision files current_database_version_number = current_database_version.version_number target_database_version = DatabaseVersion(version=f"V{current_database_version_number + steps}") dummy_revision_files: List[DatabaseRevisionFile] = target_database_version - current_database_version actual_revision_files: List[DatabaseRevisionFile] = [] # Step 2: Loop folder for actual files for dummy_revision_file in dummy_revision_files: logging.debug(f"Looking for Revision File with revision number {dummy_revision_file.revision_number}.....") for candidate_database_revision_file in candidate_database_revision_files: if dummy_revision_file == candidate_database_revision_file: actual_revision_files.append(candidate_database_revision_file) self._raise_for_status(database_revision_files=actual_revision_files, steps=steps, ) return actual_revision_files
[docs]class DatabaseConnectionFactory: """ Return connections for various databases """ MAPPING = { SupportedDatabaseFlavour.Postgres: 'self.pgconn' } def __init__(self, config_reader: ConfigReader, ): """ :param config_reader: Config Reader """ self.config_reader = config_reader
[docs] def validate_requested_database_flavour( self) -> SupportedDatabaseFlavour: """ Validate if requested database flavour is supported :return: """ try: supported_user_db_flavour = SupportedDatabaseFlavour(self.config_reader.requested_db_flavour) except ValueError as e: logging.error(traceback.format_exc()) raise RequestedDatabaseFlavourNotSupportedException( requested_database_flavour=self.config_reader.requested_db_flavour) return supported_user_db_flavour
@property def conn(self) -> DBConnLike: """ Return the expected connection object for different database flavours :return: """ supported_db_flavour = self.validate_requested_database_flavour() # Map Supported database flavours to different connections return eval(self.__class__.MAPPING[supported_db_flavour]) @property def pgconn(self) -> connection: """ Return Postgres Database Connection :return: """ dbname = self.config_reader.user_config['credentials']['dbname'] user = self.config_reader.user_config['credentials']['user'] password = self.config_reader.user_config['credentials']['password'] port = self.config_reader.user_config['credentials']['port'] host = self.config_reader.user_config['credentials']['host'] conn = psycopg2.connect(dbname=dbname, user=user, password=password, port=port, host=host) return conn