#!/usr/bin/env python
# encoding: utf-8
"""
bgp

Created by Thomas Mangin
Copyright (c) 2013-2017 Exa Networks. All rights reserved.
License: 3-clause BSD. (See the COPYRIGHT file)
"""

import os
import pwd
import sys
import time
import errno
import socket
import thread
import signal
import asyncore
import subprocess
from struct import unpack

import psutil

PORT = os.environ.get('exabgp.tcp.port',os.environ.get('exabgp_tcp_port','179'))
SIGNAL = dict([(name,getattr(signal,name)) for name in dir(signal) if name.startswith('SIG')])


def flushed (*output):
	print ' '.join(str(_) for _ in output)
	sys.stdout.flush()


def bytestream (value):
	return ''.join(['%02X' % ord(_) for _ in value])


def dump (value):
	def spaced (value):
		even = None
		for v in value:
			if even is False:
				yield ' '
			yield '%02X' % ord(v)
			even = not even
	return ''.join(spaced(value))


def cdr_to_length (cidr):
	if cidr > 24:
		return 4
	if cidr > 16:
		return 3
	if cidr > 8:
		return 2
	if cidr > 0:
		return 1
	return 0


class BGPHandler(asyncore.dispatcher_with_send):
	counter = 0

	keepalive = chr(0xFF)*16 + chr(0x0) + chr(0x13) + chr(0x4)

	_name = {
		chr(1): 'OPEN',
		chr(2): 'UPDATE',
		chr(3): 'NOTIFICATION',
		chr(4): 'KEEPALIVE',
	}

	def signal (self,myself,signal_name='SIGUSR1'):
		signal_number = SIGNAL.get(signal_name,'')
		if not signal_number:
			self.announce('invalid signal name in configuration : %s' % signal_name)
			self.announce('options are: %s' % ','.join(SIGNAL.keys()))
			sys.exit(1)

		conf_name = sys.argv[1].split('/')[-1].split('.')[0]

		processes = []

		for pid in psutil.pids():
			try:
				process = psutil.Process(pid)

				process_name = process.name().lower()
				if 'python' not in process_name and 'pypy' not in process_name:
					continue

				cmdline = process.cmdline()
				if len(cmdline) < 2:
					continue

				if not cmdline[1].endswith('/bgp.py'):
					continue

				if conf_name not in cmdline[-1]:
					continue

				if not cmdline[-1].endswith('.conf'):
					continue

				processes.append(pid)

			except psutil.NoSuchProcess:
				self.announce('signal/psutil: no such processes')
			except psutil.AccessDenied:
				self.announce('signal/psutil: access denied')

		if len(processes) == 0:
			self.announce('no running process found, this should not happend, quitting')
			sys.exit(1)

		if len(processes) > 1:
			self.announce('more than one process running, this should not happend, quitting')
			sys.exit(1)

		try:
			self.announce('sending signal %s to ExaBGP (pid %s)\n' % (signal_name,processes[0]))
			os.kill(int(processes[0]),signal_number)
		except Exception,exc:
			self.announce('\n     failed: %s' % str(exc))

	def kind (self, header):
		return header[18]

	def isupdate (self, header):
		return header[18] == chr(2)

	def isnotification (self, header):
		return header[18] == chr(4)

	def name (self, header):
		return self._name.get(header[18],'SOME WEIRD RFC PACKET')

	def routes (self, header, body):
		len_w = unpack('!H',body[0:2])[0]
		withdrawn = [ord(_) for _ in body[2:2+len_w]]
		len_a = unpack('!H',body[2+len_w:2+len_w+2])[0]
		announced = [ord(_) for _ in body[2+len_w + 2+len_a:]]

		if not withdrawn and not announced:
			if len(body) == 4:
				yield 'eor:1:1'
			elif len(body) == 11:
				yield 'eor:%d:%d' % (ord(body[-2]),ord(body[-1]))
			else:  # undecoded MP route
				yield 'mp:'
			return

		while withdrawn:
			cdr = withdrawn.pop(0)
			size = cdr_to_length(cdr)
			r = [0,0,0,0]
			for index in range(size):
				r[index] = withdrawn.pop(0)
			yield 'withdraw:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

		while announced:
			cdr = announced.pop(0)
			size = cdr_to_length(cdr)
			r = [0,0,0,0]
			for index in range(size):
				r[index] = announced.pop(0)
			yield 'announce:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

	def notification (self, header, body):
		yield 'notification:%d,%d' % (ord(body[0]),ord(body[1])), bytestream(body)

	def announce (self, *args):
		flushed('    ',self.ip,self.port,' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

	def check_signal (self):
		if self.messages and self.messages[0].startswith('signal:'):
			name = self.messages.pop(0).split(':')[-1]
			self.signal(os.getppid(),name)

	def setup (self, ip, port, messages, options):
		self.ip = ip
		self.port = port
		self.options = options
		self.handle_read = self.handle_open
		self.sequence = {}
		self.raw = False
		for rule in messages:
			sequence,announcement = rule.split(':',1)
			if announcement.startswith('raw:'):
				self.raw = True
				announcement = ''.join(announcement[4:].replace(':',''))
			self.sequence.setdefault(sequence,[]).append(announcement)
		self.update_sequence()
		return self

	def update_sequence (self):
		if self.options['sink'] or self.options['echo']:
			self.messages = []
			return True
		keys = sorted(list(self.sequence))
		if keys:
			key = keys[0]
			self.messages = self.sequence[key]
			self.step = key
			del self.sequence[key]

			self.check_signal()
			# we had a list with only one signal
			if not self.messages:
				return self.update_sequence()
			return True
		return False

	def read_message (self):
		header = ''
		while len(header) != 19:
			try:
				left = 19-len(header)
				header += self.recv(left)
				if left == 19-len(header):  # ugly
					# the TCP session is gone.
					return None,None
			except socket.error,exc:
				if exc.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise exc

		length = unpack('!H',header[16:18])[0] - 19

		body = ''
		while len(body) != length:
			try:
				left = length-len(body)
				body += self.recv(left)
			except socket.error,exc:
				if exc.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise exc

		return header,body

	def handle_open (self):
		# reply with a IBGP response with the same capability (just changing routerID)
		header,body = self.read_message()
		routerid = chr((ord(body[8])+1) & 0xFF)
		o = header+body[:8]+routerid+body[9:]

		if self.options['send-unknown-capability']:
			# hack capability 66 into the message
			content = 'loremipsum'
			cap66 = chr(66) + chr(len(content)) + content
			param = chr(2) + chr(len(cap66)) + cap66
			o = o[:17] + chr(ord(o[17])+len(param)) + o[18:28] + \
				chr(ord(o[28])+len(param)) + o[29:] + param

		self.send(o)
		self.send(self.keepalive)

		if self.options['send-default-route']:
			self.send(
				chr(0xFF)*16 +
				chr(0x00) + chr(0x31) +
				chr(0x02) +
				chr(0x00) + chr(0x00) +
				chr(0x00) + chr(0x15) +
				'' + chr(0x40) + chr(0x01) + chr(0x01) + chr(0x00) +
				'' + chr(0x40) + chr(0x02) + chr(0x00) +
				'' + chr(0x40) + chr(0x03) + chr(0x04) + chr(0x7F) + chr(0x00) + chr(0x00) + chr(0x01) +
				'' + chr(0x40) + chr(0x05) + chr(0x04) + chr(0x00) + chr(0x00) + chr(0x00) + chr(0x64) +
				chr(0x20) + chr(0x00) + chr(0x00) + chr(0x00) + chr(0x00)
			)
			self.announce('sending default-route\n')

		self.handle_read = self.handle_keepalive

	def handle_keepalive (self):
		header,body = self.read_message()

		if header is None:
			self.announce('connection closed')
			self.close()
			if self.options['send-notification']:
				self.announce('successful')
				sys.exit(0)
			return

		if self.raw:
			def parser (self, header, body):
				if body:
					yield bytestream(header + body)
		else:
			parser = self._decoder.get(self.kind(header),None)

		if self.options['sink']:
			self.announce('received %d: %s' % (self.counter,'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body))))
			self.send(self.keepalive)
			return

		if self.options['echo']:
			self.announce('received %d: %s' % (self.counter,'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body))))
			self.send(header+body)
			self.announce('sent     %d: %s' % (self.counter,'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body))))
			return

		if parser:
			for announcement in parser(self,header,body):
				self.send(self.keepalive)

				if announcement.startswith('eor:'):  # skip EOR
					self.announce('skipping eor',announcement)
					continue

				if announcement.startswith('mp:'):  # skip unparsed MP
					self.announce('skipping multiprotocol :',dump(body))
					continue

				self.counter += 1

				if announcement in self.messages:
					self.messages.remove(announcement)
					if self.raw:
						self.announce('received %d (%1s%s):' % (self.counter,self.options['letter'],self.step),'%s:%s:%s:%s' % (announcement[:32],announcement[32:36],announcement[36:38],announcement[38:]))
					else:
						self.announce('received %d (%1s%s):' % (self.counter,self.options['letter'],self.step), announcement)
					self.check_signal()
				else:
					if self.raw:
						self.announce('received %d (%1s%s):' % (self.counter,self.options['letter'],self.step),'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body)))
					else:
						self.announce('received %d     :' % self.counter,announcement)

					if len(self.messages) > 1:
						self.announce('expected one of the following :')
						for message in self.messages:
							if message.startswith('F'*32):
								self.announce('                 %s:%s:%s:%s' % (message[:32],message[32:36],message[36:38],message[38:]))
							else:
								self.announce('                 %s' % message)
					elif self.messages:
						message = self.messages[0].upper()
						if message.startswith('F'*32):
							self.announce('expected       : %s:%s:%s' % (message[:32],message[32:38],message[38:]))
						else:
							self.announce('expected       : %s' % message)
					else:
						# can happen when the thread is still running
						self.announce('extra data')
						sys.exit(1)

					sys.exit(1)

				if not self.messages:
					if self.options['single-shot']:
						self.announce('successful (partial test)')
						sys.exit(0)

					if not self.update_sequence():
						if self.options['exit']:
							self.announce('successful')
							sys.exit(0)
		else:
			self.send(self.keepalive)

		if self.options['send-notification']:
			notification = 'closing session because we can'
			self.send(
				chr(0xFF)*16 +
				chr(0x00) + chr(19+2+len(notification)) +
				chr(0x03) +
				chr(0x06) +
				chr(0x00) +
				notification
			)

	_decoder = {
		chr(2): routes,
		chr(3): notification,
	}


class BGPServer (asyncore.dispatcher):
	def announce (self, *args):
		flushed('    ' + ' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

	def __init__ (self, host, port, messages, options):
		asyncore.dispatcher.__init__(self)

		if ':' in host:
			self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
		else:
			self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
		self.set_reuse_addr()
		self.bind((host,port))
		self.listen(5)

		self.messages = {}

		self.options = {
			'send-unknown-capability': False,  # add an unknown capability to the open message
			'send-default-route': False,       # send a default route to the peer
			'send-notification': False,        # send notification messages to the backend
			'signal-SIGUSR1': 0,               # send SIGUSR1 after X seconds
			'single-shot': False,              # we can not test signal on python 2.6
			'sink': False,                     # just accept whatever is sent
			'echo': False,                     # just accept whatever is sent
		}
		self.options.update(options)

		for message in messages:
			if message.strip() == 'option:open:send-unknown-capability':
				self.options['send-unknown-capability'] = True
				continue
			if message.strip() == 'option:update:send-default-route':
				self.options['send-default-route'] = True
				continue
			if message.strip() == 'option:notification:send-notification':
				self.options['send-notification'] = True
				continue
			if message.strip().startswith('option:SIGUSR1:'):
				def notify (delay,myself):
					time.sleep(delay)
					self.signal(myself)
					time.sleep(10)
					thread.exit()

				# Python 2.6 can not perform this test as it misses the function
				if 'check_output' in dir(subprocess):
					thread.start_new_thread(notify,(int(message.split(':')[-1]),os.getpid()))
				else:
					self.options['single-shot'] = True
				continue

			if message[0].isalpha():
				index,content = message[:1].upper(), message[1:]
			else:
				index,content = 'A',message
			self.messages.setdefault(index,[]).append(content)

	def handle_accept (self):
		messages = None
		for number in range(ord('A'),ord('Z')+1):
			letter = chr(number)
			if letter in self.messages:
				messages = self.messages[letter]
				del self.messages[letter]
				break

		if self.options['sink']:
			flushed('\nsink mode - send us whatever, we can take it ! :p\n')
			messages = []
		elif self.options['echo']:
			flushed('\necho mode - send us whatever, we can parrot it ! :p\n')
			messages = []
		elif not messages:
			self.announce('we used all the test data available, can not handle this new connection')
			sys.exit(1)
		else:
			flushed('using :\n   ', '\n    '.join(messages),'\n\nconversation:\n')

		self.options['exit'] = not len(self.messages.keys())
		self.options['letter'] = letter

		pair = self.accept()
		if pair is not None:
			sock,addr = pair
			handler = BGPHandler(sock).setup(
				*addr[:2],
				messages=messages,
				options=self.options
			)


def drop ():
	uid = os.getuid()
	gid = os.getgid()

	if uid and gid:
		return

	for name in ['nobody',]:
		try:
			user = pwd.getpwnam(name)
			nuid = int(user.pw_uid)
			ngid = int(user.pw_uid)
		except KeyError:
			pass

	if not gid:
		os.setgid(ngid)
	if not uid:
		os.setuid(nuid)


def main ():
	if len(sys.argv) <= 1:
		flushed('--sink   accept any BGP messages and reply with a keepalive')
		flushed('--echo   accept any BGP messages send it back to the emiter')
		flushed('a list of expected route announcement/withdrawl in the format <number>:announce:<ipv4-route> <number>:withdraw:<ipv4-route> <number>:raw:<exabgp hex dump : separated>')
		flushed('for example:',sys.argv[0],'1:announce:10.0.0.0/8 1:announce:192.0.2.0/24 2:withdraw:10.0.0.0/8 ')
		flushed('routes with the same <number> can arrive in any order')
		sys.exit(1)

	options = {
		'sink': False,
		'echo': False,
	}

	if sys.argv[1] == '--sink':
		messages = []
		options['sink'] = True
	elif sys.argv[1] == '--echo':
		messages = []
		options['echo'] = True
	else:
		try:
			with open(sys.argv[1]) as content:
				messages = [_.strip() for _ in content.readlines() if _.strip() and '#' not in _]
		except IOError:
			flushed('could not open file', sys.argv[1])
			sys.exit(1)

	try:
		BGPServer('127.0.0.1',int(PORT),messages,options)
		BGPServer('::1',int(PORT),messages,options)
		drop()
		asyncore.loop()
	except socket.error,exc:
		if exc.errno == errno.EACCES:
			flushed('failure: could not bind to port %s - most likely not run as root' % PORT)
		elif exc.errno == errno.EADDRINUSE:
			flushed('failure: could not bind to port %s - port already in use' % PORT)
		else:
			flushed('failure', str(exc))


if __name__ == '__main__':
	main()
