import inspect
from datetime import datetime
from geoalchemy2.elements import WKTElement
from iterfzf import iterfzf
from prompt_toolkit.shortcuts import prompt
from sqlalchemy import or_
from sqlalchemy.sql.functions import func
from pepys_import.core.store import sqlite_db
from pepys_import.core.store.db_status import TableTypes
[docs]def row_to_dict(table_object, data_store):
"""Converts all entities of a table into a dict of {column_name: value}s.
:param table_object: A table object
:type table_object: sqlalchemy.ext.declarative.DeclarativeMeta
:param data_store: A :class:`DataStore` object
:type data_store: DataStore
:return: Returns a dictionary with values
:rtype: Dict
"""
with data_store.session_scope():
values = data_store.session.query(table_object).all()
objects = list()
for row in values:
d = {column.name: getattr(row, column.name) for column in row.__table__.columns}
objects.append(d)
return objects
[docs]def find_sqlite_table_object(table_object, data_store):
"""Finds and returns a SQLite Base class which will be used to create and insert values.
:param table_object: A table object
:type table_object: sqlalchemy.ext.declarative.DeclarativeMeta
:param data_store: A :class:`DataStore` object
:type data_store: DataStore
:return: Returns a table object
:rtype: sqlalchemy.ext.declarative.DeclarativeMeta
"""
if data_store.db_type == "postgres":
for name, obj in inspect.getmembers(sqlite_db):
if inspect.isclass(obj) and name == table_object.__name__:
return obj
else:
return table_object
[docs]def get_time_from_user(prompt_text):
valid = False
while not valid:
time_str = prompt(f"{prompt_text} (YYYY-MM-DD HH:MM:SS): ")
try:
time_obj = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
valid = True
except ValueError:
print("Invalid time entered, please try again")
return time_obj
[docs]def get_number_from_user(prompt_text):
valid = False
while not valid:
value = prompt(prompt_text)
try:
float_value = float(value)
valid = True
except ValueError:
print("Invalid number entered, please try again")
return float_value
[docs]def export_reference_tables(source_store, destination_store, table_objects):
"""Copies table objects from :code:`source_store` to :code:`destination_store`.
:param source_store: A :class:`DataStore` object to fetch objects
:type source_store: DataStore
:param destination_store: A :class:`DataStore` object to copy the objects from source_store
:type destination_store: DataStore
:param table_objects: A list of table objects
:type table_objects: List
:return:
"""
for table_object in table_objects:
dict_values = row_to_dict(table_object, source_store)
object_ = find_sqlite_table_object(table_object, source_store)
with destination_store.session_scope():
destination_store.session.bulk_insert_mappings(object_, dict_values)
[docs]def export_all_measurement_tables(source_store, destination_store):
measurement_table_objects = source_store.meta_classes[TableTypes.MEASUREMENT]
with source_store.session_scope():
with destination_store.session_scope():
for table in measurement_table_objects:
export_measurement_table_with_filter(source_store, destination_store, table)
[docs]def export_measurement_table_with_filter(source_store, destination_store, table, filter=None):
if isinstance(table, str):
table_object = getattr(source_store.db_classes, table)
else:
table_object = table
dict_values = []
query = source_store.session.query(table_object)
if filter is not None:
query = filter(table_object, query)
results = query.all()
if len(results) == 0:
return
data_attributes = []
for col in table_object.__table__.columns:
prop = getattr(table_object, col.name).property
data_attributes.append(prop.key)
for row in results:
d = {}
for attrib in data_attributes:
d[attrib] = getattr(row, attrib)
dict_values.append(d)
object_ = find_sqlite_table_object(table, source_store)
with destination_store.session_scope():
destination_store.session.bulk_insert_mappings(object_, dict_values)
print(f"Exported {len(results)} entries from {table.__name__}")
[docs]def export_measurement_tables_filtered_by_time(
source_store, destination_store, start_time, end_time
):
def time_attribute_filter(table_object, query):
"""Filters on the time attribute of a table, so works for State, Contact, LogsHolding and Comment"""
query = query.filter(table_object.time >= start_time)
query = query.filter(table_object.time <= end_time)
return query
def start_end_attribute_filter(table_object, query):
return _start_end_filter(table_object, query, start_time, end_time)
with source_store.session_scope():
with destination_store.session_scope():
tables_with_time_attribute = [
source_store.db_classes.State,
source_store.db_classes.Contact,
source_store.db_classes.Comment,
source_store.db_classes.LogsHolding,
source_store.db_classes.Media,
]
for table in tables_with_time_attribute:
export_measurement_table_with_filter(
source_store, destination_store, table, time_attribute_filter
)
tables_with_start_end_attributes = [
source_store.db_classes.Activation,
source_store.db_classes.Geometry1,
]
for table in tables_with_start_end_attributes:
export_measurement_table_with_filter(
source_store,
destination_store,
table,
start_end_attribute_filter,
)
[docs]def export_measurement_tables_filtered_by_location(
source_store, destination_store, xmin, ymin, xmax, ymax
):
def location_attribute_filter(table_object, query):
# Note: We can't use the ST_MakeEnvelope function, as it's not supported by spatialite
# so we have to create the WKT polygon manually. This is only done once for the filter
# so it shouldn't have an efficiency impact
wkt = f"POLYGON(({xmin} {ymin},{xmin} {ymax},{xmax} {ymax},{xmax} {ymin},{xmin} {ymin}))"
query = query.filter(
func.ST_Within(
table_object.location,
WKTElement(wkt, srid=4326),
)
)
return query
def geometry_attribute_filter(table_object, query):
# Note: We can't use the ST_MakeEnvelope function, as it's not supported by spatialite
# so we have to create the WKT polygon manually. This is only done once for the filter
# so it shouldn't have an efficiency impact
wkt = f"POLYGON(({xmin} {ymin},{xmin} {ymax},{xmax} {ymax},{xmax} {ymin},{xmin} {ymin}))"
query = query.filter(
func.ST_Within(
table_object.geometry,
WKTElement(wkt, srid=4326),
)
)
return query
with source_store.session_scope():
with destination_store.session_scope():
tables_with_location = [
source_store.db_classes.State,
source_store.db_classes.Contact,
source_store.db_classes.Media,
]
for table in tables_with_location:
export_measurement_table_with_filter(
source_store, destination_store, table, location_attribute_filter
)
export_measurement_table_with_filter(
source_store,
destination_store,
source_store.db_classes.Geometry1,
geometry_attribute_filter,
)
[docs]def export_measurement_tables_filtered_by_wargame_participation(
source_store, destination_store, selected_wargame
):
def wargame_participation_filter_with_time(table_object, query):
# Note: We can't use the ST_MakeEnvelope function, as it's not supported by spatialite
# so we have to create the WKT polygon manually. This is only done once for the filter
# so it shouldn't have an efficiency impact
wargame_start = selected_wargame.start
wargame_end = selected_wargame.end
query = query.filter(
source_store.db_classes.Platform.wargame_participations_objects.contains(
selected_wargame
),
table_object.time >= wargame_start,
table_object.time <= wargame_end,
)
return query
def wargame_participation_filter_with_start_end(table_object, query):
wargame_start = selected_wargame.start
wargame_end = selected_wargame.end
query = query.filter(
source_store.db_classes.Platform.wargame_participations_objects.contains(
selected_wargame
)
)
query = _start_end_filter(table_object, query, wargame_start, wargame_end)
return query
with source_store.session_scope():
with destination_store.session_scope():
tables_with_time = [
source_store.db_classes.State,
source_store.db_classes.Contact,
source_store.db_classes.Comment,
]
for table in tables_with_time:
export_measurement_table_with_filter(
source_store, destination_store, table, wargame_participation_filter_with_time
)
export_measurement_table_with_filter(
source_store,
destination_store,
source_store.db_classes.Activation,
wargame_participation_filter_with_start_end,
)
[docs]def export_measurement_tables_filtered_by_serial_participation(
source_store, destination_store, selected_serial
):
with source_store.session_scope():
with destination_store.session_scope():
# Get a list of all Platforms taking part in the serial, and their participation times
participants = source_store.session.query(
source_store.db_classes.SerialParticipant
).filter(
source_store.db_classes.SerialParticipant.serial_id == selected_serial.serial_id
)
for participant in participants:
platform_id = participant.platform.platform_id
start_time = (
participant.start if participant.start is not None else selected_serial.start
)
end_time = participant.end if participant.end is not None else selected_serial.end
print(f"{platform_id}: {start_time}, {end_time}")
def filter_function_time(table_object, query):
query = query.filter(
table_object.platform_id == platform_id,
table_object.time >= start_time,
table_object.time <= end_time,
)
return query
def filter_function_start_end(table_object, query):
query = query.filter(table_object.platform_id == platform_id)
query = _start_end_filter(table_object, query, start_time, end_time)
return query
tables_with_time = [
source_store.db_classes.State,
source_store.db_classes.Contact,
source_store.db_classes.Comment,
]
for table in tables_with_time:
export_measurement_table_with_filter(
source_store, destination_store, table, filter_function_time
)
export_measurement_table_with_filter(
source_store,
destination_store,
source_store.db_classes.Activation,
filter_function_start_end,
)
def _start_end_filter(table_object, query, start_time, end_time):
query = query.filter(
or_(
( # Deal with entries where start is missing, so just do a standard 'between' search on end
(table_object.start == None) # noqa
& (table_object.end >= start_time)
& (table_object.end <= end_time)
),
( # Deal with entries where end is missing, so just do a standard 'between' search on start
(table_object.end == None) # noqa
& (table_object.start >= start_time)
& (table_object.start <= end_time)
),
( # Deal with entries where we have both start and end, and want to test all overlap possibilities
(table_object.start != None)
& (table_object.end != None)
& (
((table_object.start >= start_time) & (table_object.end <= end_time))
| (
(table_object.start >= start_time)
& (table_object.start <= end_time)
& (table_object.end >= end_time)
)
| (
(table_object.start <= start_time)
& (table_object.end <= end_time)
& (table_object.end >= start_time)
)
| ((table_object.start <= start_time) & (table_object.end >= end_time))
)
),
)
)
return query
def _select_wargame(data_store):
results = data_store.session.query(
data_store.db_classes.Wargame.name,
data_store.db_classes.Wargame.wargame_id,
).all()
wargame_dict = dict(results)
if len(wargame_dict) == 0:
print("No wargames defined")
return None
selected_wargame_name = iterfzf(wargame_dict.keys(), prompt="Select wargame: ")
selected_wargame_id = wargame_dict[selected_wargame_name]
return selected_wargame_id