#!/usr/bin/python
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
# vi: set ft=python sts=4 ts=4 sw=4 noet :
#
# This file is part of Fail2Ban.
#
# Fail2Ban is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Fail2Ban is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Fail2Ban; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""
Fail2Ban  reads log file that contains password failure report
and bans the corresponding IP addresses using firewall rules.

This tools can test regular expressions for "fail2ban".

Report bugs to https://github.com/fail2ban/fail2ban/issues
"""

__author__ = "Cyril Jaquier, Yaroslav Halchenko"
__copyright__ = "Copyright (c) 2004-2008 Cyril Jaquier, 2012-2013 Yaroslav Halchenko"
__license__ = "GPL"

import getopt, sys, time, logging, os, locale, shlex

from optparse import OptionParser, Option

from ConfigParser import NoOptionError, NoSectionError, MissingSectionHeaderError

try:
	from systemd import journal
	from fail2ban.server.filtersystemd import FilterSystemd
except ImportError:
	journal = None

from fail2ban.version import version
from fail2ban.client.configparserinc import SafeConfigParserWithIncludes
from fail2ban.server.filter import Filter
from fail2ban.server.failregex import RegexException

from fail2ban.tests.utils import FormatterWithTraceBack
# Gets the instance of the logger.
logSys = logging.getLogger("fail2ban")

def shortstr(s, l=53):
	"""Return shortened string
	"""
	if len(s) > l:
		return s[:l-3] + '...'
	return s

def pprint_list(l, header=None):
	if not len(l):
		return
	if header:
		s = "|- %s\n" % header
	else:
		s = ''
	print s + "|  " + "\n|  ".join(l) + '\n`-'

def file_lines_gen(hdlr):
	for line in hdlr:
		try:
			line = line.decode(fail2banRegex.encoding, 'strict')
		except UnicodeDecodeError:
			if sys.version_info >= (3,): # Python 3 must be decoded
				line = line.decode(fail2banRegex.encoding, 'ignore')
		yield line

def journal_lines_gen(myjournal):
	while True:
		try:
			entry = myjournal.get_next()
		except OSError:
			continue
		if not entry:
			break
		yield FilterSystemd.formatJournalEntry(entry)

def get_opt_parser():
	# use module docstring for help output
	p = OptionParser(
				usage="%s [OPTIONS] <LOG> <REGEX> [IGNOREREGEX]\n" % sys.argv[0] + __doc__
				+ """
LOG:
    string                  a string representing a log line
    filename                path to a log file (/var/log/auth.log)
	"systemd-journal"       search systemd journal (systemd-python required)

REGEX:
    string                  a string representing a 'failregex'
    filename                path to a filter file (filter.d/sshd.conf)

IGNOREREGEX:
    string                  a string representing an 'ignoreregex'
    filename                path to a filter file (filter.d/sshd.conf)
