#!/usr/bin/python -Es
# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2012 Red Hat, Inc.
#
# Authors:
# Thomas Woerner <twoerner@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

from gi.repository import GObject
import sys
sys.modules['gobject'] = GObject

import getopt
import dbus
import os

from firewall.client import FirewallClient
from firewall.errors import *

def usage():
    os.system("man firewall-cmd")

def __fail(msg=None):
    if msg:
        print(msg)
#    usage()
    sys.exit(2)

def __parse_port(value):
    try:
        (port, proto) = value.split("/")
    except Exception as msg:
        __fail(msg)
    return (port, proto)

def __parse_forward_port(value):
    port = None
    protocol = None
    toport = None
    toaddr = None
    args = value.split(":")
    for arg in args:
        try:
            (opt,val) = arg.split("=")
            if opt == "port":
                port = val
            elif opt == "proto":
                protocol = val
            elif opt == "toport":
                toport = val
            elif opt == "toaddr":
                toaddr = val
        except:
            __fail("invalid forward port arg '%s'" % (arg))
    if not port:
        __fail("missing port")
    if not protocol:
        __fail("missing protocol")
    if not (toport or toaddr):
        __fail("missing destination")
    return (port, protocol, toport, toaddr)

def __list_all(fw, zone):
    interfaces = fw.getInterfaces(zone)
    services = fw.getServices(zone)
    ports = fw.getPorts(zone)
    forward_ports = fw.getForwardPorts(zone)
    icmp_blocks = fw.getIcmpBlocks(zone)

    print(zone if zone != "" else fw.getDefaultZone())
    print("  interfaces: " + " ".join(interfaces))
    print("  services: " + " ".join(services))
    print("  ports: " + " ".join(["%s/%s" % (port[0], port[1]) for port in ports]))
    print("  forward-ports: " + "\n\t".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in forward_ports]))
    print("  icmp-blocks: " + " ".join(icmp_blocks))

if "--direct" not in sys.argv[1:]:
    try:
        (opts, args) = \
            getopt.getopt(sys.argv[1:], "hv", 
                          [ "help", "version", "timeout=",
                            "reload", "complete-reload", "state",
                            "get-default-zone", "set-default-zone=",
                            "get-zones", "get-active-zones",
                            "get-zone-of-interface=",
                            "enable-panic", "disable-panic", "query-panic",
                            "get-services", "get-icmptypes", "list-all-zones",
                            # zone
                            "zone=",
                            # permanent
                            "permanent",
                            # modes (exactly one of those)
                            "add-service=", "remove-service=", "query-service=",
                            "list-services",
                            "add-port=", "remove-port=", "query-port=",
                            "list-ports",
                            "add-interface=", "remove-interface=",
                            "query-interface=", "list-interfaces",
                            "change-interface=",
                            "add-masquerade", "remove-masquerade",
                            "query-masquerade",
                            "add-icmp-block=", "remove-icmp-block=",
                            "query-icmp-block=", "list-icmp-blocks",
                            "add-forward-port=", "remove-forward-port=",
                            "query-forward-port=",
                            "list-forward-ports",
                            "list-all",
                            ])
    except Exception as msg:
        print(msg)
        usage()
        sys.exit(1)

    if not opts:
        usage()
        sys.exit(1)
else:
    opts = [ ]
    args = sys.argv[1:]

timeout = 0
mode = None
value = None
zone = ""
permanent = False

if len(args) > 2 and args[0] == "--direct":
    direct_ipv = args[2]
    if args[1] == "--passthrough" and len(args) > 3:
        mode = args[1][2:]
        direct_args = args[3:]
    elif args[1] in [ "--add-chain", "--remove-chain", "--query-chain" ] and \
            len(args) == 5:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]
    elif args[1] == "--get-chains" and len(args) == 4:
        mode = args[1][2:]
        direct_table = args[3]
    elif args[1] == "--add-rule" and len(args) > 7:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]
        direct_priority = int(args[5])
        direct_args = args[6:]
    elif args[1] in [ "--remove-rule", "--query-rule" ] and \
            len(args) > 6:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]
        direct_args = args[5:]
    elif args[1] == "--get-rules" and len(args) == 5:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]        
    else:
        usage()
        sys.exit(1)

