#!/usr/bin/python3

from __future__ import print_function

import os
import re
import time
import random
import signal
import hashlib
import sqlite3
import optparse
import collections

from gratia.common.Gratia import DebugPrint
from gratia.common.debug import DebugPrintTraceback
import gratia.common.GratiaCore as GratiaCore
import gratia.common.GratiaWrapper as GratiaWrapper
import gratia.common.Gratia as Gratia
#import gratia.common.file_utils as file_utils
import gratia.common.config as config

import classad
import htcondor

probe_version = "2.9.1-2.osg25.el9.26ec412"

probe_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))

# default locations
probe_config = "/etc/gratia/%s/ProbeConfig" % probe_name
db_path = "/var/lib/gratia/%s/data.db" % probe_name

VO = 'osg'  # lowercase
LOCAL_USER = "osgvo-container-pilot"

grid_type = "OSG"

# Limit the number of outbox jobs delivered per run of the cron job
OUTBOX_MESSAGES_PER_JOB=10000

# in seconds
DONE_IF_NOT_SEEN_AFTER =  1 * 60 * 60
POLL_INTERVAL          = 15 * 60
PARTIAL_RECORD_CUTOFF  = 24 * 60 * 60

# log levels
CRITICAL = 0
ERROR    = 1
WARNING  = 2
INFO     = 3
DEBUG    = 4

WANT_PSLOT_ATTRS = [
    'DaemonStartTime',
    'Cpus',
    'GLIDEIN_Site',
    'GLIDEIN_ResourceName',
    'Machine',
    'Name',
    'SlotType',
    'ChildName',
    'SlotId',
    'MyAddress',
]

WANT_DSLOT_ATTRS = [
    'CPUsUsage',
    'MemoryUsage',
]

WANT_ATTRS = WANT_PSLOT_ATTRS + WANT_DSLOT_ATTRS

WANT_ATTRS_IDX = { v:i for i,v in enumerate(WANT_ATTRS) }

SEND_ATTRS = [
    'StartTime',     # DaemonStartTime
    'EndTime',       # <last seen time, after gone for N hours>
    'WallDuration',  # <EndTime or Now> - <StartTime>
    'CpuDuration',   #(User): CPUsUsage * <polling interval in minutes>
                     # + <previous CPU total> across all dynamic slots carved
                     #                        off from the partitionable slot
    'Processors',    # Cpus
    'Memory',        # MemoryUsage
    'SiteName',      # GLIDEIN_ResourceName (or GLIDEIN_Site)
    'MachineName',   # Machine
]

OutboxRow = collections.namedtuple('OutboxRow', ['ID'] + SEND_ATTRS)


def coljoin(cols):
    return "\n, ".join(cols)

def fmt_coljoin(fmt, cols):
    return coljoin(map(fmt.format, cols))


