#!/usr/libexec/platform-python

import os
from ctypes import CDLL
from time import sleep, monotonic, process_time
from sys import stderr, exit, argv
from signal import signal, SIGKILL, SIGTERM, SIGINT, SIGQUIT, SIGHUP
from operator import itemgetter


def get_swappiness():
    """
    """
    return rline('/proc/sys/vm/swappiness')


def debug_cg():
    """
    """
    for cg in cg_dict:
        cur_path = cg_dict[cg]['cur_path']
        high_path = cg_dict[cg]['high_path']

        try:
            c = rline(cur_path)
            h = rline(high_path)
        except FileNotFoundError as e:
            print('  cgroup: {}, {}'.format(cg, e))
            continue

        c_mib = round(c / 1024 / 1024, 1)
        h_mib = round(h / 1024 / 1024, 1)
        c_pc = round(c / 1024 / mem_total * 100, 1)
        h_pc = round(h / 1024 / mem_total * 100, 1)
        print('  cg: {}, memory.current: {}M ({}%), memory.high:'
              ' {}M ({}%)'.format(
                  cg, c_mib, c_pc, h_mib, h_pc))


def errprint(*text):
    """
    """
    print(*text, file=stderr, flush=True)


def string_to_float_convert_test(string):
    """
    Try to interprete string values as floats.
    """
    try:
        return float(string)
    except ValueError:
        return None


def mlockall():
    """
    """
    MCL_CURRENT = 1
    MCL_FUTURE = 2
    MCL_ONFAULT = 4

    libc = CDLL('libc.so.6', use_errno=True)
    result = libc.mlockall(MCL_CURRENT | MCL_FUTURE | MCL_ONFAULT)

    if result != 0:
        result = libc.mlockall(MCL_CURRENT | MCL_FUTURE)
        if result != 0:
            print('ERROR: cannot lock process memory: [Errno {}]'.format(
                result))
            exit(1)
        else:
            if debug:
                print('Process memory locked with MCL_CURRENT | MCL_FUTURE')
    else:
        if debug:
            print('Process memory locked with MCL_CURRENT | MCL_FUTURE | MCL'
                  '_ONFAULT')


def rline(path):
    """
    """
    with open(path, 'rb', buffering=0) as f:
        try:
            return int(f.read())
        except ValueError:
            return 0


def check_mem_and_swap():
    """
    """
    fd['mi'].seek(0)

    m_list = fd['mi'].read().decode().split(' kB\n')

    ma = int(m_list[mem_available_index].split(':')[1])
    st = int(m_list[swap_total_index].split(':')[1])
    sf = int(m_list[swap_free_index].split(':')[1])

    return ma, st, sf


def write(path, string):
    """
    """
    with open(path, 'w') as f:
        f.write(string)


def percent(num):
    """Interprete num as percentage."""
    return round(num * 100, 1)


def sleep_after_check_mem(ma):
    """
    """
    if stable_interval:
        sleep(min_interval)
        return None

    t = ma / mem_fill_rate

    if t > max_interval:
        t = max_interval
    elif t < min_interval:
        t = min_interval
    else:
        pass

    sleep(t)


def soft_exit(num):
    """
    """
    for cg in cg_dict:

        try:
            write(cg_dict[cg]['high_path'], 'max\n')
        except FileNotFoundError as e:
            if debug:
                print(e)
            continue

    if len(fd) > 0:
        for f in fd:
            fd[f].close()

    m1 = monotonic()
    pt1 = process_time()
    x = (pt1 - pt0) / (m1 - m0) * 100

    print('CPU usage by the process since monitoring has started'
          ': {}%; exit.'.format(round(x, 2)))
    exit(num)


def signal_handler(signum, frame):
    """
    """
    def signal_handler_inner(signum, frame):
        """
        """
        pass

    for i in sig_list:
        signal(i, signal_handler_inner)

    print('Signal handler called with the {} signal '.format(
        sig_dict[signum]))

    soft_exit(0)


