#!/usr/bin/python3 -tt
# ---------------------------------------------------------------------
# Copyright 2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# Description:	Google Cloud Platform - Floating IP Address (Alias)
# ---------------------------------------------------------------------

import json
import logging
import os
import sys
import time

OCF_FUNCTIONS_DIR = os.environ.get("OCF_FUNCTIONS_DIR", "%s/lib/heartbeat"
                                   % os.environ.get("OCF_ROOT"))
sys.path.append(OCF_FUNCTIONS_DIR)

from ocf import *

try:
  sys.path.insert(0, '/usr/lib/fence-agents/support/google')
  import googleapiclient.discovery
  try:
    from google.oauth2.service_account import Credentials as ServiceAccountCredentials
  except ImportError:
    from oauth2client.service_account import ServiceAccountCredentials
except ImportError:
  pass

if sys.version_info >= (3, 0):
  # Python 3 imports.
  import urllib.parse as urlparse
  import urllib.request as urlrequest
else:
  # Python 2 imports.
  import urllib as urlparse
  import urllib2 as urlrequest


# Constants for alias add/remove modes
ADD = 0
REMOVE = 1

CONN = None
THIS_VM = None
ALIAS = None
MAX_RETRIES = 3
RETRY_BACKOFF_SECS = 1
METADATA_SERVER = 'http://metadata.google.internal/computeMetadata/v1/'
METADATA_HEADERS = {'Metadata-Flavor': 'Google'}
METADATA = \
'''<?xml version="1.0"?>
<!DOCTYPE resource-agent SYSTEM "ra-api-1.dtd">
<resource-agent name="gcp-vpc-move-vip" version="1.0">
  <version>1.0</version>
  <longdesc lang="en">Floating IP Address or Range on Google Cloud Platform - Using Alias IP address functionality to attach a secondary IP range to a running instance</longdesc>
  <shortdesc lang="en">Floating IP Address or Range on Google Cloud Platform</shortdesc>
  <parameters>
    <parameter name="alias_ip" unique="1" required="1">
      <longdesc lang="en">IP range to be added including CIDR netmask (e.g., 192.168.0.1/32)</longdesc>
      <shortdesc lang="en">IP range to be added including CIDR netmask (e.g., 192.168.0.1/32)</shortdesc>
      <content type="string" default="" />
    </parameter>
    <parameter name="alias_range_name" unique="0" required="0">
      <longdesc lang="en">Subnet name for the Alias IP</longdesc>
      <shortdesc lang="en">Subnet name for the Alias IP</shortdesc>
      <content type="string" default="" />
    </parameter>
    <parameter name="hostlist" unique="0" required="0">
      <longdesc lang="en">List of hosts in the cluster, separated by spaces</longdesc>
      <shortdesc lang="en">Host list</shortdesc>
      <content type="string" default="" />
    </parameter>
    <parameter name="project" unique="0" required="0">
      <longdesc lang="en">
        Project ID of the instance. It can be useful to set this
        attribute if the instance is in a shared service project.
        Otherwise, the agent should be able to determine the project ID
        automatically.
      </longdesc>
      <shortdesc lang="en">Project ID</shortdesc>
      <content type="string" default="default" />
    </parameter>
    <parameter name="serviceaccount">
      <longdesc lang="en">Path to Service account JSON file</longdesc>
      <shortdesc lang="en">Service account JSONfile</shortdesc>
      <content type="string" default="" />
    </parameter>
    <parameter name="stackdriver_logging" unique="0" required="0">
      <longdesc lang="en">If enabled (set to true), IP failover logs will be posted to stackdriver logging</longdesc>
      <shortdesc lang="en">Stackdriver-logging support</shortdesc>
      <content type="boolean" default="" />
    </parameter>
  </parameters>
  <actions>
    <action name="start" timeout="300s" />
    <action name="stop" timeout="15s" />
    <action name="monitor" timeout="15s" interval="60s" depth="0" />
    <action name="meta-data" timeout="15s" />
    <action name="validate-all" timeout="15s" />
  </actions>
</resource-agent>'''


