#!/usr/bin/python3

import sys, dbus, json, subprocess, os, glob, traceback, shutil, re, time
import lxml.objectify

sys.path.insert(0, os.path.dirname(__file__))
import gcatypes
sys.path = sys.path[1:]

has_failures = False
(console_width, console_height) = shutil.get_terminal_size(fallback=(80, 25))

class Interface:
    class Arg:
        def __init__(self):
            self.name = ''
            self.direction = 'in'
            self.signature = ''

        def assert_equal(self, other):
            if self.direction != other.direction:
                raise ValueError("Expected direction {0} for argument {1}, but got {2}".format(self.direction, self.name, other.direction))

            if self.signature != other.signature:
                raise ValueError("Expected signature {0} for argument {1}, but got {2}".format(self.signature, self.name, other.signature))

    class Method:
        def __init__(self):
            self.name = ''
            self.args = []

        def assert_equal(self, other):
            if self.name != other.name:
                raise ValueError("Methods have different names, expected {0} but got {1}".format(self.name, other.name))

            for i, a in enumerate(self.args):
                if i >= len(other.args):
                    raise ValueError("Missing expected argument {0}:{1} missing".format(a.name, a.signature))

                a.assert_equal(other.args[i])

            for a in other.args[len(self.args):]:
                raise ValueError("Unexpected argument {0}:{1}".format(a.name, a.signature))

    def __init__(self):
        self.name = ''
        self.methods = {}

    def assert_equal(self, other):
        if self.name != other.name:
            raise ValueError("Interfaces have different names, expected {0} but got {1}".format(self.name, other.name))

        otherm = {m: True for m in other.methods}

        for m in self.methods:
            m1 = self.methods[m]

            if not m in otherm:
                raise ValueError("Missing method {0}".format(m))

            del otherm[m]
            m1.assert_equal(other.methods[m])

        for k in otherm:
            raise ValueError("Unexpected method {0}".format(k))

    @classmethod
    def from_xml(cls, xml):
        ret = {}
        obj = lxml.objectify.fromstring(xml)

        if not hasattr(obj, 'interface'):
            return ret

        for i in obj.interface:
            ii = cls()

            ii.name = i.attrib['name']
            ret[ii.name] = ii

            if not hasattr(i, 'method'):
                continue

            for m in i.method:
                meth = cls.Method()
                meth.name = m.attrib['name']

                ii.methods[meth.name] = meth

                if not hasattr(m, 'arg'):
                    continue

                for a in m.arg:
                    aa = cls.Arg()
                    aa.direction = a.attrib['direction']
                    aa.signature = a.attrib['type']

                    if 'name' in a.attrib:
                        aa.name = a.attrib['name']

                    meth.args.append(aa)

        return ret

    @classmethod
    def from_json(cls, js):
        with open(js) as f:
            contents = f.read()

        js = json.loads(contents)
        ret = {}

        for k in js:
            ii = cls()
            ii.name = k

            ret[k] = ii

            methods = js[k]

            for mname in methods:
                meth = cls.Method()
                meth.name = mname

                m = methods[mname]

                ii.methods[meth.name] = meth

                for a in m:
                    aa = cls.Arg()
                    aa.direction = a['direction']
                    aa.signature = a['type']

                    if 'name' in a:
                        aa.name = a['name']

                    meth.args.append(aa)

        return ret

def test(name):
    def decorator(f):
        def dummy(*args, **kwargs):
            class C:
                ansire = re.compile('\033\\[[^m]*m')

                def __init__(self, args, kwargs):
                    self.args = args
                    self.kwargs = kwargs

                def __call__(self, *args, **kwargs):
                    if len(args) == 0 and len(kwargs) == 0:
                        return f(*self.args, **self.kwargs)
                    else:
                        return f(self.args[0], *args, **kwargs)

                def __enter__(self):
                    text = '  TEST {0} ({1})'.format(name, ', '.join([str(a) for a in self.args[1:]]))

                    self.enter_text = text
                    sys.stdout.write(text)
                    return self

                def __exit__(self, typ, value, tb):
                    global console_width

                    if not value is None:
                        global has_failures
                        has_failures = True

                        rettext = '[\033[31mFAIL\033[0m]'
                    else:
                        rettext = '[\033[32mOK\033[0m]'

                    cleart = self.ansire.sub('', (self.enter_text + rettext))

                    print(' \033[30m{0}\033[0m {1}'.format('.' * (console_width - len(cleart) - 2), rettext))

                    if not value is None:
                        print('      \033[31m{0}\n      {1}\033[0m'.format(str(value), ''.join(traceback.format_tb(tb)).replace("\n", '\n      ')))

                    return True

            return C(args, kwargs)

        return dummy

    return decorator