def correction():
    """
    """

    if debug:
        print('Enter in correction')

    enter_time = monotonic()

    while True:

        ma, st, sf = check_mem_and_swap()

        if debug:
            print('MemAvailable: {}M ({}%), SwapFree: {}M'.format(
                    round(ma / 1024, 1),
                    percent(ma / mem_total),
                    round(sf / 1024, 1)))

            debug_cg()

        # 0
        swappiness = get_swappiness()
        if swappiness == 0:
            if debug:
                print('vm.swappiness=0')

            for cg in cg_dict:

                try:
                    write(
                        cg_dict[cg]['high_path'],
                        str(cg_dict[cg]['lim_max']) + '\n')
                except FileNotFoundError as e:

                    if debug:
                        print(e)

                    if nonexistent_cgroup_is_eror:
                        soft_exit(1)
                    else:
                        continue

            sleep(max_interval)
            break

        # 1
        if sf < min_swap_free:

            if debug:
                print('SwapFree < $MIN_SWAP_FREE; extend limits')

            for cg in cg_dict:

                try:
                    write(cg_dict[cg]['high_path'], str(
                        cg_dict[cg]['lim_max']) + '\n')
                except FileNotFoundError as e:

                    if debug:
                        print(e)

                    if nonexistent_cgroup_is_eror:
                        soft_exit(1)
                    else:
                        continue

            sleep(max_interval)

            if debug:
                print('Exit from correction')

            break

        # 2
        if ma < keep_memavail_min:
            enter_time = monotonic()

            c_list = []

            cur_sum = 0

            for cg in cg_dict:

                try:
                    cur = rline(cg_dict[cg]['cur_path'])
                except FileNotFoundError as e:

                    if debug:
                        print(e)

                    if nonexistent_cgroup_is_eror:
                        soft_exit(1)
                    else:
                        continue

                cg_dict[cg]['real_cur'] = cur
                cur_sum += cur

            for cg in cg_dict:

                try:
                    frac_cur = cg_dict[cg]['real_cur'] / cur_sum
                except KeyError:
                    continue

                cg_dict[cg]['frac_cur'] = frac_cur

                frac = cg_dict[cg]['fraction']

                delta = frac_cur - frac

                cg_dict[cg]['delta'] = delta

                c_list.append((cg, delta))

            if len(c_list) == 0:
                if debug:
                    print('Nothing to limit, no cgroup found')
                    print('Exit from correction')
                sleep(max_interval)
                break

            c_list = sorted(c_list, key=itemgetter(1), reverse=True)

            if debug:
                print('Q:', c_list)

            delta = keep_memavail_mid - ma
            if delta > max_step:
                delta = max_step

            delta_b = delta * 1024

            if debug:
                print('step, M:', round(delta / 1024, 1))

            for cg, _ in c_list:
                cur = cg_dict[cg]['real_cur']

                lim_min = cg_dict[cg]['lim_min']

                if cur - delta_b < lim_min:
                    continue

                new_max = cur - delta_b

                if debug:
                    print('set new memory.high for {}: {} MiB'.format(
                        cg, round(new_max / 1024 / 1024)))

                try:
                    write(cg_dict[cg]['high_path'], str(new_max) + '\n')
                except FileNotFoundError as e:

                    if debug:
                        print(e)

                    if nonexistent_cgroup_is_eror:
                        soft_exit(1)
                    else:
                        continue

                break

            sleep(min_interval)

        # 3
        elif ma < keep_memavail_max:

            enter_time = monotonic()

            sleep_after_check_mem(ma)

        # 4
        else:

            if monotonic() - enter_time > correction_exit_time:

                if debug:
                    print('$CORRECTION_EXIT_TIME expired; extend limits')

                for cg in cg_dict:

                    try:
                        write(
                            cg_dict[cg]['high_path'],
                            str(cg_dict[cg]['lim_max']) + '\n')
                    except FileNotFoundError as e:

                        if debug:
                            print(e)

                        if nonexistent_cgroup_is_eror:
                            soft_exit(1)
                        else:
                            continue

                if debug:
                    print('Exit from correction')
                break

            sleep_after_check_mem(ma)


###############################################################################


a = argv[1:]
la = len(a)
if la == 0:
    errprint('invalid input: missing CLI options')
    exit(1)
