"""
Script to fill calibration database with filtering slowdb compton measurements
"""

import argparse
from configparser import ConfigParser
from datetime import datetime, timedelta, timezone
import sys
from typing import Tuple, List, Dict, Union, Optional
import warnings
import logging

import psycopg2
from psycopg2.extras import execute_values

class PostgreSQLHandler():
    """A common class for processing postgresql databases
    """
    
    def __init__(self, host: str = 'cmddb', database: str = 'slowdb', user: str = None, password: str = None):
        """
        Parameters
        ----------
        host : str
            host name (default is "cmddb")
        database : str
            database name (default is "slowdb")
        user : str
            username (default is None)
        password : str
            password (default is None)
        """
        
        self.conn = psycopg2.connect(host = host, database = database, user = user, password = password)
        self.cur = self.conn.cursor()
        logging.info("PostgreSQL Hander created")
        
    @property
    def list_tables(self) -> List[str]:
        """Returns list of existed tables in the compton measurements slowDB
        
        Returns
        -------
        List[str]
            list of tables
        """
        
        logging.info("Get list of the slowdb tables")
        self.cur.execute("""
        SELECT table_name FROM information_schema.tables
        WHERE table_schema = 'public'
        """)
        return list(map(lambda x: x[0], self.cur.fetchall()))

class SlowdbComptonHandler(PostgreSQLHandler):
    """A class for processing and filtering of compton measurements from slowdb
    """
    
    def __is_overlapped_row(self, start_time_next: datetime, stop_time_prev: datetime):
        gap = timedelta(seconds=2)
        if(start_time_next < stop_time_prev):
            logging.debug(f'time gap {abs(start_time_next - stop_time_prev)}')
        return start_time_next < stop_time_prev - gap
    
    def __drop_overlapping_rows_list(self, table: list) -> list:
        """Removes rows with overlapping time intervals from the table
        
        Parameters
        ----------
        table : list
            the table MUST BE ORDERED BY TIME where 5th column is start_time, 6th column is end_time
            
        Returns
        -------
        list
            clear table
        """
        
        if len(table) == 0:
            logging.info("Empty list. No overlapping rows")
            return table
        
        logging.info("Drop overlapping rows in list representation")
        table = table[::-1] # sort table by time from last to past
        min_time = table[0][6]
        overlapped_idxs = list()
        
        for idx, row in enumerate(table):
            start_time, stop_time = row[5], row[6]
            if self.__is_overlapped_row(min_time, stop_time):
                overlapped_idxs.append(idx)
            else:
                min_time = start_time
        
        for index in sorted(overlapped_idxs, reverse=True): # strict condition of the backward loop
            table.pop(index)
        
        return table[::-1]
    
    def load_tables(self, tables: List[str], daterange: Optional[datetime] = None):
        """Returns tables containing compton energy measurements
        
        Parameters
        ----------
        tables : List[str]
            names of tables in the slowdb compton measurements database 
            (full list of available tables can be seen with the property tables)
        daterange : Optional[datetime]
            minimum time for selection (should contain timezone)
        
        Returns
        -------
        Union[pd.DataFrame, list]
            table containing compton energy measurements with fields:
            write_time - time when the row was written (contains timezone)
            mean_energy - compton mean of the energy measurement [MeV]
            std_energy - compton std of the energy measurement [MeV]
            mean_spread - compton mean of the spread measurement [MeV]
            std_spread - compton std of the spread measurement [MeV]
            start_time - beginning time of the compton measurement (contains timezone)
            end_time - end time of the compton measurement (contains timezone)
        """
        
        time_condition = f"AND time>(%(date)s)" if daterange is not None else ""
        
        sql_query = lambda table: f"""SELECT 
            time AS time, 
            CAST(values_array[1] AS numeric) AS mean_energy, 
            CAST(values_array[2] AS numeric) AS std_energy, 
            ROUND(CAST(values_array[5]/1000 AS numeric), 6) AS mean_spread,
            ROUND(CAST(values_array[6]/1000 AS numeric), 6) AS std_spread, 
            date_trunc('second', time + (values_array[8] * interval '1 second')) AS start_time, 
            date_trunc('second', time + (values_array[8] * interval '1 second') + (values_array[7] * interval '1 second')) AS stop_time
            FROM {table} WHERE g_id=43 AND dt>0 {time_condition}"""
        
        full_sql_query = '\nUNION ALL\n'.join([sql_query(table) for table in tables]) + '\nORDER BY time;'
        
        logging.debug(f"Full sql query {full_sql_query}")
        
        self.cur.execute(full_sql_query, {'date': daterange})
        table = self.cur.fetchall()
        table = self.__drop_overlapping_rows_list(table)
        return table
    