def get_metadata(metadata_key, params=None, timeout=None):
  """Performs a GET request with the metadata headers.

  Args:
    metadata_key: string, the metadata to perform a GET request on.
    params: dictionary, the query parameters in the GET request.
    timeout: int, timeout in seconds for metadata requests.

  Returns:
    HTTP response from the GET request.
  """
  for i in range(MAX_RETRIES):
    try:
      timeout = timeout or 60
      metadata_url = os.path.join(METADATA_SERVER, metadata_key)
      params = urlparse.urlencode(params or {})
      url = '%s?%s' % (metadata_url, params)
      request = urlrequest.Request(url, headers=METADATA_HEADERS)
      request_opener = urlrequest.build_opener(urlrequest.ProxyHandler({}))
      return request_opener.open(
          request, timeout=timeout * 1.1).read().decode("utf-8")
    except Exception as e:
      logger.error('Couldn\'t get instance name, is this running inside GCE?: '
                   + str(e))
      time.sleep(RETRY_BACKOFF_SECS * (i + 1))

  # If the retries are exhausted we exit with a generic error.
  sys.exit(OCF_ERR_GENERIC)


def create_api_connection():
  for i in range(MAX_RETRIES):
    try:
      serviceaccount = os.environ.get("OCF_RESKEY_serviceaccount")
      if not serviceaccount:
        try:
          from googleapiclient import _auth
          credentials = _auth.default_credentials();
        except:
          credentials = GoogleCredentials.get_application_default()
          logging.debug("using application default credentials")
      else:
        scope = ['https://www.googleapis.com/auth/cloud-platform']
        logging.debug("using credentials from service account")
        try:
          credentials = ServiceAccountCredentials.from_service_account_file(filename=serviceaccount, scopes=scope)
        except AttributeError:
          credentials = ServiceAccountCredentials.from_json_keyfile_name(serviceaccount, scope)
        except Exception as e:
          logging.error(str(e))
          sys.exit(OCF_ERR_GENERIC)
      return googleapiclient.discovery.build('compute', 'v1',
                                             credentials=credentials,
                                             cache_discovery=False)
    except Exception as e:
      logger.error('Couldn\'t connect with google api: ' + str(e))
      time.sleep(RETRY_BACKOFF_SECS * (i + 1))

  # If the retries are exhausted we exit with a generic error.
  sys.exit(OCF_ERR_GENERIC)


def get_instance(project, zone, instance):
  request = CONN.instances().get(
      project=project, zone=zone, instance=instance)
  return request.execute()


def get_network_ifaces(project, zone, instance):
  return get_instance(project, zone, instance)['networkInterfaces']


def wait_for_operation(project, zone, operation):
  while True:
    result = CONN.zoneOperations().get(
        project=project,
        zone=zone,
        operation=operation['name']).execute()

    if result['status'] == 'DONE':
      if 'error' in result:
        raise Exception(result['error'])
      return
    time.sleep(1)


def set_aliases(project, zone, instance, aliases, fingerprint):
  """Sets the alias IP ranges for an instance.

  Args:
    project: string, the project in which the instance resides.
    zone: string, the zone in which the instance resides.
    instance: string, the name of the instance.
    aliases: list, the list of dictionaries containing alias IP ranges
      to be added to or removed from the instance.
    fingerprint: string, the fingerprint of the network interface.
  """
  body = {
    'aliasIpRanges': aliases,
    'fingerprint': fingerprint
  }

  request = CONN.instances().updateNetworkInterface(
      instance=instance, networkInterface='nic0', project=project, zone=zone,
      body=body)
  operation = request.execute()
  wait_for_operation(project, zone, operation)