""",
				version="%prog " + version)

	p.add_options([
		Option("-d", "--datepattern",
			   help="set custom pattern used to match date/times"),
		Option("-e", "--encoding",
			   help="File encoding. Default: system locale"),
		Option("-L", "--maxlines", type=int, default=0,
			   help="maxlines for multi-line regex"),
		Option("-m", "--journalmatch",
			   help="journalctl style matches overriding filter file. "
			   "\"systemd-journal\" only"),
		Option("-v", "--verbose", action='store_true',
			   help="Be verbose in output"),

		Option('-l', "--log-level", type="choice",
			   dest="log_level",
			   choices=('heavydebug', 'debug', 'info', 'warning', 'error', 'fatal'),
			   default=None,
			   help="Log level for the Fail2Ban logger to use"),
		Option("--print-all-missed", action='store_true',
			   help="Either to print all missed lines"),
		Option("--print-all-ignored", action='store_true',
			   help="Either to print all ignored lines"),
		Option("-t", "--log-traceback", action='store_true',
			   help="Enrich log-messages with compressed tracebacks"),
		Option("--full-traceback", action='store_true',
			   help="Either to make the tracebacks full, not compressed (as by default)"),

		])

	return p


class RegexStat(object):

	def __init__(self, failregex):
		self._stats = 0
		self._failregex = failregex
		self._ipList = list()

	def __str__(self):
		return "%s(%r) %d failed: %s" \
		  % (self.__class__, self._failregex, self._stats, self._ipList)

	def inc(self):
		self._stats += 1

	def getStats(self):
		return self._stats

	def getFailRegex(self):
		return self._failregex

	def appendIP(self, value):
		self._ipList.append(value)

	def getIPList(self):
		return self._ipList


class LineStats(object):
	"""Just a convenience container for stats
	"""
	def __init__(self):
		self.tested = self.matched = 0
		self.missed_lines = []
		self.ignored_lines = []

	def __str__(self):
		return "%(tested)d lines, %(ignored)d ignored, %(matched)d matched, %(missed)d missed" % self

	@property
	def ignored(self):
		return len(self.ignored_lines)

	@property
	def missed(self):
		return len(self.missed_lines)

	# just for convenient str
	def __getitem__(self, key):
		return getattr(self, key)


class Fail2banRegex(object):

	CONFIG_DEFAULTS = {'configpath' : "/etc/fail2ban/"}

	def __init__(self, opts):
		self._verbose = opts.verbose
		self._print_all_missed = opts.print_all_missed
		self._print_all_ignored = opts.print_all_ignored
		self._maxlines_set = False		  # so we allow to override maxlines in cmdline
		self._journalmatch = None

		if opts.datepattern:
			self.setDatePattern(opts.datepattern)

		if opts.encoding:
			self.encoding = opts.encoding
		else:
			self.encoding = locale.getpreferredencoding()

		self._filter = Filter(None)
		self._ignoreregex = list()
		self._failregex = list()
		self._line_stats = LineStats()

		if opts.maxlines:
			self.setMaxLines(opts.maxlines)
		if opts.journalmatch is not None:
			self.setJournalMatch(opts.journalmatch.split())

	def setDatePattern(self, pattern):
		self._filter.setDatePattern(pattern)

	def setMaxLines(self, v):
		if not self._maxlines_set:
			self._filter.setMaxLines(int(v))
			self._maxlines_set = True
			print "Use         maxlines : %d" % self._filter.getMaxLines()

	def setJournalMatch(self, v):
		if self._journalmatch is None:
			self._journalmatch = v

	def readRegex(self, value, regextype):
		assert(regextype in ('fail', 'ignore'))
		regex = regextype + 'regex'
		if os.path.isfile(value):
			reader = SafeConfigParserWithIncludes(defaults=self.CONFIG_DEFAULTS)
			try:
				reader.read(value)
				print "Use %11s file : %s" % (regex, value)
				# TODO: reuse functionality in client
				regex_values = [
					RegexStat(m)
					for m in reader.get("Definition", regex).split('\n')
					if m != ""]
			except NoSectionError:
				print "No [Definition] section in %s" % value
				return False
			except NoOptionError:
				print "No %s option in %s" % (regex, value)
				return False
			except MissingSectionHeaderError:
				print "No section headers in %s" % value
				return False

			# Read out and set possible value of maxlines
			try:
				maxlines = reader.get("Init", "maxlines")
			except (NoSectionError, NoOptionError):
				# No [Init].maxlines found.
				pass
			else:
				try:
					self.setMaxLines(maxlines)
				except ValueError:
					print "ERROR: Invalid value for maxlines (%(maxlines)r) " \
					      "read from %(value)s" % locals()
					return False
			# Read out and set possible value for journalmatch
			try:
				journalmatch = reader.get("Init", "journalmatch")
			except (NoSectionError, NoOptionError):
				# No [Init].journalmatch found.
				pass
			else:
				self.setJournalMatch(shlex.split(journalmatch))
		else:
			print "Use %11s line : %s" % (regex, shortstr(value))
			regex_values = [RegexStat(value)]

		setattr(self, "_" + regex, regex_values)
		for regex in regex_values:
			getattr(
				self._filter,
				'add%sRegex' % regextype.title())(regex.getFailRegex())
		return True

	def testIgnoreRegex(self, line):
		found = False
		try:
			ret = self._filter.ignoreLine(line)
			if ret is not None:
				found = True
				regex = self._ignoreregex[ret].inc()
		except RegexException, e:
			print e
			return False
		return found

	def testRegex(self, line):
		orgLineBuffer = self._filter._Filter__lineBuffer
		fullBuffer = len(orgLineBuffer) >= self._filter.getMaxLines()
		try:
			ret = self._filter.processLine(line, checkAllRegex=True)
			for match in ret:
				# Append True/False flag depending if line was matched by
				# more than one regex
				match.append(len(ret)>1)
				regex = self._failregex[match[0]]
				regex.inc()
				regex.appendIP(match)
		except RegexException, e:
			print e
			return False
		except IndexError:
			print "Sorry, but no <host> found in regex"
			return False
		for bufLine in orgLineBuffer[int(fullBuffer):]:
			if bufLine not in self._filter._Filter__lineBuffer:
				if self.removeMissedLine(bufLine):
					self._line_stats.matched += 1
		return len(ret) > 0

	def removeMissedLine(self, line):
		"""Remove `line` from missed lines, by comparing without time match"""
		for n, missed_line in \
				enumerate(reversed(self._line_stats.missed_lines)):
			timeMatch = self._filter.dateDetector.matchTime(
				missed_line, incHits=False)
			if timeMatch:
				logLine = (missed_line[:timeMatch.start()] +
					missed_line[timeMatch.end():])
			else:
				logLine = missed_line
			if logLine.rstrip("\r\n") == line:
				self._line_stats.missed_lines.pop(
					len(self._line_stats.missed_lines) - n - 1)
				return True
		return False

	def process(self, test_lines):

		for line_no, line in enumerate(test_lines):
			if line.startswith('#') or not line.strip():
				# skip comment and empty lines
				continue
			is_ignored = fail2banRegex.testIgnoreRegex(line)
			if is_ignored:
				self._line_stats.ignored_lines.append(line)

			if fail2banRegex.testRegex(line):
				assert(not is_ignored)
				self._line_stats.matched += 1
			else:
				if not is_ignored:
					self._line_stats.missed_lines.append(line)
			self._line_stats.tested += 1

			if line_no % 10 == 0:
				self._filter.dateDetector.sortTemplate()

	def printLines(self, ltype):
		lstats = self._line_stats
		assert(len(lstats.missed_lines) == lstats.tested - (lstats.matched + lstats.ignored))
		l = lstats[ltype + '_lines']
		if len(l):
			header = "%s line(s):" % (ltype.capitalize(),)
			if len(l) < 20 or getattr(self, '_print_all_' + ltype):
				pprint_list([x.rstrip() for x in l], header)
			else:
				print "%s: too many to print.  Use --print-all-%s " \
					  "to print all %d lines" % (header, ltype, len(l))

	def printStats(self):
		print
		print "Results"
		print "======="

		def print_failregexes(title, failregexes):
			# Print title
			total, out = 0, []
			for cnt, failregex in enumerate(failregexes):
				match = failregex.getStats()
				total += match
				if (match or self._verbose):
					out.append("%2d) [%d] %s" % (cnt+1, match, failregex.getFailRegex()))

				if self._verbose and len(failregex.getIPList()):
					for ip in failregex.getIPList():
						timeTuple = time.localtime(ip[2])
						timeString = time.strftime("%a %b %d %H:%M:%S %Y", timeTuple)
						out.append(
							"    %s  %s%s" % (
								ip[1],
								timeString,
								ip[3] and " (multiple regex matched)" or ""))

			print "\n%s: %d total" % (title, total)
			pprint_list(out, " #) [# of hits] regular expression")
			return total

		# Print title
		total = print_failregexes("Failregex", self._failregex)
		_ = print_failregexes("Ignoreregex", self._ignoreregex)


		print "\nDate template hits:"
		out = []
		for template in self._filter.dateDetector.getTemplates():
			if self._verbose or template.getHits():
				out.append("[%d] %s" % (template.getHits(), template.getName()))
		pprint_list(out, "[# of hits] date format")

		print "\nLines: %s" % self._line_stats

		self.printLines('ignored')
		self.printLines('missed')

		return True


if __name__ == "__main__":

	parser = get_opt_parser()
	(opts, args) = parser.parse_args()

	fail2banRegex = Fail2banRegex(opts)

	# We need 2 or 3 parameters
	if not len(args) in (2, 3):
		sys.stderr.write("ERROR: provide both <LOG> and <REGEX>.\n\n")
		parser.print_help()
		sys.exit(-1)

	# TODO: taken from -testcases -- move common functionality somewhere
	if opts.log_level is not None: # pragma: no cover
		# so we had explicit settings
		logSys.setLevel(getattr(logging, opts.log_level.upper()))
	else: # pragma: no cover
		# suppress the logging but it would leave unittests' progress dots
		# ticking, unless like with '-l fatal' which would be silent
		# unless error occurs
		logSys.setLevel(getattr(logging, 'FATAL'))

	# Add the default logging handler
	stdout = logging.StreamHandler(sys.stdout)

	fmt = 'D: %(message)s'

	if opts.log_traceback:
		Formatter = FormatterWithTraceBack
		fmt = (opts.full_traceback and ' %(tb)s' or ' %(tbc)s') + fmt
	else:
		Formatter = logging.Formatter

	# Custom log format for the verbose tests runs
	if opts.verbose: # pragma: no cover
		stdout.setFormatter(Formatter(' %(asctime)-15s %(thread)s' + fmt))
	else: # pragma: no cover
		# just prefix with the space
		stdout.setFormatter(Formatter(fmt))
	logSys.addHandler(stdout)

	print
	print "Running tests"
	print "============="
	print

	cmd_log, cmd_regex = args[:2]

	fail2banRegex.readRegex(cmd_regex, 'fail') or sys.exit(-1)

	if len(args) == 3:
		fail2banRegex.readRegex(args[2], 'ignore') or sys.exit(-1)

	if os.path.isfile(cmd_log):
		try:
			hdlr = open(cmd_log, 'rb')
			print "Use         log file : %s" % cmd_log
			print "Use         encoding : %s" % fail2banRegex.encoding
			test_lines = file_lines_gen(hdlr)
		except IOError, e:
			print e
			sys.exit(-1)
	elif cmd_log == "systemd-journal":
		if not journal:
			print "Error: systemd library not found. Exiting..."
			sys.exit(-1)
		myjournal = journal.Reader(converters={'__CURSOR': lambda x: x})
		journalmatch = fail2banRegex._journalmatch
		if journalmatch:
			try:
				for element in journalmatch:
					if element == "+":
						myjournal.add_disjunction()
					else:
						myjournal.add_match(element)
			except ValueError:
				print "Error: Invalid journalmatch: %s" % shortstr(" ".join(journalmatch))
				sys.exit(-1)
		print "Use    journal match : %s" % " ".join(journalmatch)
		test_lines = journal_lines_gen(myjournal)
	else:
		print "Use      single line : %s" % shortstr(cmd_log)
		test_lines = [ cmd_log ]
	print

	fail2banRegex.process(test_lines)

	fail2banRegex.printStats() or sys.exit(-1)