class CalibrdbHandler(PostgreSQLHandler):
    """A class for processing of calibration database
    """
    
    def select_table(self, system: str, algo: str, name: str, version: str = 'Default') -> int:
        """Selects the table from database
        
        Parameters
        ----------
        system : str
            name of the system
        algo : str
            name of the algorithm
        name : str
            name of the calibration
        version : str
            name of the calibration version (default is Default)
            
        Returns
        -------
        sid : int
            value corresponding the table
        """
        
        self.cur.execute(f"""SELECT * FROM clbrset 
        WHERE system='{system}' AND algo='{algo}' AND name='{name}' AND version='{version}'""")
        result = self.cur.fetchall()
        logging.debug(f"selected clbrset: {result}")
        if len(result) > 1:
            logging.warning('Multiple equal calibration sets. clbrset DB problem')
            return result[0]
        sid = result[0][0]
        return sid
    
    def load_table(self, system: str, algo: str, name: str, version: str = 'Default', 
                   num_last_rows: Optional[int] = None, timerange: Optional[Tuple[datetime, datetime]] = None, 
                   return_timezone: bool = False) -> Tuple[list, list]:
        """Loads the calibration table
        
        Parameters
        ----------
        system : str
            name of the system
        algo : str
            name of the algorithm
        name : str
            name of the calibration
        version : str
            name of the calibration version (default is Default)
        num_last_rows : Optional[int]
            the number of last rows of the table
        timerange : Optional[Tuple[datetime, datetime]]
            time range condition on the selection of the table (default is None)
        return_timezone : bool
            return timezone in output datetimes as a field or not (default is False)
        
        Returns
        -------
        Tuple[list, list]
            the calibration table and name of fields
        """
        
        sid = self.select_table(system, algo, name, version)
        time_condition = "AND begintime BETWEEN %s AND %s" if timerange is not None else ""
        tzone = "AT TIME ZONE 'ALMST'" if return_timezone else ''
        sql_query = f"""SELECT 
        cid, sid, createdby, 
        time {tzone} AS time, 
        begintime {tzone} AS begintime, 
        endtime {tzone} AS endtime, 
        comment, parameters, data
        FROM clbrdata WHERE sid={sid} {time_condition} ORDER BY time DESC """
        if num_last_rows is not None:
            sql_query += f"LIMIT {num_last_rows}"
            
        if timerange is None:
            self.cur.execute(sql_query)
        else:
            self.cur.execute(sql_query, timerange)
        fields_name = [i[0] for i in self.cur.description]
        table = self.cur.fetchall()
        return table, fields_name
        
    def update(self, new_rows: list, system: str = "Misc", algo: str = "RunHeader", 
               name: str = "Compton_run", version: str = 'Default', handle_last_time_row: bool = False):
        """Writes new_rows in clbrdb
        
        Parameters
        ----------
        new_rows : list
            list of the data for writing
        handle_last_time_row : bool
            (DANGEROUS PLACE - keep default False or don't commit changes if you don't know what you want)
            update current values or not: replace all values in interval from min(begintime in new_rows) to max(endtime in new_rows)
        """
        
        if len(new_rows) == 0:
            return
        
        sid = self.select_table(system, algo, name, version)
        
        new_rows = list(map(lambda x: (sid, 'lxeuser', x[0], x[5], x[6], [x[1], x[2], x[3], x[4]]), new_rows))
        
        if handle_last_time_row:
            min_new_time, max_new_time = min(map(lambda x: x[3], new_rows)), max(map(lambda x: x[4], new_rows))
            self.delete_rows(sid = sid, createdby = 'lxeuser', time = (min_new_time, max_new_time))
        
        insert_query = """INSERT INTO clbrdata (sid, createdby, time, begintime, endtime, data) VALUES %s;"""
        execute_values(self.cur, insert_query, new_rows, fetch=False)
        logging.info(f"Inserted {len(new_rows)} new rows")
        return
    
    def insert(self, new_rows: list, system: str, algo: str, name: str, version: str, 
               update: bool = True, comment: Optional[str] = None):
        """Insert new_rows in the table
        
        Parameters
        ----------
        new_rows : list
            list of new rows in the follwing format
        update : bool
            update current calibration
        comment : Optional[str]
            common comment field
        """
        
        sid = self.select_table(system, algo, name, version)
        
        if update:
            update_query = f"""UPDATE clbrdata 
            SET data = %(data)s, createdby = %(createdby)s, time = %(time)s, begintime = %(begintime)s, endtime = %(endtime)s
            WHERE sid = %(sid)s AND comment = %(comment)s
            """
            for x in new_rows:
                season_point = (comment if comment is not None else '') + '_' + str(x[3])
                dict_row = {
                    'sid': sid,
                    'createdby': 'lxeuser',
                    'time': x[0],
                    'begintime': x[1],
                    'endtime': x[2],
                    'comment': season_point,
                    'data': x[3:],
                }
                self.cur.execute(update_query, dict_row)
        
        insert_query = """INSERT INTO clbrdata (sid, createdby, time, begintime, endtime, comment, data) VALUES %s"""
        comment_creator = lambda x: f'{comment if comment is not None else ""}_{str(x[3])}'
        insert_rows = list(map(lambda x: (sid, 'lxeuser', x[0], x[1], x[2], comment_creator(x), x[3:]), new_rows))        
        execute_values(self.cur, insert_query, insert_rows, fetch=False)        
        
        drop_query = f"""
            DELETE FROM clbrdata a
            USING clbrdata b
            WHERE
                a.sid = {sid}
                AND a.cid > b.cid
                AND a.sid = b.sid
                AND a.comment = b.comment
        """
        self.cur.execute(drop_query)
        
        logging.info(f"Inserted {len(insert_rows)} rows into table: {system}/{algo}/{name}/{version}")
        return
    
    def clear_table(self, sid: int, createdby: str):
        delete_query = f"""DELETE FROM clbrdata WHERE sid = %s AND createdby = %s"""
        logging.info(f"Clear ({sid}, {createdby}) table")
        self.cur.execute(delete_query, (sid, createdby))
        return
        
    def delete_row(self, sid: int, createdby: str, time: datetime):
        delete_query = f"""DELETE FROM clbrdata 
        WHERE sid = %s AND createdby = %s AND time = %s
        """
        self.cur.execute(delete_query, (sid, createdby, time))
        logging.info(f"Deleted ({sid}, {createdby}, {time}) row")
        return
    
    def delete_rows(self, sid: int, createdby: str, time: Tuple[datetime, datetime]):
        delete_query = f"""DELETE FROM clbrdata 
        WHERE sid = %s AND createdby = %s AND endtime > %s AND begintime < %s
        """
        self.cur.execute(delete_query, (sid, createdby, time[0], time[1]))
        logging.info(f"Deleted ({sid}, {createdby} from {time[0]} to {time[1]}) rows")
        return
    
    def remove_duplicates(self, system: str = "Misc", algo: str = "RunHeader", name: str = "Compton_run", version: str = 'Default', keep: str = 'last'):
        sid = self.select_table(system, algo, name, version)
        
        keep_rule = ''
        if keep == 'last':
            keep_rule = '<'
        elif keep == 'first':
            keep_rule = '>'
        else:
            raise ValueError("keep argument must be 'last' or 'first'")
        
        remove_query = f"""
            DELETE FROM clbrdata a
            USING clbrdata b
            WHERE
                a.sid = {sid}
                AND a.cid {keep_rule} b.cid
                AND a.sid = b.sid
                AND a.time = b.time                
        """
        self.cur.execute(remove_query)
        pass
    
    def commit(self):
        logging.info("Changes commited")
        self.conn.commit()
        return

    def rollback(self):
        logging.info("Changes aborted")
        self.conn.rollback()
        return
    
    def __del__(self):
        logging.info("del clbr class")
        self.cur.close()
        self.conn.close()
    
    