elif la == 2:
    if a[0] == '-c':
        config = a[1]
        config = os.path.abspath(config)
    else:
        errprint('invalid input')
        exit(1)
else:
    errprint('invalid input')
    exit(1)


with open('/proc/meminfo') as f:
    mem_list = f.readlines()
mem_list_names = []
for s in mem_list:
    mem_list_names.append(s.split(':')[0])
try:
    mem_available_index = mem_list_names.index('MemAvailable')
except ValueError:
    print('ERROR: your Linux kernel is too old, Linux 3.14+ required')
swap_total_index = mem_list_names.index('SwapTotal')
swap_free_index = mem_list_names.index('SwapFree')
mem_total = mt = int(mem_list[0].split(':')[1][:-4])
mt_b = mt * 1024


config_dict = dict()

cg_dict = dict()


try:
    with open(config) as f:
        for line in f:

            if line[0] == '$' and '=' in line:
                key, _, value = line.partition('=')
                key = key.rstrip()
                value = value.strip()
                if key in config_dict:
                    errprint('config key {} duplication'.format(key))
                    exit(1)
                config_dict[key] = value

            if line[0] == '@' and '=' in line:
                if line.startswith('@LIMIT '):
                    a_list = line.partition('@LIMIT ')[2:][0].split()
                    lal = len(a_list)
                    if lal != 4:
                        errprint('invalid conf')
                        exit(1)

                    a_dict = dict()
                    for pair in a_list:
                        key, _, value = pair.partition('=')
                        a_dict[key] = value

                    cg = a_dict['CGROUP'].strip('/')

                    if cg in cg_dict:
                        errprint('err, invalid conf, cgroup ({}) dupli'
                                 'cation'.format(cg))
                        exit(1)

                    cur_path = '/sys/fs/cgroup/{}/memory.current'.format(cg)
                    high_path = '/sys/fs/cgroup/{}/memory.high'.format(cg)
                    fraction = float(a_dict['FRACTION'])
                    min_percent = a_dict['MIN_PERCENT']
                    max_percent = a_dict['MAX_PERCENT']
                    m_min = int(mem_total * 1024 / 100 * float(min_percent))
                    m_max = int(mem_total * 1024 / 100 * float(max_percent))

                    cg_dict[cg] = dict()
                    cg_dict[cg] = {
                        'cur_path': cur_path,
                        'high_path': high_path,
                        'fraction': fraction,
                        'lim_min': m_min,
                        'lim_max': m_max}
                else:
                    print('invalid conf')

except (PermissionError, UnicodeDecodeError, IsADirectoryError,
        IndexError, FileNotFoundError) as e:
    errprint('Invalid config: {}. Exit.'.format(e))
    exit(1)


if len(cg_dict) == 0:
    print('ERROR: invalid config, no one cgroup set to limit')
    exit(1)


frac_sum = 0
for cg in cg_dict:
    frac_sum += cg_dict[cg]['fraction']
for cg in cg_dict:
    cg_dict[cg]['fraction'] = round(cg_dict[cg]['fraction'] / frac_sum, 4)


if '$DEBUG' in config_dict:
    debug = config_dict['$DEBUG']
    if debug == 'True':
        debug = True
    elif debug == 'False':
        debug = False
    else:
        errprint('invalid $DEBUG value')
        exit(1)
else:
    errprint('missing $DEBUG key')
    exit(1)


if '$MIN_SWAP_FREE_MIB' in config_dict:
    string = config_dict['$MIN_SWAP_FREE_MIB']
    min_swap_free_mib = string_to_float_convert_test(string)
    if min_swap_free_mib is None:
        errprint('invalid $MIN_SWAP_FREE_MIB value')
        exit(1)
    min_swap_free = int(min_swap_free_mib * 1024)
else:
    errprint('missing $MIN_SWAP_FREE_MIB key')
    exit(1)


if '$MAX_STEP_MIB' in config_dict:
    string = config_dict['$MAX_STEP_MIB']
    max_step_mib = string_to_float_convert_test(string)
    if max_step_mib is None:
        errprint('invalid $MAX_STEP_MIB value')
        exit(1)
    max_step = int(max_step_mib * 1024)