class ServiceTest:
    interfaces = Interface.from_json(os.path.join(os.path.dirname(__file__), 'interfaces.json'))

    def __init__(self, bus, test):
        self.test = test
        self.language = test['language']
        self.name = 'org.gnome.CodeAssist.v1.' + self.language
        self.path = '/org/gnome/CodeAssist/v1/' + self.language
        self.bus = bus

    def full_path(self, path):
        if path != '/':
            return self.path + path
        else:
            return self.path

    @test('object')
    def test_object(self, path):
        return self.get_object(path)

    @test('interface')
    def test_interface(self, path, interface):
        obj = self.get_object(path)

        intro = dbus.Interface(obj, 'org.freedesktop.DBus.Introspectable')
        xml = intro.Introspect()
        introf = Interface.from_xml(xml)

        if not interface in introf:
            raise ValueError("Missing interface {0}".format(interface))

        if not interface in self.interfaces:
            raise ValueError("Unknown interface {0}".format(interface))

        introf[interface].assert_equal(self.interfaces[interface])

    def assert_source_location(self, l1, l2):
        if l1.line != l2.line:
            raise ValueError("Expected source location line {0} but got {1}".format(l1.line, l2.line))

        if l1.column != l2.column:
            raise ValueError("Expected source location column {0} but got {1}".format(l1.column, l2.column))

    def assert_source_range(self, r1, r2):
        if r1.file != r2.file:
            raise ValueError("Expected source range file {0} but got {1}".format(r1.file, r2.file))

        self.assert_source_location(r1.start, r2.start)
        self.assert_source_location(r1.end, r2.end)

    def assert_fixit(self, f1, f2):
        self.assert_source_range(f1.location, f2.location)

        if f1.replacement != f2.replacement:
            raise ValueError("Expected fixit replacement {0} but got {1}".format(f1.replacement, f2.replacement))

    def assert_diagnostic(self, d1, d2):
        if d1.severity != d2.severity:
            raise ValueError("Expected severity '{0}' but got '{1}'".format(d1.severity, d2.severity))

        for i, v in enumerate(d1.fixits):
            if i >= len(d2.fixits):
                raise ValueError("Expected fixit {0}".format(v))

            self.assert_fixit(v, d2.fixits[i])

        for v in d2.fixits[len(d1.fixits):]:
            raise ValueError("Unexpected fixit {0}".format(v))

        for i, v in enumerate(d1.locations):
            if i >= len(d2.locations):
                raise ValueError("Expected source range {0}".format(v))

            self.assert_source_range(v, d2.locations[i])

        for v in d2.locations[len(d1.locations):]:
            raise ValueError("Unexpected source range {0}".format(v))

        if d1.message != d2.message:
            raise ValueError("Expected message '{0}' but got '{1}'".format(d1.message, d2.message))

    @test('diagnostics')
    def test_diagnostics(self, d1, d2):
        for i, v in enumerate(d1):
            if i >= len(d2):
                raise ValueError("Expected diagnostic {0} (got {1} diagnostics instead of {2})".format(v, len(d2), len(d1)))

            self.assert_diagnostic(v, d2[i])

        for v in d2[len(d1):]:
            raise ValueError("Unexpected diagnostic {0}".format(v))

    def file_path(self, p):
        return os.path.abspath(os.path.join(os.path.dirname(__file__), 'backends', p))

    def run_parse(self, p):
        path = p['path']
        obj = self.get_object('/')

        iface = dbus.Interface(obj, 'org.gnome.CodeAssist.v1.Service')
        doc = iface.Parse(self.file_path(path), '', (0, 0), {})

        doc = doc[len(self.path):]

        with self.test_object(doc) as t:
            obj = t()

        return doc, obj

    @test('remote documents')
    def test_remote_documents(self, orig, docmap):
        for f in orig:
            ff = self.file_path(f)

            if not ff in docmap:
                raise ValueError('Expected remote document {0}'.format(f))

            del docmap[ff]

        for k in docmap:
            raise ValueError('Unexpected remote document {0}'.format(k))

    def get_object(self, path):
        return self.bus.get_object(self.name, self.full_path(path))

    def run_parse_all(self, p):
        obj = self.get_object('/')

        path = self.file_path(p['path'])
        docs = [gcatypes.OpenDocument(self.file_path(d)) for d in p['documents']]
        docs.insert(0, gcatypes.OpenDocument(path))

        iface = dbus.Interface(obj, 'org.gnome.CodeAssist.v1.Project')
        return iface.ParseAll(self.file_path(path), [d.to_tuple() for d in docs], (0, 0), {})

    def verify_parse_diagnostics(self, path, obj, diagnostics):
        with self.test_interface(path, 'org.gnome.CodeAssist.v1.Diagnostics') as t:
            t()

        diag = dbus.Interface(obj, 'org.gnome.CodeAssist.v1.Diagnostics')

        ret = [gcatypes.Diagnostic.from_tuple(dd) for dd in diag.Diagnostics()]
        orig = [gcatypes.Diagnostic.from_json(dd) for dd in diagnostics]

        with self.test_diagnostics(path) as t:
            t(orig, ret)

    @test('exited')
    def test_exited(self):
        try:
            self.bus.get_name_owner(self.name)
        except:
            return

        raise ValueError('Service did not exit')

    @test('dispose')
    def test_dispose(self, path):
        obj = self.get_object('/')
        iface = dbus.Interface(obj, 'org.gnome.CodeAssist.v1.Service')
        iface.Dispose(path)
        time.sleep(0.1)

    def test_parse(self, d):
        path, parsed = self.run_parse(d['parse'])
        self.verify_parse_diagnostics(path, parsed, d['diagnostics'])

        with self.test_dispose(self.file_path(d['parse']['path'])) as t:
            t()

        with self.test_exited() as t:
            t()

    def test_parse_all(self, d):
        docs = self.run_parse_all(d['parse_all'])

        def make_doc_map():
            return {str(d[0]): str(d[1])[len(self.path):] for d in docs}

        with self.test_remote_documents(*d['documents']) as t:
            t(d['documents'], make_doc_map())

        docmap = make_doc_map()
        diags = d['diagnostics']

        for k in diags:
            remote = docmap[self.file_path(k)]

            with self.test_object(remote) as t:
                obj = t()

            self.verify_parse_diagnostics(remote, obj, diags[k])

        for k in docmap:
            with self.test_dispose(k) as t:
                t()

        with self.test_exited() as t:
            t()

    def run_diagnostic(self, d):
        if 'parse' in d:
            self.test_parse(d)
        elif 'parse_all' in d:
            self.test_parse_all(d)
        else:
            raise ValueError('Do not know how to parse diagnostic')

    def run(self):
        print('\n\033[1mTESTING {0}\033[0m'.format(self.name))

        # Test for default interfaces and paths
        with self.test_object('/') as t:
            t()

        with self.test_interface('/', 'org.gnome.CodeAssist.v1.Service') as t:
            t()

        if 'interfaces' in self.test:
            for interface in self.test['interfaces']:
                with self.test_interface('/', interface) as t:
                    t()

        with self.test_object('/document') as t:
            t()

        with self.test_interface('/document', 'org.gnome.CodeAssist.v1.Document') as t:
            t()

        if 'document_interfaces' in self.test:
            for interface in self.test['document_interfaces']:
                with self.test_interface('/document', interface) as t:
                    t()

        if 'diagnostics' in self.test:
            diagnostics = self.test['diagnostics']

            for d in diagnostics:
                self.run_diagnostic(d)

def run_test(testfile):
    with open(testfile, 'r') as j:
        test = json.loads(j.read())

    bus = dbus.SessionBus(private=True)
    t = ServiceTest(bus, test)
    t.run()

    bus.close()

dirname = os.path.dirname(__file__)

if len(sys.argv) <= 1:
    testfiles = glob.glob(os.path.join(dirname, 'backends', '*.json'))
else:
    testfiles = [os.path.join(dirname, 'backends', d + '.json') for d in sys.argv[1:]]

dn = open(os.devnull, 'w')
daemon = subprocess.Popen(['dbus-daemon', '--print-address', '--nofork', '--config-file', os.path.join(dirname, 'dbus.conf')], stdout=subprocess.PIPE, stderr=dn, close_fds=True)

try:
    address = daemon.stdout.readline().rstrip()

    os.environ['DBUS_SESSION_BUS_ADDRESS'] = address.decode('utf-8')
    os.environ['DBUS_SESSION_BUS_PID'] = str(daemon.pid)

    for testfile in testfiles:
        run_test(testfile)
finally:
    daemon.terminate()
    print(daemon.stdout.read().decode('utf-8'))
    daemon.wait()

if has_failures:
    sys.exit(1)

# vi:ts=4:et