def add_rm_alias(mode, project, zone, instance, alias, alias_range_name=None):
  """Adds or removes an alias IP range for a GCE instance.

  Args:
    mode: int, a constant (ADD (0) or REMOVE (1)) indicating the
      operation type.
    project: string, the project in which the instance resides.
    zone: string, the zone in which the instance resides.
    instance: string, the name of the instance.
    alias: string, the alias IP range to be added to or removed from
      the instance.
    alias_range_name: string, the subnet name for the alias IP range.

  Returns:
    True if the existing list of alias IP ranges was modified, or False
    otherwise.
  """
  ifaces = get_network_ifaces(project, zone, instance)
  fingerprint = ifaces[0]['fingerprint']

  try:
    old_aliases = ifaces[0]['aliasIpRanges']
  except KeyError:
    old_aliases = []

  new_aliases = [a for a in old_aliases if a['ipCidrRange'] != alias]

  if alias:
    if mode == ADD:
      obj = {'ipCidrRange': alias}
      if alias_range_name:
        obj['subnetworkRangeName'] = alias_range_name
      new_aliases.append(obj)
    elif mode == REMOVE:
      pass    # already removed during new_aliases build
    else:
      raise ValueError('Invalid value for mode: {}'.format(mode))

  if (sorted(new_aliases, key=lambda item: item.get('ipCidrRange'))
      != sorted(old_aliases, key=lambda item: item.get('ipCidrRange'))):
    set_aliases(project, zone, instance, new_aliases, fingerprint)
    return True
  else:
    return False


def add_alias(project, zone, instance, alias, alias_range_name=None):
  return add_rm_alias(ADD, project, zone, instance, alias, alias_range_name)


def remove_alias(project, zone, instance, alias):
  return add_rm_alias(REMOVE, project, zone, instance, alias)


def get_aliases(project, zone, instance):
  ifaces = get_network_ifaces(project, zone, instance)
  try:
    aliases = ifaces[0]['aliasIpRanges']
    return [a['ipCidrRange'] for a in aliases]
  except KeyError:
    return []


def get_localhost_aliases():
  net_iface = get_metadata('instance/network-interfaces', {'recursive': True})
  net_iface = json.loads(net_iface)
  try:
    return net_iface[0]['ipAliases']
  except (KeyError, IndexError):
    return []


def get_zone(project, instance):
  fl = 'name="%s"' % instance
  request = CONN.instances().aggregatedList(project=project, filter=fl)
  while request is not None:
    response = request.execute()
    zones = response.get('items', {})
    for zone in zones.values():
      for inst in zone.get('instances', []):
        if inst['name'] == instance:
          return inst['zone'].split("/")[-1]
    request = CONN.instances().aggregatedList_next(
        previous_request=request, previous_response=response)
  raise Exception("Unable to find instance %s" % (instance))


def get_instances_list(project, exclude):
  hostlist = []
  request = CONN.instances().aggregatedList(project=project)
  while request is not None:
    try:
      response = request.execute()
      zones = response.get('items', {})
    except googleapiclient.errors.HttpError as e:
      if e.resp.status == 404:
        logger.debug('get_instances_list(): no instances found')
        return ''

    for zone in zones.values():
      for inst in zone.get('instances', []):
        if inst['name'] != exclude:
          hostlist.append(inst['name'])
    request = CONN.instances().aggregatedList_next(
        previous_request=request, previous_response=response)
  return hostlist


