Source code for dvc.core.database.postgres

import psycopg2
from psycopg2._psycopg import connection
from pathlib import Path
from typing import Optional, Tuple

from dvc.core.database import SQLFileExecutorTemplate
from dvc.core.struct import DatabaseRevisionFile, DatabaseVersion
from dvc.core.hash import FileHasher


[docs]class PostgresSQLFileExecutor(SQLFileExecutorTemplate): METADATA_SQL_FOLDER_PATH = Path(__file__).parent FILE_HASHER = FileHasher() def __init__(self, db_conn: connection, target_schema: str, ): super(PostgresSQLFileExecutor, self).__init__(db_conn, target_schema) self.cur = self.conn.cursor() self.target_schema = target_schema
[docs] def set_up_database_revision_control_tables(self): """ Create all database revision control schema and tables :return: """ with open(self.__class__.METADATA_SQL_FOLDER_PATH.joinpath("scm_dvc__create_scm_and_tbls.sql"), 'r', encoding='utf-8') as create_sql_file: create_sql = create_sql_file.read().format(target_schema=self.target_schema) self.cur.execute(query=create_sql) self.conn.commit()
[docs] def get_latest_database_version(self) -> DatabaseVersion: """ Get the latest database version :return: """ sql_file_path = self.__class__.METADATA_SQL_FOLDER_PATH.joinpath("scm_dvc__select_latest_database_version.sql") with open(sql_file_path, 'r', encoding='utf-8') as select_latest_database_version_sql_file: select_latest_database_version_sql = select_latest_database_version_sql_file.read().format( target_schema=self.target_schema) self.cur.execute(query=select_latest_database_version_sql) result: Optional[Tuple] = self.cur.fetchone() if result is None: # Nothing is in the table latest_database_version: DatabaseVersion = DatabaseVersion( version="V0", created_at=None, ) else: current_version, _, _, created_at = result latest_database_version: DatabaseVersion = DatabaseVersion( version=current_version, created_at=created_at ) return latest_database_version
[docs] def execute_database_revision(self, database_revision_file: DatabaseRevisionFile ): """ Execute database revision and write to database version control tables :param database_revision_file: :return: """ with open(database_revision_file.file_path, 'r', encoding='utf-8', ) as sql_file: sql = sql_file.read() self.cur.execute(sql) self.conn.commit() self._write_database_revision_metadata(database_revision_file=database_revision_file)
def _write_database_revision_metadata(self, database_revision_file: DatabaseRevisionFile ): """ Write a given database revision to database version control tables :param database_revision_file: :return: """ executed_sql_file_folder = str(database_revision_file.file_path.parent) executed_sql_file_name = str(database_revision_file.file_path.name) executed_sql_file_content_hash = self.__class__.FILE_HASHER.md5(database_revision_file.file_path) with open(database_revision_file.file_path, 'r', encoding='utf-8') as executed_sql_file: executed_sql_file_content = executed_sql_file.read() with open(self.__class__.METADATA_SQL_FOLDER_PATH.joinpath("scm_dvc__insert_tbl_database_revision_history.sql"), 'r', encoding='utf-8') as insert_sql_file: insert_sql = insert_sql_file.read().format(target_schema=self.target_schema) self.cur.execute(query=insert_sql, vars=(executed_sql_file_folder, executed_sql_file_name, executed_sql_file_content_hash, executed_sql_file_content, database_revision_file.operation_type.name)) self.conn.commit()