def main():
    log_format = '[%(asctime)s] %(levelname)s: %(message)s'
    logging.basicConfig(stream=sys.stdout, format=log_format, level=logging.INFO) #"filename=compton_filter.log"
    logging.info("Program started")
    
    parser = argparse.ArgumentParser(description = 'Filter compton energy measurements from slowdb')
    parser.add_argument('--season', help = 'Name of compton measurement table from slowdb')
    parser.add_argument('--config', help = 'Config file containing information for access to databases')
    parser.add_argument('--update', action = 'store_true', help = 'Writes only newest values into the db')
    
    args = parser.parse_args()
    logging.info(f"Arguments: season: {args.season}, config {args.config}, update {args.update}")

    parser = ConfigParser()
    parser.read(args.config);
    logging.info("Config parsed")
    
    clbrdb = CalibrdbHandler(**parser['clbrDB'])
    last_written_row, _ = clbrdb.load_table('Misc', 'RunHeader', 'Compton_run', num_last_rows = 1, return_timezone = True)
    last_time = last_written_row[0][3] if (len(last_written_row) > 0) and (args.update) else None
    
    compton_slowdb = SlowdbComptonHandler(**parser['postgresql'])
    res = compton_slowdb.load_tables([args.season], last_time)
    
    clbrdb.update(res, handle_last_time_row = args.update)
    clbrdb.commit()
    del clbrdb
    
# python scripts/compton_filter.py --season cmd3_2021_2 --config database.ini --update
if __name__ == "__main__":
    main()