schema_sql = [
    """
    create table jobs
    ( last_updated
    , {jobs_cols}
    , total_cpu
    , unique (Name, MyAddress)
    );
    """.format(jobs_cols=coljoin(WANT_ATTRS)),


    """
    create table slots
    ( pslot
    , dslot
    , unique (pslot, dslot)
    , unique (dslot)
    );
    """,


    """
    create table inbox
    ( next_updated
    , {inbox_cols}
    , unique (Name, MyAddress)
    );
    """.format(inbox_cols=coljoin(WANT_ATTRS)),

    """
    create table inbox_archive
    ( next_updated
    , {inbox_cols}
    , unique (Name, MyAddress, next_updated)
    );
    """.format(inbox_cols=coljoin(WANT_ATTRS)),

    """
    create index inbox_match
        on inbox (MyAddress, SlotType, SlotId);
    """,

    """
    create table outbox
    ( id integer primary key autoincrement
    , {outbox_cols}
    );
    """.format(outbox_cols=coljoin(SEND_ATTRS)),


    """
    create view pslot_totals as
    select p.next_updated
         , {pslot_cols}
         , {dslot_cols}
--       , count(*) as n_dslots
      from inbox p
      left join slots s
        on p.name = s.pslot
--     and p.SlotType = 'Partitionable'
      left join inbox d
        on d.name = s.dslot
--     and d.SlotType = 'Dynamic'
     group by p.name
       ;
    """.format(pslot_cols=fmt_coljoin("p.{}", WANT_PSLOT_ATTRS),
               dslot_cols=fmt_coljoin("sum(d.{0}) as {0}", WANT_DSLOT_ATTRS)),


    """
    create view pslot_totals2 as
    select p.next_updated
         , {pslot_cols}
         , {dslot_cols}
--       , count(*) as n_dslots
      from inbox p
      left join inbox d
        on p.MyAddress = d.MyAddress
       and p.SlotId = d.SlotId
       and p.SlotType = 'P'  -- 'Partitionable'
       and d.SlotType = 'D'  -- 'Dynamic'
     group by p.name
       ;
    """.format(pslot_cols=fmt_coljoin("p.{}", WANT_PSLOT_ATTRS),
               dslot_cols=fmt_coljoin("sum(d.{0}) as {0}", WANT_DSLOT_ATTRS)),


    """
    create view updates as
    select a.*
         , b.name                      as b_name
         , b.next_updated
         , b.CPUsUsage                 as next_CPUsUsage
         , b.MemoryUsage               as next_MemoryUsage
         , next_updated - last_updated as elapsed_interval
      from jobs a
--    left join inbox b
      left join pslot_totals b
        on a.name = b.name;
    """,


    """
    create view finished as
    select *
      from updates
     where b_name is null;
    """,


    """
    create view finished2 as
    select DaemonStartTime                as StartTime
         , last_updated                   as EndTime
         , last_updated - DaemonStartTime as WallDuration
         , total_cpu                      as CpuDuration
         , Cpus                           as Processors
         , MemoryUsage                    as Memory
         , ifnull(GLIDEIN_ResourceName,
                  GLIDEIN_Site)           as SiteName
         , Machine                        as MachineName
      from finished;
    """,


    """
    create view updated as
    select *
         , next_CPUsUsage * elapsed_interval as next_cpu
         , max(MemoryUsage,next_MemoryUsage) as next_mem
      from updates
     where b_name is not null;
    """,


    """
    create view new_jobs as
    select b.*
--    from inbox b
      from pslot_totals b
      left join jobs a
        on a.name = b.name
     where a.name is null;
    """,


    """
    create view cutoff as
    select DaemonStartTime                as StartTime
         , last_updated                   as EndTime
         , last_updated - DaemonStartTime as WallDuration
         , total_cpu                      as CpuDuration
         , Cpus                           as Processors
         , MemoryUsage                    as Memory
         , ifnull(GLIDEIN_ResourceName,
                  GLIDEIN_Site)           as SiteName
         , Machine                        as MachineName
      from jobs
     where last_updated - DaemonStartTime >= {partial_record_cutoff};
    """.format(partial_record_cutoff=PARTIAL_RECORD_CUTOFF),

]


# query collector for machine (startd) ads

def query_current_attrs(opts, pool):
    coll = htcondor.Collector(pool)

    filter_cond = 'SlotType != "Static"'
    if not opts.any_records:
        filter_cond += ' && IsOsgVoContainer =?= True'

    ads = coll.query(ad_type=htcondor.AdTypes.Startd, projection=WANT_ATTRS,
                     constraint=filter_cond)
    return ads


def get_db(opts):
    dbpath = opts.db_path
    db_exists = os.path.exists(dbpath)

    sqldb = sqlite3.connect(dbpath)
    if not db_exists:
        for sql in schema_sql:
            sqldb.execute(sql)
        sqldb.commit()

    return sqldb


def sha1sum(s):
    return hashlib.sha1(s.encode("utf-8")).hexdigest()

def slottype_shortname(s):
    return s[0] if s in ("Partitionable", "Dynamic", "Static") else s

def eval_expr(x):
    if isinstance(x, classad.ExprTree):
        # TODO: Per htcondor2 docs, eval() with no positional arg may yield unexpected results.
        # We must confirm continued functionality before updating.
        x = x.eval()
    if isinstance(x, list):
        x = "\t".join(x)
    return x


def mangle_job_attrs(vals, attr_idx):
    filters = [
        ("SlotType",  slottype_shortname),
        ("MyAddress", sha1sum),
    ]
    for attr, fn in filters:
        if attr in attr_idx:
            idx = attr_idx[attr]
            vals[idx] = fn(vals[idx])

def qmarks(n):
    return ",".join(["?"] * n)

def write_ads_to_sqlite(sqldb, ads, current_ts):
    DebugPrint(INFO, "Writing new ads to inbox")

    n_attrs = len(WANT_ATTRS) + 1
    insert_sql = "insert into inbox values (%s)" % qmarks(n_attrs)

    def job2vals(job):
        vals = list(map(eval_expr, map(job.get, WANT_ATTRS)))
        mangle_job_attrs(vals, WANT_ATTRS_IDX)
        return [current_ts] + vals

    rc = sqldb.executemany(insert_sql, map(job2vals, ads)).rowcount

    DebugPrint(INFO, "- inserted %s rows into inbox" % rc)