def gcp_alias_start(alias):
  my_aliases = get_localhost_aliases()
  my_zone = get_metadata('instance/zone').split('/')[-1]
  project = os.environ.get(
        'OCF_RESKEY_project', get_metadata('project/project-id'))

  if alias in my_aliases:
    # TODO: Do we need to check alias_range_name?
    logger.info(
        '%s already has %s attached. No action required' % (THIS_VM, alias))
    sys.exit(OCF_SUCCESS)

  # If the alias is currently attached to another host, detach it.
  hostlist = os.environ.get('OCF_RESKEY_hostlist', '')
  if hostlist:
    hostlist = hostlist.replace(THIS_VM, '').split()
  else:
    hostlist = get_instances_list(project, THIS_VM)
  for host in hostlist:
    host_zone = get_zone(project, host)
    host_aliases = get_aliases(project, host_zone, host)
    if alias in host_aliases:
      logger.info(
          '%s is attached to %s - Removing %s from %s' %
          (alias, host, alias, host))
      remove_alias(project, host_zone, host, alias)
      break

  # Add alias IP range to localhost
  try:
    add_alias(
        project, my_zone, THIS_VM, alias,
        os.environ.get('OCF_RESKEY_alias_range_name'))
  except googleapiclient.errors.HttpError as e:
    if e.resp.status == 404:
      sys.exit(OCF_ERR_CONFIGURED)

  # Verify that the IP range has been added
  my_aliases = get_localhost_aliases()
  if alias in my_aliases:
    logger.info('Finished adding %s to %s' % (alias, THIS_VM))
  else:
    if my_aliases:
      logger.error(
          'Failed to add alias IP range %s. %s has alias IP ranges attached but'
          + ' they don\'t include %s' % (alias, THIS_VM, alias))
    else:
      logger.error(
          'Failed to add IP range %s. %s has no alias IP ranges attached'
           % (alias, THIS_VM))
    sys.exit(OCF_ERR_GENERIC)


def gcp_alias_stop(alias):
  my_aliases = get_localhost_aliases()
  my_zone = get_metadata('instance/zone').split('/')[-1]
  project = os.environ.get(
        'OCF_RESKEY_project', get_metadata('project/project-id'))

  if alias in my_aliases:
    logger.info('Removing %s from %s' % (alias, THIS_VM))
    remove_alias(project, my_zone, THIS_VM, alias)
  else:
    logger.info(
        '%s is not attached to %s. No action required'
        % (alias, THIS_VM))


def gcp_alias_status(alias):
  my_aliases = get_localhost_aliases()
  if alias in my_aliases:
    logger.info('%s has the correct IP range attached' % THIS_VM)
  else:
    sys.exit(OCF_NOT_RUNNING)


def validate():
  global ALIAS
  global THIS_VM
  global CONN

  CONN = create_api_connection()
  THIS_VM = get_metadata('instance/name')
  ALIAS = os.environ.get('OCF_RESKEY_alias_ip')
  if not ALIAS:
    logger.error('Missing alias_ip parameter')
    sys.exit(OCF_ERR_CONFIGURED)


def configure_logs():
  # Prepare logging
  global logger
  logging.getLogger('googleapiclient').setLevel(logging.WARN)
  logging_env = os.environ.get('OCF_RESKEY_stackdriver_logging')
  if logging_env:
    logging_env = logging_env.lower()
    if any(x in logging_env for x in ['yes', 'true', 'enabled']):
      try:
        import google.cloud.logging.handlers
        client = google.cloud.logging.Client()
        handler = google.cloud.logging.handlers.CloudLoggingHandler(
            client, name=THIS_VM)
        handler.setLevel(logging.INFO)
        formatter = logging.Formatter('gcp:alias "%(message)s"')
        handler.setFormatter(formatter)
        log.addHandler(handler)
        logger = logging.LoggerAdapter(log, {'OCF_RESOURCE_INSTANCE':
                                             OCF_RESOURCE_INSTANCE})
      except ImportError:
        logger.error('Couldn\'t import google.cloud.logging, '
            'disabling Stackdriver-logging support')


def main():
  if 'meta-data' in sys.argv[1]:
    print(METADATA)
    return

  validate()
  if 'validate-all' in sys.argv[1]:
    return

  configure_logs()
  if 'start' in sys.argv[1]:
    gcp_alias_start(ALIAS)
  elif 'stop' in sys.argv[1]:
    gcp_alias_stop(ALIAS)
  elif 'status' in sys.argv[1] or 'monitor' in sys.argv[1]:
    gcp_alias_status(ALIAS)
  else:
    logger.error('no such function %s' % str(sys.argv[1]))


if __name__ == "__main__":
  main()
