diff --git a/toardb/timeseries/crud.py b/toardb/timeseries/crud.py index 906259861ebf56ebd1f28c166f849080fac3a697..83e661ae96eb0a5248cbae9a5a8f859f5e72c53c 100644 --- a/toardb/timeseries/crud.py +++ b/toardb/timeseries/crud.py @@ -14,6 +14,7 @@ from fastapi import File, UploadFile from fastapi.responses import JSONResponse import datetime as dt import json +from functools import reduce from . import models from .models import TimeseriesChangelog, timeseries_timeseries_roles_table, \ timeseries_timeseries_annotations_table @@ -139,9 +140,12 @@ def get_citation(db: Session, timeseries_id: int, datetime: dt.datetime = None): license_txt = "This data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/" return {"attribution": attribution, "citation": citation, "license": license_txt} -def adapt_db_object(db_object_immut, fields_str = "", lconstr_roles = False): - if fields_str: - db_object = dict(zip(fields_str.split(","), db_object_immut)) +def adapt_db_object(db_object, db, fields=False, lconstr_roles=False): + if fields: + db_object = dict( + zip((field for field in fields if field not in {"station_changelog", "changelog"}), db_object) + ) + # there is a mismatch with coordinates and additional_metadata if "coordinates" in db_object: db_object["coordinates"] = get_coordinates_from_string(db_object["coordinates"]) @@ -172,52 +176,91 @@ def adapt_db_object(db_object_immut, fields_str = "", lconstr_roles = False): roles_atts = {key: value for key, value in db_object.items() if key in roles_params} db_object = {key: value for key, value in db_object.items() if key not in roles_params} db_object["roles"] = TimeseriesRoleFields(**roles_atts) - - return db_object - else: - if isinstance(db_object_immut.station.coordinates, (WKBElement, WKTElement)): - db_object_immut.station.coordinates = get_coordinates_from_geom(db_object_immut.station.coordinates) + if isinstance(db_object.station.coordinates, (WKBElement, WKTElement)): + db_object.station.coordinates = get_coordinates_from_geom(db_object.station.coordinates) # there is a mismatch with additional_metadata - if isinstance(db_object_immut.station.additional_metadata, dict): - db_object_immut.station.additional_metadata = json.dumps(db_object_immut.station.additional_metadata) - db_object_immut.additional_metadata = clean_additional_metadata(db_object_immut.additional_metadata) + if isinstance(db_object.station.additional_metadata, dict): + db_object.station.additional_metadata = json.dumps(db_object.station.additional_metadata) + db_object.additional_metadata = clean_additional_metadata(db_object.additional_metadata) #Internall use try: - del db_object_immut.data_license_accepted + del db_object.data_license_accepted except AttributeError: pass try: - del db_object_immut.dataset_approved_by_provider + del db_object.dataset_approved_by_provider except AttributeError: pass - return db_object_immut + return db_object class TimeseriesQuery: - def __init__(self, sign, query): + def __init__(self, sign, query, fields, lconstr_roles): self.sign = sign self.query = query + self.fields = fields + self.lconstr_roles = lconstr_roles def __add__(self, other): - return TimeseriesQuery(True, self.query.union_all(other.query) if other.sign else self.query.except_all(other.query)) + if self.fields != other.fields: + raise ValueError("Fields of subquerys are diffrent") + + return TimeseriesQuery( + True, + self.query.union_all(other.query) if other.sign else self.query.except_all(other.query), + self.fields, + self.lconstr_roles or other.lconstr_roles, + ) - def from_query_params(query_params, db, endpoint = False, sign = True): - try: - limit, offset, fields, format, filters = create_filter(query_params, endpoint) - t_filter = filters["t_filter"] - t_r_filter = filters["t_r_filter"] - s_c_filter = filters["s_c_filter"] - s_g_filter = filters["s_g_filter"] - except (KeyError, ValueError) as e: - status_code = 400 - return JSONResponse(status_code=status_code, content=str(e)) + def __radd__(self, other): + return self if other == 0 else self + other # To make sum() work + + @staticmethod + def from_query_params(query_params, db, endpoint="timeseries", sign=True): + limit, offset, fields, format, filters = create_filter(query_params, endpoint) + t_filter = filters["t_filter"] + t_r_filter = filters["t_r_filter"] + s_c_filter = filters["s_c_filter"] + s_g_filter = filters["s_g_filter"] + + if fields: + # If only certain fields are selected the return type is not a orm object anymore but a dict + # sort input fields to be sure to replace station_changelog before changelog + fields = sorted(fields.split(","), reverse=True) + lconstr_roles = any(field in roles_params for field in fields) + + if "role" in fields: + fields.remove("role") + fields += roles_params + # lconstr_glob = True #used to be set here but is unused + + field_map = { + "id": "timeseries.id", + "order": "timeseries.order", + "additional_metadata": "timeseries.additional_metadata", + "station_id": "stationmeta_core.id", + "variable_id": "variables.id", + "name": "stationmeta_core.name", + "coordinates": "ST_AsText(coordinates)", + "station_country": "stationmeta_core.country", + } + query = db.query(*[text(field_map.get(field, field)) for field in fields]) + + # ordering is needed (because of limit/offset-option) + # --> Timeseries.id is added to columns while doing distinct! (see: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#migration-20-query-distinct) + # (duplicates are being created, which means that limit/offset in the query itself is useless!) + # hot fix to speed up the search + else: + lconstr_roles = False + query = db.query(models.Timeseries) query = ( - db.query(models.Timeseries) + query + .select_from(models.Timeseries) .filter(text(t_filter)) .distinct() .join(StationmetaCore) @@ -236,113 +279,52 @@ class TimeseriesQuery: ) ) .execution_options(stream_results=True) + .order_by(models.Timeseries.id) ) + if "NOT" in t_r_filter: + role_string = query_params.get("has_role")[1:] + role_ids = get_role_id_from_string(db, role_string) + query = query.filter( + ~models.Timeseries.id.in_( + select(timeseries_timeseries_roles_table.c.timeseries_id).where( + timeseries_timeseries_roles_table.c.role_id.in_(role_ids) + ) + ) + ) + if limit: query = query.limit(limit).offset(offset) - return TimeseriesQuery(sign, query) + return TimeseriesQuery(sign, query, fields, lconstr_roles) + + def adapt_objects(self, db): + return [adapt_db_object(db_object, db, self.fields, self.lconstr_roles) for db_object in self.query] def search_all(db, path_params, query_params, lts=False): endpoint = "timeseries" if lts else "search" + try: - limit, offset, fields, format, filters = create_filter(query_params, endpoint) - t_filter = filters["t_filter"] - t_r_filter = filters["t_r_filter"] - s_c_filter = filters["s_c_filter"] - s_g_filter = filters["s_g_filter"] + return TimeseriesQuery.from_query_params(query_params, db, endpoint).adapt_objects(db) + except (KeyError, ValueError) as e: status_code = 400 return JSONResponse(status_code=status_code, content=str(e)) - if fields: - # sort input fields to be sure to replace station_changelog before changelog - - fields = sorted(fields.split(","), reverse=True) - - lconstr_roles = any(field in roles_params for field in fields) - - if "role" in fields: - fields.remove("role") - fields += roles_params - # lconstr_glob = True #used to be set here but is unused - - field_map = { - "id": "timeseries.id", - "order": "timeseries.order", - "additional_metadata": "timeseries.additional_metadata", - "station_id": "stationmeta_core.id", - "variable_id": "variables.id", - "name": "stationmeta_core.name", - "coordinates": "ST_AsText(coordinates)", - "station_country": "stationmeta_core.country", - } - - fields_str = ",".join(field for field in fields if field not in ("station_changelog", "changelog")) - fields_lst = [text(field_map.get(field, field)) for field in fields] - - query = db.query(*fields_lst) - - # ordering is needed (because of limit/offset-option) - # --> Timeseries.id is added to columns while doing distinct! (see: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#migration-20-query-distinct) - # (duplicates are being created, which means that limit/offset in the query itself is useless!) - # hot fix to speed up the search - else: - lconstr_roles = False - fields_str = "" - query = db.query(models.Timeseries) - - query = ( - query.select_from(models.Timeseries) - .filter(text(t_filter)) - .distinct() - .join(StationmetaCore) - .filter(and_(models.Timeseries.station_id == StationmetaCore.id, text(s_c_filter))) - .join(StationmetaGlobal) - .filter(and_(StationmetaCore.id == StationmetaGlobal.station_id, text(s_g_filter))) - .filter( - and_( - (models.Timeseries.id == timeseries_timeseries_roles_table.c.timeseries_id), - (models.timeseries_timeseries_roles_table.c.role_id == models.TimeseriesRole.id), - (models.TimeseriesRole.contact_id == Contact.id), - (Contact.organisation_id == Organisation.id), - (Contact.person_id == Person.id), - (Variable.id == models.Timeseries.variable_id), - text(t_r_filter), - ) - ) - .execution_options(stream_results=True) - .order_by(models.Timeseries.id) - ) - - # is required like this due to test and previous implementation seems weird - # order by defeats purpose of limit limited data retrieval - if limit: - query = query.limit(limit).offset(offset) - - - not_in_tr = "NOT" in t_r_filter - if not_in_tr: - role_string = query_params.get("has_role")[1:] - role_ids = get_role_id_from_string(db, role_string) - - return [ - adapt_db_object(db_object_immut, fields_str, lconstr_roles) - for db_object_immut in query.all() - if not not_in_tr or not all(role[1] in role_ids for role in get_timeseries_roles(db, db_object_immut.id)) - ] - def search_all_aggreagtion(db, path_params, signs, query_params_list, lts=False): endpoint = "timeseries" if lts else "search" - queries = [ - TimeseriesQuery.from_query_params(query_params, db, endpoint, sign) - for sign, query_params in zip(signs, query_params_list) - ] - - return [adapt_db_object(db_object_immut) for db_object_immut in sum(queries[1:], queries[0]).query] + try: + return sum( + TimeseriesQuery.from_query_params(query_params, db, endpoint, sign) + for sign, query_params in zip(signs, query_params_list) + ).adapt_objects(db) + + except (KeyError, ValueError) as e: + status_code = 400 + return JSONResponse(status_code=status_code, content=str(e)) def get_timeseries_by_unique_constraints(db: Session, station_id: int, variable_id: int, resource_provider: str = None,