def archive_new_ads(sqldb, ads, current_ts):
    DebugPrint(INFO, "Archiving new ads")

    insert_sql = "insert into inbox_archive select * from inbox"
    rc = sqldb.execute(insert_sql).rowcount

    DebugPrint(INFO, "- inserted %s rows into inbox_archive" % rc)


def get_pslots_children(ads):
    for ad in ads:
        if 'ChildName' in ad:
            yield ad['Name'], ad['ChildName']


def get_pslots_dslots(items):
    for pslot, dslots in items:
        for dslot in dslots:
            yield pslot, dslot

def get_pslots_dslots2(ads):
    for ad in ads:
        if 'ChildName' in ad:
            for dslot in ad['ChildName']:
                yield ad['Name'], dslot

def update_slots_info(sqldb, ads):
    DebugPrint(INFO, "Updating slot info")

    delete_sql = "delete from slots;"
    insert_sql = "insert into slots values (?,?);"

    rc1 = sqldb.execute(delete_sql).rowcount
    rc2 = sqldb.executemany(insert_sql, get_pslots_dslots2(ads)).rowcount

    DebugPrint(INFO, "- deleted %s rows from slots" % rc1)
    DebugPrint(INFO, "- inserted %s rows into slots" % rc2)


def detect_finished_jobs(sqldb, current_ts):
    DebugPrint(INFO, "Detecting finished jobs")

    dead_cutoff = current_ts - DONE_IF_NOT_SEEN_AFTER

    insert_sql = """
        insert into outbox (%s)
        select *
          from finished2
         where EndTime < ?
           and WallDuration > 0;
    """ % ", ".join(SEND_ATTRS)

    delete_sql = """
        delete from jobs
         where exists ( select 1
                          from finished b
                         where jobs.name = b.name
                           and b.last_updated < ?);
    """

    rc1 = sqldb.execute(insert_sql, (dead_cutoff,)).rowcount
    rc2 = sqldb.execute(delete_sql, (dead_cutoff,)).rowcount

    DebugPrint(INFO, "- inserted %s rows into outbox" % rc1)
    DebugPrint(INFO, "- deleted %s rows from jobs" % rc2)


def update_active_jobs(sqldb):
    DebugPrint(INFO, "Updating active jobs")

    # TODO: simplify with UPDATE FROM syntax after SQLite 3.33.0 (2020-08-14)
    update_sql = """
        update jobs
           set last_updated = (
                        select next_updated
                          from updated b
                         where jobs.name = b.name),
               total_cpu = total_cpu + (
                        select next_cpu
                          from updated b
                         where jobs.name = b.name),
               MemoryUsage = (
                        select next_mem
                          from updated b
                         where jobs.name = b.name)
         where exists ( select 1
                          from inbox b
                         where jobs.name = b.name);
    """

    delete_sql = """
        delete from inbox
         where exists ( select 1
                          from jobs a
                         where a.name = inbox.name);
    """

    rc1 = sqldb.execute(update_sql).rowcount
    rc2 = sqldb.execute(delete_sql).rowcount

    DebugPrint(INFO, "- updated %s rows from jobs" % rc1)
    DebugPrint(INFO, "- deleted %s rows from inbox" % rc2)



def detect_new_jobs(sqldb, current_ts):
    DebugPrint(INFO, "Detecting new jobs...")

    insert_sql = """
        insert into jobs
        select *
             , (? - DaemonStartTime) * CPUsUsage as total_cpu
--           , min(?, (? - DaemonStartTime)) * CPUsUsage as total_cpu
          from new_jobs;
    """

    delete_sql = """
        delete from inbox;
    """

    rc1 = sqldb.execute(insert_sql, (current_ts,)).rowcount
#   rc1 = sqldb.execute(insert_sql, (POLL_INTERVAL, current_ts,)).rowcount

    # XXX: not before update_active_jobs
    rc2 = sqldb.execute(delete_sql).rowcount

    DebugPrint(INFO, "- inserted %s rows into jobs" % rc1)
    DebugPrint(INFO, "- deleted %s rows from inbox" % rc2)