else:
    errprint('missing $MAX_STEP_MIB key')
    exit(1)


if '$CORRECTION_EXIT_TIME' in config_dict:
    string = config_dict['$CORRECTION_EXIT_TIME']
    correction_exit_time = string_to_float_convert_test(string)
    if correction_exit_time is None:
        errprint('invalid $CORRECTION_EXIT_TIME value')
        exit(1)
else:
    errprint('missing $CORRECTION_EXIT_TIME key')
    exit(1)


if '$MEM_FILL_RATE' in config_dict:
    string = config_dict['$MEM_FILL_RATE']
    mem_fill_rate = string_to_float_convert_test(string)
    if mem_fill_rate is None:
        errprint('invalid $MEM_FILL_RATE value')
        exit(1)
    mem_fill_rate = 1024 * mem_fill_rate
else:
    errprint('missing $MEM_FILL_RATE key')
    exit(1)


if '$MIN_INTERVAL' in config_dict:
    string = config_dict['$MIN_INTERVAL']
    min_interval = string_to_float_convert_test(string)
    if min_interval is None:
        errprint('invalid $MIN_INTERVAL value')
        exit(1)
else:
    errprint('missing $MIN_INTERVAL key')
    exit(1)


if '$MAX_INTERVAL' in config_dict:
    string = config_dict['$MAX_INTERVAL']
    max_interval = string_to_float_convert_test(string)
    if max_interval is None:
        errprint('invalid $MAX_INTERVAL value')
        exit(1)
else:
    errprint('missing $MAX_INTERVAL key')
    exit(1)


if min_interval == max_interval:
    stable_interval = True
else:
    stable_interval = False


if '$KEEP_MEM_MIN' in config_dict:
    string = config_dict['$KEEP_MEM_MIN']
    if string.endswith('%'):
        keep_memavail_min_percent = string[:-1].rstrip()
        keep_memavail_min_percent = string_to_float_convert_test(
            keep_memavail_min_percent)
        if keep_memavail_min_percent is None:
            errprint('invalid $KEEP_MEM_MIN value')
            exit(1)
        keep_memavail_min = int(mem_total / 100 * keep_memavail_min_percent)
    elif string.endswith('M'):
        keep_memavail_min_mib = string[:-1].rstrip()
        keep_memavail_min_mib = string_to_float_convert_test(
            keep_memavail_min_mib)
        if keep_memavail_min_mib is None:
            errprint('invalid $KEEP_MEM_MIN value')
            exit(1)
        keep_memavail_min = int(keep_memavail_min_mib * 1024)
    else:
        errprint('invalid $KEEP_MEM_MIN value')
        exit(1)
else:
    errprint('missing $KEEP_MEM_MIN key')
    exit(1)


if '$KEEP_MEM_MID' in config_dict:
    string = config_dict['$KEEP_MEM_MID']
    if string.endswith('%'):
        keep_memavail_mid_percent = string[:-1].rstrip()
        keep_memavail_mid_percent = string_to_float_convert_test(
            keep_memavail_mid_percent)
        if keep_memavail_mid_percent is None:
            errprint('invalid $KEEP_MEM_MID value')
            exit(1)
        keep_memavail_mid = int(mem_total / 100 * keep_memavail_mid_percent)
    elif string.endswith('M'):
        keep_memavail_mid_mib = string[:-1].rstrip()
        keep_memavail_mid_mib = string_to_float_convert_test(
            keep_memavail_mid_mib)
        if keep_memavail_mid_mib is None:
            errprint('invalid $KEEP_MEM_MID value')
            exit(1)
        keep_memavail_mid = int(keep_memavail_mid_mib * 1024)
    else:
        errprint('invalid $KEEP_MEM_MID value')
        exit(1)
else:
    errprint('missing $KEEP_MEM_MID key')
    exit(1)