for (opt, val) in opts:
    if opt in ["-h", "--help"]:
        usage()
        sys.exit(0)
    elif opt in ["-v", "--version"]:
        if mode:
            __fail()
        mode = "version"

    elif opt == "--permanent":
        permanent = True

    elif opt in [ "--reload", "--complete-reload", "--state",
                  "--get-default-zone", "--get-zones", "--get-active-zones",
                  "--enable-panic", "--disable-panic", "--query-panic",

                  "--list-services", "--list-ports", "--list-interfaces", 
                  "--list-icmp-blocks", "--list-forward-ports",
                  "--list-all",
                  "--get-services", "--get-icmptypes", "--list-all-zones",
                  "--add-masquerade", "--remove-masquerade",
                  "--query-masquerade", ]:
        if mode:
            __fail()
        mode = opt[2:]

    elif opt in [ "--set-default-zone", "--get-zone-of-interface",
                  "--add-service", "--remove-service", "--query-service",
                  "--add-port", "--remove-port", "--query-port",
                  "--add-interface", "--remove-interface",
                  "--query-interface",
                  "--change-interface",
                  "--add-icmp-block", "--remove-icmp-block",
                  "--query-icmp-block",
                  "--add-forward-port", "--remove-forward-port",
                  "--query-forward-port", ]:
        if mode:
            __fail()
        mode = opt[2:]
        value = val

    # zone
    elif opt == "--zone":
        if zone:
            __fail()
        zone = val

    # timeout
    elif opt == "--timeout":
        try:
            timeout = int(val)
        except Exception as msg:
            usage()
            sys.exit(2)
        if timeout < 1:
            __fail("Timeout not valid")
if not mode:
    __fail("No mode.")

if timeout != 0:
    if mode[:3] != "add":
        __fail("Timeout only valid in add mode.")

if zone and mode in [ "state", "reload", "complete-reload",
                      "enable-panic", "disable-panic", "query-panic",
                      "get-default-zone", "set-default-zone",
                      "get-zones", "get-active-zones", "get-zone-of-interface",
                      "get-services", "get-icmptypes", "list-all-zones", ]:
    usage()
    sys.exit(2)

if permanent:
    if timeout != 0:
        usage()
        sys.exit(1)
    if mode in [ "get-zones", "get-services", "get-icmptypes" ] and not zone:
        pass
    elif mode in [
        "add-service", "remove-service", "query-service", "list-services",
        "add-port", "remove-port", "query-port", "list-ports",
        "add-icmp-block", "remove-icmp-block", "query-icmp-block",
        "list-icmp-blocks",
        "add-masquerade", "remove-masquerade", "query-masquerade",
        "add-forward-port", "remove-forward-port", "query-forward-port", 
        "list-forward-ports"]:
        pass
    else:
        usage()
        sys.exit(1)
else:
    if mode in [ "get-services", "get-icmptypes", "list-all-zones" ] and zone:
        usage()
        sys.exit(1)

#print("ZONE='%s', MODE='%s'" % (zone, mode))