def detect_and_update_partial_jobs(sqldb):
    DebugPrint(INFO, "Reporting and breaking up partial jobs")

    insert_sql = """
        insert into outbox (%s)
        select *
          from cutoff
    """ % ", ".join(SEND_ATTRS)

    # reset aggregates
    update_sql = """
        update jobs
           set DaemonStartTime = last_updated,
               total_cpu = 0,
               MemoryUsage = 0
         where last_updated - DaemonStartTime >= {partial_record_cutoff};
    """.format(partial_record_cutoff=PARTIAL_RECORD_CUTOFF)

    rc1 = sqldb.execute(insert_sql).rowcount
    rc2 = sqldb.execute(update_sql).rowcount

    DebugPrint(INFO, "- inserted %s rows into outbox" % rc1)
    DebugPrint(INFO, "- updated %s rows from jobs" % rc2)
    if rc1 != rc2:
        DebugPrint(ERROR, "Counts do not match for partial records"
                          " inserted into inbox and updated from jobs")


def query_and_update_db(opts):
    current_ts = int(time.time())

    def _startd_collector_hash(self):
        """Hash StartD ads the same way the Collector does
        """
        try:
            # The Collector will fallback to Machine and SlotID concatenated
            # together but TJ said we could just get away with Name + MyAddress
            # as this simplifies the SQL column uniqueness constraints
            return hash(','.join([str(self['Name']),
                                  str(self['MyAddress'])]))
        except KeyError:
            # This could happen if the StartD ad is missing Name or MyAddress attrs,
            # which the Collector should not allow to happen
            return 1

    def _startd_collector_eq(self, other):
        return self['Name'] == other['Name'] and self['MyAddress'] == other['MyAddress']

    _old_eq = classad.ClassAd.__eq__
    classad.ClassAd.__eq__ = _startd_collector_eq
    classad.ClassAd.__hash__ = _startd_collector_hash

    unique_ads = set()
    for pool in ('cm-1.ospool.osg-htc.org', 'cm-2.ospool.osg-htc.org', 'flock.opensciencegrid.org'):
        try:
            unique_ads = unique_ads.union(query_current_attrs(opts, pool))
        except htcondor.HTCondorIOError:
            DebugPrint(WARNING, f"Could not contact collector to query ads: {pool}")
            pass  # skip collectors that we can't contact

    ads = list(unique_ads)

    # restore the default hashing and equality behavior
    classad.ClassAd.__eq__ = _old_eq
    classad.ClassAd.__hash__ = None

    sqldb = get_db(opts)
    write_ads_to_sqlite(sqldb, ads, current_ts)
    #archive_new_ads(sqldb, ads, current_ts)  # uncomment to keep full history
    update_slots_info(sqldb, ads)
    #sqldb.commit()
    #return

    detect_finished_jobs(sqldb, current_ts)
    update_active_jobs(sqldb)
    detect_new_jobs(sqldb, current_ts) # XXX: must go after update_active_jobs
    detect_and_update_partial_jobs(sqldb)
    sqldb.commit()


def site_ProbeName(site):
    site_dns = re.sub(r'[^a-zA-Z0-9-]', '-', site).strip('-')  # sanitize site
    return "osg-pilot-container:%s.gratia.opensciencegrid.org" % site_dns


def outbox_row_to_jur(row):
    resource_type = 'Batch'
    r = Gratia.UsageRecord(resource_type)

    SECONDS = "Was entered in seconds"
    USER = "user"

    r.StartTime    ( row.StartTime,               SECONDS )
    r.EndTime      ( row.EndTime,                 SECONDS )
    r.WallDuration ( row.WallDuration,            SECONDS )
    r.CpuDuration  ( row.CpuDuration or 0,  USER, SECONDS )
    r.Processors   ( row.Processors, metric="max" )
    if row.Memory:
        r.Memory   ( row.Memory, "MB", description="RSS")
    r.SiteName     ( row.SiteName )
    r.ProbeName    ( site_ProbeName(row.SiteName) )
    r.MachineName  ( row.MachineName )
    r.Grid         ( grid_type )
    r.LocalUserId  ( LOCAL_USER )
    r.VOName       ( VO )
    r.ReportableVOName ( VO )

    return r


def send_updates(opts):
    sqldb = get_db(opts)

    count_submit = 0
    count_delete = 0
    # Deliver only the most recent OUTBOX_MESSAGES_PER_JOB to Gratia
    cur = sqldb.execute(f"select * from outbox order by StartTime limit {OUTBOX_MESSAGES_PER_JOB};")

    persite_info = collections.defaultdict(list)
    delete_ids = []
    for row in cur:
        row = OutboxRow(*row)
        rec = outbox_row_to_jur(row)
        sitekey = rec.GetProbeName(), rec.GetSiteName()
        persite_info[sitekey].append(rec)
        delete_ids.append(row.ID)

    send_alternate_records(persite_info, opts.gratia_config)

    # XXX: we are deleting our outbox records without checking that
    #      sending them in the child process was successful
    for rowid in delete_ids:
            count_submit += 1
            delete_sql = "delete from outbox where id = ?"
            count_delete += sqldb.execute(delete_sql, (rowid,)).rowcount

    sqldb.commit()

    DebugPrint(INFO, "- sent %s records to gracc" % count_submit)
    DebugPrint(INFO, "- deleted %s rows from outbox" % count_delete)