if '$KEEP_MEM_MAX' in config_dict:
    string = config_dict['$KEEP_MEM_MAX']
    if string.endswith('%'):
        keep_memavail_max_percent = string[:-1].rstrip()
        keep_memavail_max_percent = string_to_float_convert_test(
            keep_memavail_max_percent)
        if keep_memavail_max_percent is None:
            errprint('invalid $KEEP_MEM_MAX value')
            exit(1)
        keep_memavail_max = int(mem_total / 100 * keep_memavail_max_percent)
    elif string.endswith('M'):
        keep_memavail_max_mib = string[:-1].rstrip()
        keep_memavail_max_mib = string_to_float_convert_test(
            keep_memavail_max_mib)
        if keep_memavail_max_mib is None:
            errprint('invalid $KEEP_MEM_MAX value')
            exit(1)
        keep_memavail_max = int(keep_memavail_max_mib * 1024)
    else:
        errprint('invalid $KEEP_MEM_MAX value')
        exit(1)
else:
    errprint('missing $KEEP_MEM_MAX key')
    exit(1)


if '$NONEXISTENT_CGROUP_IS_ERROR' in config_dict:
    nonexistent_cgroup_is_eror = config_dict['$NONEXISTENT_CGROUP_IS_ERROR']
    if nonexistent_cgroup_is_eror == 'True':
        nonexistent_cgroup_is_eror = True
    elif nonexistent_cgroup_is_eror == 'False':
        nonexistent_cgroup_is_eror = False
    else:
        errprint('invalid $NONEXISTENT_CGROUP_IS_ERROR value')
        exit(1)
else:
    errprint('missing $NONEXISTENT_CGROUP_IS_ERROR key')
    exit(1)


if debug:
    print('keep_memavail_min:', keep_memavail_min)
    print('keep_memavail_mid:', keep_memavail_mid)
    print('keep_memavail_max:', keep_memavail_max)
    print('max_step:', max_step)
    print('min_swap_free:', min_swap_free)
    print('correction_exit_time:', correction_exit_time)
    print('mem_fill_rate:', mem_fill_rate)
    print('min_interval:', min_interval)
    print('max_interval:', max_interval)
    print('nonexistent_cgroup_is_error:', nonexistent_cgroup_is_eror)
    print('debug:', debug)
    for cg in cg_dict:
        print('cgroup: {} {}'.format(cg, '{'))
        x = cg_dict[cg]
        for i in x:
            print('  {}: {}'.format(i, x[i]))
        print('}')


###############################################################################


mlockall()


fd = dict()
fd['mi'] = open('/proc/meminfo', 'rb', buffering=0)


sig_dict = {
    SIGKILL: 'SIGKILL',
    SIGINT: 'SIGINT',
    SIGQUIT: 'SIGQUIT',
    SIGHUP: 'SIGHUP',
    SIGTERM: 'SIGTERM'
}


sig_list = [SIGTERM, SIGINT, SIGQUIT, SIGHUP]


for i in sig_list:
    signal(i, signal_handler)

m0 = monotonic()
pt0 = process_time()


if not os.path.exists('/sys/fs/cgroup/cgroup.procs'):
    print('ERROR: unified cgroup hierarchy not found')
    exit(1)


for cg in cg_dict:
    try:
        write(cg_dict[cg]['high_path'], '{}\n'.format(
            cg_dict[cg]['lim_max']))
    except Exception as e:
        if nonexistent_cgroup_is_eror:
            if debug:
                print(e)
            soft_exit(1)


swappiness = get_swappiness()
if debug:
    print('vm.swappiness:', swappiness)
if swappiness == 0:
    print('WARNING: vm.swappiness=0')


print('Monitoring has started!')


while True:
    ma, st, sf = check_mem_and_swap()
    if debug:
        print('MemAvailable: {}M ({}%), SwapFree: {}M'.format(
            round(ma / 1024, 1),
            percent(ma / mem_total),
            round(sf / 1024, 1)))
        debug_cg()
    if sf < min_swap_free:
        sleep(max_interval)
        continue
    if ma < keep_memavail_min:

        swappiness = get_swappiness()
        if swappiness == 0:
            if debug:
                print('vm.swappiness=0')
            sleep(max_interval)
            continue

        correction()
    else:
        sleep_after_check_mem(ma)