try:
    fw = FirewallClient()
    if fw.connected == False:
        print("Couldn't connect to FirewallD, it's probably not running.")
        sys.exit(NOT_RUNNING)

    if permanent:
        if mode == "get-zones":
            zones = fw.config().listZones()
            l = [fw.config().getZone(z).get_property("name") for z in zones]
            if len(l) > 0:
                print(" ".join(l))
        elif mode == "get-services":
            services = fw.config().listServices()
            l = [fw.config().getService(s).get_property("name") for s in services]
            if len(l) > 0:
                print(" ".join(l))
        elif mode == "get-icmptypes":
            icmptypes = fw.config().listIcmpTypes()
            l = [fw.config().getIcmpType(i).get_property("name") for i in icmptypes]
            if len(l) > 0:
                print(" ".join(l))
        else:
            if not zone:
                zone = fw.getDefaultZone()
            fw_zone = fw.config().getZoneByName(zone)
            fw_settings = fw_zone.getSettings()

            # service
            if mode == "list-services":
                l = fw_settings.getServices()
                if len(l) > 0:
                    print(" ".join(l))
                sys.exit(0)
            elif mode == "add-service":
                fw_settings.addService(value)
            elif mode == "remove-service":
                fw_settings.removeService(value)
            elif mode == "query-service":
                sys.exit(not value in fw_settings.getServices())

            # port
            elif mode == "list-ports":
                l = fw_settings.getPorts()
                if len(l) > 0:
                    print(" ".join(["%s/%s" % (port[0], port[1]) for port in l]))
                sys.exit(0)
            elif mode == "add-port":
                (port, proto) = __parse_port(value)
                fw_settings.addPort(port, proto)
            elif mode == "remove-port":
                (port, proto) = __parse_port(value)
                fw_settings.removePort(port, proto)
            elif mode == "query-port":
                (port, proto) = __parse_port(value)
                sys.exit(not (port, proto) in fw_settings.getPorts())

            # masquerade
            elif mode == "add-masquerade":
                fw_settings.setMasquerade(True)
            elif mode == "remove-masquerade":
                fw_settings.setMasquerade(False)
            elif mode == "query-masquerade":
                sys.exit(not fw_settings.getMasquerade())

            # forward port
            elif mode == "list-forward-ports":
                l = fw_settings.getForwardPorts()
                if len(l) > 0:
                    print("\n".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in l]))
                sys.exit(0)
            elif mode == "add-forward-port":
                (port, protocol, toport, toaddr) = __parse_forward_port(value)
                fw_settings.addForwardPort(port, protocol, toport, toaddr)
            elif mode == "remove-forward-port":
                (port, protocol, toport, toaddr) = __parse_forward_port(value)
                fw_settings.removeForwardPort(port, protocol, toport, toaddr)
            elif mode == "query-forward-port":
                (port, protocol, toport, toaddr) = __parse_forward_port(value)
                sys.exit(not fw_settings.queryForwardPort(port, protocol, toport, toaddr))

            # block icmp
            elif mode == "list-icmp-blocks":
                l = fw_settings.getIcmpBlocks()
                if len(l) > 0:
                    print(" ".join(l))
                sys.exit(0)
            elif mode == "add-icmp-block":
                fw_settings.addIcmpBlock(value)
            elif mode == "remove-icmp-block":
                fw_settings.removeIcmpBlock(value)
            elif mode == "query-icmp-block":
                sys.exit(not value in fw_settings.getIcmpBlocks())

            fw_zone.update(fw_settings)

    elif mode == "version":
        print(fw.get_property("version"))
        sys.exit(0)
    elif mode == "state":
        state = fw.get_property("state")
        if state != "RUNNING":
            sys.exit(1)
    elif mode == "reload":
        if not fw.reload():
            sys.exit(1)
    elif mode == "complete-reload":
        fw.complete_reload()
    elif mode == "passthrough":
        print(fw.passthrough(direct_ipv, direct_args))
    elif mode == "add-chain":
        fw.addChain(direct_ipv, direct_table, direct_chain)
    elif mode == "remove-chain":
        fw.removeChain(direct_ipv, direct_table, direct_chain)
    elif mode == "query-chain":
        sys.exit(not fw.queryChain(direct_ipv, direct_table, direct_chain))
    elif mode == "get-chains":
        print(" ".join(fw.getChains(direct_ipv, direct_table)))
    elif mode == "add-rule":
        fw.addRule(direct_ipv, direct_table, direct_chain, direct_priority,
                   direct_args)
    elif mode == "remove-rule":
        fw.removeRule(direct_ipv, direct_table, direct_chain, direct_args)
    elif mode == "query-rule":
        sys.exit(not fw.queryRule(direct_ipv, direct_table, direct_chain,
                                  direct_args))
    elif mode == "get-rules":
        rules = fw.getRules(direct_ipv, direct_table, direct_chain)
        for rule in rules:
            print(" ".join(rule)) 
    elif mode == "get-default-zone":
        print(fw.getDefaultZone())
    elif mode == "set-default-zone":
        fw.setDefaultZone(value)
    elif mode == "get-zones":
        print(" ".join(fw.getZones()))
    elif mode == "get-active-zones":
        zones = fw.getActiveZones()
        for zone in zones:
            print("%s: %s" % (zone, " ".join(zones[zone])))
    elif mode == "get-zone-of-interface":
        try:
            print(fw.getZoneOfInterface(value))
        except:
            pass
    elif mode == "get-services":
        l = fw.listServices()
        if len(l) > 0:
            print(" ".join(l))
    elif mode == "get-icmptypes":
        l = fw.listIcmpTypes()
        if len(l) > 0:
            print(" ".join(l))

    # panic
    elif mode == "enable-panic":
        fw.enablePanicMode()
    elif mode == "disable-panic":
        fw.disablePanicMode()
    elif mode == "query-panic":
        sys.exit(not fw.queryPanicMode())

    # interface
    elif mode == "list-interfaces":
        l = fw.getInterfaces(zone)
        if len(l) > 0:
            print(" ".join(l))
    elif mode == "add-interface":
        fw.addInterface(zone, value)
    elif mode == "change-interface":
        fw.changeZone(zone, value)
    elif mode == "remove-interface":
        fw.removeInterface(zone, value)
    elif mode == "query-interface":
        sys.exit(not fw.queryInterface(zone, value))

    # service
    elif mode == "list-services":
        l = fw.getServices(zone)
        if len(l) > 0:
            print(" ".join(l))
    elif mode == "add-service":
        fw.addService(zone, value, timeout)
    elif mode == "remove-service":
        fw.removeService(zone, value)
    elif mode == "query-service":
        sys.exit(not fw.queryService(zone, value))

    # port
    elif mode == "list-ports":
        l = fw.getPorts(zone)
        if len(l) > 0:
            print(" ".join(["%s/%s" % (port[0], port[1]) for port in l]))
    elif mode == "add-port":
        (port, proto) = __parse_port(value)
        fw.addPort(zone, port, proto, timeout)
    elif mode == "remove-port":
        (port, proto) = __parse_port(value)
        fw.removePort(zone, port, proto)
    elif mode == "query-port":
        (port, proto) = __parse_port(value)
        sys.exit(not fw.queryPort(zone, port, proto))

    # masquerade
    elif mode == "add-masquerade":
        fw.addMasquerade(zone, timeout)
    elif mode == "remove-masquerade":
        fw.removeMasquerade(zone)
    elif mode == "query-masquerade":
        sys.exit(not fw.queryMasquerade(zone))

    # forward port
    elif mode == "list-forward-ports":
        l = fw.getForwardPorts(zone)
        if len(l) > 0:
            print("\n".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in l]))

    elif mode == "add-forward-port":
        (port, protocol, toport, toaddr) = __parse_forward_port(value)
        fw.addForwardPort(zone, port, protocol, toport, toaddr, timeout)
    elif mode == "remove-forward-port":
        (port, protocol, toport, toaddr) = __parse_forward_port(value)
        fw.removeForwardPort(zone, port, protocol, toport, toaddr)
    elif mode == "query-forward-port":
        (port, protocol, toport, toaddr) = __parse_forward_port(value)
        sys.exit(not fw.queryForwardPort(zone, port, protocol,
                                         toport, toaddr))

    # block icmp
    elif mode == "list-icmp-blocks":
        l = fw.getIcmpBlocks(zone)
        if len(l) > 0:
            print(" ".join(l))
    elif mode == "add-icmp-block":
        fw.addIcmpBlock(zone, value, timeout)
    elif mode == "remove-icmp-block":
        fw.removeIcmpBlock(zone, value)
    elif mode == "query-icmp-block":
        sys.exit(not fw.queryIcmpBlock(zone, value))


    # list all
    elif mode == "list-all":
        __list_all(fw, zone)


    # list everything
    elif mode == "list-all-zones":
        for zone in fw.getZones():
            __list_all(fw, zone)
            print("")


except dbus.DBusException as e:
    if "NotAuthorizedException" in e.get_dbus_message():
        print("Authorization failed.")
        sys.exit(NOT_AUTHORIZED)
    else:
        try:
            code = FirewallError.get_code(e.message)
        except:
            code = UNKNOWN_ERROR
            print("Error: %s" % e)
        else:
            if code in [ ALREADY_ENABLED, NOT_ENABLED, ZONE_ALREADY_SET ]:
                print("Warning: %s" % e.message)
                sys.exit(0)
            else:
                print("Error: %s" % e.message)
        sys.exit(code)

sys.exit(0)