def parse_opts():
    parser = optparse.OptionParser(usage="""%prog""")
    parser.add_option("-f", "--gratia_config",
        help="Location of the Gratia config; defaults to %s." % probe_config,
        dest="gratia_config", default=probe_config)

    parser.add_option("-d", "--db_path",
        help="Location of the sqlite db; defaults to %s." % db_path,
        dest="db_path", default=db_path)

    parser.add_option("-v", "--verbose",
        help="Enable verbose logging to stdout.",
        default=False, action="store_true", dest="verbose")

    parser.add_option("-a", "--any-records",
        help="Do not limit records to those with IsOsgVoContainer.",
        default=False, action="store_true", dest="any_records")

    parser.add_option("-s", "--sleep",
        help="""This should be used with normal cron usage. It sets a random
amount of sleep, up to the specified number of seconds before running.
This reduces the load on the Gratia collector.""",
        dest="sleep", default=0, type="int")

    opts, args = parser.parse_args()

    # Initialize Gratia
    if not opts.gratia_config or not os.path.exists(opts.gratia_config):
        raise Exception("Gratia config '%s' does not exist." %
                        opts.gratia_config)
    GratiaCore.Config = GratiaCore.ProbeConfiguration(opts.gratia_config)

    if opts.verbose:
        GratiaCore.Config.set_DebugLevel(5)

    global grid_type
    grid_type = GratiaCore.Config.getConfigAttribute("Grid")

    return opts, args


def register_gratia():
    GratiaCore.RegisterReporter("osgpilot_meter")
    GratiaCore.RegisterService("OSGPilotContainer", probe_version)
    GratiaCore.setProbeBatchManager("condor")


def main():
    opts, args = parse_opts()

    GratiaWrapper.CheckPreconditions()

    if opts.sleep:
        rnd = random.randint(1, int(opts.sleep))
        DebugPrint(INFO, "Sleeping for %d seconds before proceeding." % rnd)
        time.sleep(rnd)

    GratiaWrapper.ExclusiveLock()

    register_gratia()
    GratiaCore.Initialize(opts.gratia_config)

    query_and_update_db(opts)
    send_updates(opts)


def send_alternate_records(gratia_info, gratia_config):
    """
    For any accumulated records with an alternate probe/site name, send in a sub-process.
    """
    if not gratia_info:
        return
    GratiaCore.Disconnect()
    for info, records in gratia_info.items():
        pid = os.fork()
        if pid == 0: # I am the child
            try:
                signal.alarm(5*60)
                send_alternate_records_child(info, records, gratia_config)
            except Exception as e:
                DebugPrint(2, "Failed to send alternate records: %s" % str(e))
                DebugPrintTraceback(2)
                os._exit(0)
            os._exit(0)
        else: # I am parent
            os.waitpid(pid, 0)


def send_alternate_records_child(info, record_list, gratia_config):
    probe, site = info

    try:
        GratiaCore.Initialize(gratia_config)
    except Exception as e:
        DebugPrint(2, "Failed to send alternate records: %s" % str(e))
        DebugPrintTraceback(2)
        raise
    config.Config.setSiteName(site)
    config.Config.setMeterName(probe)
    GratiaCore.Handshake()
    try:
        GratiaCore.SearchOutstandingRecord()
    except Exception as e:
        DebugPrint(2, "Failed to send alternate records: %s" % str(e)) 
        DebugPrintTraceback(2)
        raise
    GratiaCore.Reprocess()

    DebugPrint(2, "Sending alternate records for probe %s / site %s." % (probe, site))
    DebugPrint(2, "Gratia collector to use: %s" % GratiaCore.Config.get_SOAPHost())

    count_found = 0
    count_submit = 0
    for record in record_list:
        count_found += 1
        record.ProbeName(probe)
        record.SiteName(site)
        response = GratiaCore.Send(record)
        if response[:2] == 'OK':
            count_submit += 1
        DebugPrint(4, "Sending record for probe %s in site %s to Gratia: %s."% \
            (probe, site, response))

    DebugPrint(2, "Number of usage records submitted: %d" % count_submit)
    DebugPrint(2, "Number of usage records found: %d" % count_found)

    GratiaCore.ProcessCurrentBundle()

    os._exit(0)



if __name__ == '__main__':
    main()


