Skip to content
Snippets Groups Projects
Commit a03b8a52 authored by Matteo Westerwinter's avatar Matteo Westerwinter
Browse files

remove dublicate code and fixed some edge cases also integrated not_in_tr...

remove dublicate code and fixed some edge cases also integrated not_in_tr filtering into the sql querry
parent 9aea36a6
No related branches found
No related tags found
2 merge requests!227merge dev into testing,!226enable aggregated filtering
......@@ -14,6 +14,7 @@ from fastapi import File, UploadFile
from fastapi.responses import JSONResponse
import datetime as dt
import json
from functools import reduce
from . import models
from .models import TimeseriesChangelog, timeseries_timeseries_roles_table, \
timeseries_timeseries_annotations_table
......@@ -139,9 +140,12 @@ def get_citation(db: Session, timeseries_id: int, datetime: dt.datetime = None):
license_txt = "This data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/"
return {"attribution": attribution, "citation": citation, "license": license_txt}
def adapt_db_object(db_object_immut, fields_str = "", lconstr_roles = False):
if fields_str:
db_object = dict(zip(fields_str.split(","), db_object_immut))
def adapt_db_object(db_object, db, fields=False, lconstr_roles=False):
if fields:
db_object = dict(
zip((field for field in fields if field not in {"station_changelog", "changelog"}), db_object)
)
# there is a mismatch with coordinates and additional_metadata
if "coordinates" in db_object:
db_object["coordinates"] = get_coordinates_from_string(db_object["coordinates"])
......@@ -172,95 +176,61 @@ def adapt_db_object(db_object_immut, fields_str = "", lconstr_roles = False):
roles_atts = {key: value for key, value in db_object.items() if key in roles_params}
db_object = {key: value for key, value in db_object.items() if key not in roles_params}
db_object["roles"] = TimeseriesRoleFields(**roles_atts)
return db_object
else:
if isinstance(db_object_immut.station.coordinates, (WKBElement, WKTElement)):
db_object_immut.station.coordinates = get_coordinates_from_geom(db_object_immut.station.coordinates)
if isinstance(db_object.station.coordinates, (WKBElement, WKTElement)):
db_object.station.coordinates = get_coordinates_from_geom(db_object.station.coordinates)
# there is a mismatch with additional_metadata
if isinstance(db_object_immut.station.additional_metadata, dict):
db_object_immut.station.additional_metadata = json.dumps(db_object_immut.station.additional_metadata)
db_object_immut.additional_metadata = clean_additional_metadata(db_object_immut.additional_metadata)
if isinstance(db_object.station.additional_metadata, dict):
db_object.station.additional_metadata = json.dumps(db_object.station.additional_metadata)
db_object.additional_metadata = clean_additional_metadata(db_object.additional_metadata)
#Internall use
try:
del db_object_immut.data_license_accepted
del db_object.data_license_accepted
except AttributeError:
pass
try:
del db_object_immut.dataset_approved_by_provider
del db_object.dataset_approved_by_provider
except AttributeError:
pass
return db_object_immut
return db_object
class TimeseriesQuery:
def __init__(self, sign, query):
def __init__(self, sign, query, fields, lconstr_roles):
self.sign = sign
self.query = query
self.fields = fields
self.lconstr_roles = lconstr_roles
def __add__(self, other):
return TimeseriesQuery(True, self.query.union_all(other.query) if other.sign else self.query.except_all(other.query))
def from_query_params(query_params, db, endpoint = False, sign = True):
try:
limit, offset, fields, format, filters = create_filter(query_params, endpoint)
t_filter = filters["t_filter"]
t_r_filter = filters["t_r_filter"]
s_c_filter = filters["s_c_filter"]
s_g_filter = filters["s_g_filter"]
except (KeyError, ValueError) as e:
status_code = 400
return JSONResponse(status_code=status_code, content=str(e))
query = (
db.query(models.Timeseries)
.filter(text(t_filter))
.distinct()
.join(StationmetaCore)
.filter(and_(models.Timeseries.station_id == StationmetaCore.id, text(s_c_filter)))
.join(StationmetaGlobal)
.filter(and_(StationmetaCore.id == StationmetaGlobal.station_id, text(s_g_filter)))
.filter(
and_(
(models.Timeseries.id == timeseries_timeseries_roles_table.c.timeseries_id),
(models.timeseries_timeseries_roles_table.c.role_id == models.TimeseriesRole.id),
(models.TimeseriesRole.contact_id == Contact.id),
(Contact.organisation_id == Organisation.id),
(Contact.person_id == Person.id),
(Variable.id == models.Timeseries.variable_id),
text(t_r_filter),
)
if self.fields != other.fields:
raise ValueError("Fields of subquerys are diffrent")
return TimeseriesQuery(
True,
self.query.union_all(other.query) if other.sign else self.query.except_all(other.query),
self.fields,
self.lconstr_roles or other.lconstr_roles,
)
.execution_options(stream_results=True)
)
if limit:
query = query.limit(limit).offset(offset)
return TimeseriesQuery(sign, query)
def __radd__(self, other):
return self if other == 0 else self + other # To make sum() work
def search_all(db, path_params, query_params, lts=False):
endpoint = "timeseries" if lts else "search"
try:
@staticmethod
def from_query_params(query_params, db, endpoint="timeseries", sign=True):
limit, offset, fields, format, filters = create_filter(query_params, endpoint)
t_filter = filters["t_filter"]
t_r_filter = filters["t_r_filter"]
s_c_filter = filters["s_c_filter"]
s_g_filter = filters["s_g_filter"]
except (KeyError, ValueError) as e:
status_code = 400
return JSONResponse(status_code=status_code, content=str(e))
if fields:
# If only certain fields are selected the return type is not a orm object anymore but a dict
# sort input fields to be sure to replace station_changelog before changelog
fields = sorted(fields.split(","), reverse=True)
lconstr_roles = any(field in roles_params for field in fields)
if "role" in fields:
......@@ -278,11 +248,7 @@ def search_all(db, path_params, query_params, lts=False):
"coordinates": "ST_AsText(coordinates)",
"station_country": "stationmeta_core.country",
}
fields_str = ",".join(field for field in fields if field not in ("station_changelog", "changelog"))
fields_lst = [text(field_map.get(field, field)) for field in fields]
query = db.query(*fields_lst)
query = db.query(*[text(field_map.get(field, field)) for field in fields])
# ordering is needed (because of limit/offset-option)
# --> Timeseries.id is added to columns while doing distinct! (see: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#migration-20-query-distinct)
......@@ -290,11 +256,11 @@ def search_all(db, path_params, query_params, lts=False):
# hot fix to speed up the search
else:
lconstr_roles = False
fields_str = ""
query = db.query(models.Timeseries)
query = (
query.select_from(models.Timeseries)
query
.select_from(models.Timeseries)
.filter(text(t_filter))
.distinct()
.join(StationmetaCore)
......@@ -316,33 +282,49 @@ def search_all(db, path_params, query_params, lts=False):
.order_by(models.Timeseries.id)
)
# is required like this due to test and previous implementation seems weird
# order by defeats purpose of limit limited data retrieval
if "NOT" in t_r_filter:
role_string = query_params.get("has_role")[1:]
role_ids = get_role_id_from_string(db, role_string)
query = query.filter(
~models.Timeseries.id.in_(
select(timeseries_timeseries_roles_table.c.timeseries_id).where(
timeseries_timeseries_roles_table.c.role_id.in_(role_ids)
)
)
)
if limit:
query = query.limit(limit).offset(offset)
return TimeseriesQuery(sign, query, fields, lconstr_roles)
not_in_tr = "NOT" in t_r_filter
if not_in_tr:
role_string = query_params.get("has_role")[1:]
role_ids = get_role_id_from_string(db, role_string)
def adapt_objects(self, db):
return [adapt_db_object(db_object, db, self.fields, self.lconstr_roles) for db_object in self.query]
return [
adapt_db_object(db_object_immut, fields_str, lconstr_roles)
for db_object_immut in query.all()
if not not_in_tr or not all(role[1] in role_ids for role in get_timeseries_roles(db, db_object_immut.id))
]
def search_all(db, path_params, query_params, lts=False):
endpoint = "timeseries" if lts else "search"
try:
return TimeseriesQuery.from_query_params(query_params, db, endpoint).adapt_objects(db)
except (KeyError, ValueError) as e:
status_code = 400
return JSONResponse(status_code=status_code, content=str(e))
def search_all_aggreagtion(db, path_params, signs, query_params_list, lts=False):
endpoint = "timeseries" if lts else "search"
queries = [
try:
return sum(
TimeseriesQuery.from_query_params(query_params, db, endpoint, sign)
for sign, query_params in zip(signs, query_params_list)
]
).adapt_objects(db)
return [adapt_db_object(db_object_immut) for db_object_immut in sum(queries[1:], queries[0]).query]
except (KeyError, ValueError) as e:
status_code = 400
return JSONResponse(status_code=status_code, content=str(e))
def get_timeseries_by_unique_constraints(db: Session, station_id: int, variable_id: int, resource_provider: str = None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment