#!/usr/bin/env python
#
# Copyright 2012 Red Hat, Inc
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import base64
import hmac
import hashlib
import time
import sys

import authhub

def hex2bin(hex, bytes=0):
    while len(hex) < bytes or len(hex) % 2 == 1:
        hex = "0" + hex
    
    out = ""
    for i in range(len(hex) / 2):
        out += chr(int(hex[i*2:(i+1)*2], 16))
        
    return out

DECODER = {
    "hex"   : hex2bin,
    "base32": base64.b32decode,
    "base64": base64.b64decode,
}

class TOTPHandler(authhub.Handler):
    def getTokenInfo(self, params):
        cfg = params.get("config", {})
        pin = cfg.get("pin", None)
        key = cfg.get("key", None)
        vnd = cfg.get("vendor", None)
        
        if key is None:
            return []
        
        if vnd is None:
            vnd = "TOTP"
        
        flags = 0
        if pin is not None:
            flags |= authhub.Flags.collectPIN
            flags |= authhub.Flags.separatePINRequired
        else:
            flags |= authhub.Flags.doNotCollectPIN
        
        return [{"flags": flags, "otp-vendor": vnd,
                 "otp-format": authhub.Format.decimal}]
    
    def verifyRequest(self, params):
        req = cfg = params.get("request", {})
        cfg = params.get("config", {})
        pin = cfg.get("pin", None)
        key = cfg.get("key", None)
        stp = cfg.get("step", 30)
        flr = cfg.get("floor", 0)
        hsh = cfg.get("hash", "sha1")
        dig = cfg.get("digits", 6)
        
        # Key must be a string
        if not isinstance(key, basestring):
            return False
        
        # Step and Floor must be integers
        if type(stp) != int or type(flr) != int:
            return False
        
        # Digits must be between 4 and 8
        if type(dig) != int or dig < 4 or dig > 8:
            return False
        
        # Hash must be one of these
        if hsh not in ("sha1", "sha256", "sha512"):
            return False

        # The PIN must match
        if isinstance(pin, basestring):
            otp_pin = req.get("otp-pin", None)
            if not isinstance(otp_pin, basestring):
                return False

            # If our pin is hashed, hash the otp_pin which is in plaintext
            if ":" in pin:
                hasher = getattr(hashlib, pin.split(":", 1)[0], None)
                if hasher:
                    pin = pin.split(":", 1)[1]
                    otp_pin = hasher(otp_pin).hexdigest()

            if pin != otp_pin:
                return False

        # Get the key's encoding
        enc = "hex"
        if ":" in key and key.split(":", 1)[0] in DECODER:
            enc, key = key.split(":", 1)
        
        # Get the digest
        T = hex2bin("%016X" % ((int(time.time()) - flr) / stp))
        K = DECODER[enc](key)
        digest  = hmac.HMAC(K, T, getattr(hashlib, hsh)).digest()
        
        # Truncate the digest
        offset  = ord(digest[-1]) & 0xf
        binary  = (ord(digest[offset+0]) & 0x7f) << 0x18
        binary |= (ord(digest[offset+1]) & 0xff) << 0x10
        binary |= (ord(digest[offset+2]) & 0xff) << 0x08
        binary |= (ord(digest[offset+3]) & 0xff) << 0x00
        binary  = binary % (10 ** dig)
        
        # Pad with zeros if necessary
        token = str(binary)
        token = "0" * (dig - len(token)) + token
        
        return token == req.get("otp-value", None)

if __name__ == "__main__":
    authhub.Plugin(TOTPHandler()).runForever()
