Source code for pepys_admin.utils

import json
import os
import sqlite3
import sys
from contextlib import contextmanager

import pint
from sqlalchemy.sql.schema import UniqueConstraint

from paths import MIGRATIONS_DIRECTORY
from pepys_import.utils.sqlalchemy_utils import get_primary_key_for_table
from pepys_import.utils.text_formatting_utils import (
    custom_print_formatted_text,
    format_error_message,
)


[docs]def get_default_export_folder(): current_folder_name = os.path.basename(os.path.normpath(os.getcwd())) if current_folder_name == "bin": return os.path.expanduser("~") else: return os.getcwd()
[docs]def round_object_if_necessary(obj): if isinstance(obj, pint.quantity._Quantity) or isinstance(obj, float): return round(obj, 3) else: return obj
[docs]def sqlalchemy_obj_to_dict(obj, remove_id=False): """Converts a SQLAlchemy result from a query into a dict of {column_name: value}s, excluding the 'created_date' column. This is used for tests. To make the tests work on machines that round floats differently, we round the objects if necessary before putting them in the dict. This deals with issues we have if we have a Quantity with a value of 5.0000000024 from Postgres and a value of 5.0000 on SQLite. """ d = { column.name: round_object_if_necessary(getattr(obj, column.name)) for column in obj.__table__.columns } del d["created_date"] if remove_id: pri_key_col_name = get_primary_key_for_table(obj) del d[pri_key_col_name] return d
[docs]def check_sqlalchemy_results_are_equal(results1, results2): """Compare two lists of SQLAlchemy results to see if they are equal""" list1 = [sqlalchemy_obj_to_dict(item) for item in results1] list2 = [sqlalchemy_obj_to_dict(item) for item in results2] return list1 == list2
[docs]def make_query_for_cols(table_object, comparison_object, columns, session): """Create a SQLAlchemy query for the given table_object, with a filter comparing the given columns with the given comparison_object. For example, if the comparison object contains values {'name': GPS, 'host': 42, 'type':12, 'blah': 'hello}, and the columns are ['name', 'host'] then this will return a query like this: session.query(table_object).filter(table_object.name == "GPS").filter(table_object.host == 42) """ query = session.query(table_object) for col_name in columns: query = query.filter( getattr(table_object, col_name) == getattr(comparison_object, col_name) ) return query
[docs]def make_query_for_unique_cols_or_all(table_object, comparison_object, session): """Create a SQLAlchemy query object for the given table_object, with a filter comparing it to the comparison object. The filter will use just the columns defined in the unique constraint for the table if a unique constraint is defined, otherwise it will compare all columns. """ unique_constraints = [ c for c in table_object.__table__.constraints if isinstance(c, UniqueConstraint) ] if len(unique_constraints) == 0: return make_query_for_all_data_columns(table_object, comparison_object, session) elif len(unique_constraints) == 1: unique_col_names = [c.name for c in unique_constraints[0].columns] return make_query_for_cols(table_object, comparison_object, unique_col_names, session)
[docs]def make_query_for_all_data_columns(table_object, comparison_object, session): """Makes a query to search for an object where all data columns match the comparison object In this case, the data columns are all columns excluding the primary key and the created_date column. """ primary_key = get_primary_key_for_table(table_object) column_names = [col.name for col in table_object.__table__.columns.values()] # Get rid of the primary key column from the list column_names.remove(primary_key) # And get rid of the created_date column column_names.remove("created_date") # And get rid of the privacy column, if it exists if "privacy_id" in column_names: column_names.remove("privacy_id") query = session.query(table_object) for col_name in column_names: # Don't add a filter to match any columns which have missing values in the comparison object if getattr(comparison_object, col_name) is None: continue query = query.filter( getattr(table_object, col_name) == getattr(comparison_object, col_name) ) return query
[docs]def get_name_for_obj(obj): """Return a 'name' field for an object. Most objects have a field called `name`, so we try this first. If this fails, we try `reference` (for Datafiles) and `synonym` (for Synonyms), otherwise we just return 'Unknown'. """ if "name" in obj.__dict__: return obj.name elif "reference" in obj.__dict__: return obj.reference elif "synonym" in obj.__dict__: return obj.synonym else: return "Unknown"
[docs]def statistics_to_table_data(statistics): """Convert a dictionary of statistics data into tuples ready for displaying as a table with the tabulate function.""" return [ (k, v["already_there"], v["added"], v["modified"]) for k, v in sorted(statistics.items()) ]
[docs]def create_statistics_from_ids(ids): """Create a statistics dictionary from a dict of ids/details for items added, modified and already_there""" return { "added": len(ids["added"]), "modified": len([item for item in ids["modified"] if item["data_changed"]]), "already_there": len(ids["already_there"]), }
[docs]def database_at_latest_revision(db_path): try: conn = sqlite3.connect(db_path) result = conn.execute("SELECT version_num from alembic_version;") slave_version = next(result)[0] conn.close() except Exception: custom_print_formatted_text( format_error_message( "Could not read schema revision from database - is this a valid Pepys database file?" ) ) return False versions = read_latest_revisions_file() if len(versions) == 0: return False if slave_version == versions["LATEST_SQLITE_VERSION"]: return True return False
[docs]def read_latest_revisions_file(): try: with open(os.path.join(MIGRATIONS_DIRECTORY, "latest_revisions.json"), "r") as file: versions = json.load(file) except Exception: custom_print_formatted_text(format_error_message("Could not find latest_revisions.json")) return {} if "LATEST_SQLITE_VERSION" not in versions: custom_print_formatted_text( format_error_message("Latest revision IDs couldn't found from latest_revisions.json") ) return {} return versions
[docs]class StdoutAndFileWriter: def __init__(self, filename): self.terminal = sys.stdout self.log = open(filename, "a") # Pretend not to be a terminal
[docs] def isatty(self): return False
[docs] def write(self, message): self.terminal.write(message) self.log.write(message)
[docs] def close(self): self.log.close()
[docs] def flush(self): # this flush method is needed for python 3 compatibility. # this handles the flush command by doing nothing. # you might want to specify some extra behavior here. pass
[docs]@contextmanager def redirect_stdout_to_file_and_screen(filename): old_stdout = sys.stdout old_stderr = sys.stderr sys.stdout = StdoutAndFileWriter(filename) sys.stderr = sys.stdout yield sys.stdout.close() sys.stdout = old_stdout sys.stderr = old_stderr