#include "optimise.h"
#include "blocks.h"
//#include "collections.h"
#define DEPTH (IR->x.depth)
#define DONT_CROSS_CALL 1

static void intraBlock(Block b);
extern void increaseTreeDepth(Node n, int depth);

char* koop() {
    flow_applyToBlocks(intraBlock);
    return NULL;
}

typedef struct node_pair {
    int separation;
    int location;
    Node first;
    Node second;
} *NodePair;

int pairID(NodePair n) {
    Node kid = n->first->kids[0];
    return getTempIndex(kid);
}

int pairSeparation(NodePair n) {
    return n->separation;
}

int pairLocation(NodePair n) {
    return n->location;
}

static int* nodeLocations;
static Node* foundNodes;
static ArrayList pairs;
static int location = 0;

static void makePair(Node second, Node kid) {
    int id = getTempIndex(kid);
    Node first = foundNodes[id];
    if (first) {
        assert(first != second);
        NodePair nodes = allocate(sizeof(struct node_pair), FUNC);
        nodes->separation = location - nodeLocations[id];
        nodes->location = nodeLocations[id];
        nodes->first = first;
        nodes->second = second;
        list_add(pairs, nodes);
    }
    foundNodes[id] = second;
    nodeLocations[id] = location;
}

static void scanTree(Node p) {
    if (p == NULL) {
        return;
    }
    if (DONT_CROSS_CALL && generic(p->op) == CALL) {
        int i;
        for (i = 0; i < local_variable_count; i++) {
            foundNodes[i] =  NULL;
        }
    } else if (p->kids[0]) {
        if (generic(p->kids[0]->op) == VREG) {
            makePair(p, p->kids[0]);
            location++;
        } else {
            scanTree(p->kids[0]);
            if (p->kids[1])
                scanTree(p->kids[1]);
        }
    } else {
        location++;
    }
}


ArrayList getPairs(Block b) {
    Node p, kid, end;
    int i, size;
    size = local_variable_count;
    pairs = list_new(size); 
    nodeLocations = allocate(sizeof(int) * size, FUNC);
    foundNodes = allocate(sizeof(Node) * size, FUNC);
    for (i = 0; i < size; i++) {
        foundNodes[i] =  NULL;
    }
    end = block_lastNode(b)->x.next;
    for (p = block_firstNode(b); p != end; p = p->x.next) {
        if (generic(p->op) == ASGN) {
            if (generic(p->kids[0]) == VREG) {
                scanTree(p->kids[1]);
                makePair(p, p->kids[0]);
                location++;
            } else {
                scanTree(p->kids[0]);
                scanTree(p->kids[1]);
            }
        } else {
            scanTree(p);
        }
    }
    return pairs;
}

void increaseDepth(Node from, Node to) {
    Node n;
    int generic = generic(from->op);
    assert(generic == VREG || generic == STACK);
    assert(from->kids[0] == NULL);
    generic = generic(to->op);
    assert(generic == VREG || generic == STACK);
    assert(to->kids[0] == NULL);
    for (n = from->x.next; n != to->x.next; n = n->x.next) {
        if (n->x.next || to->x.next == NULL) {
            incrementDepth(n);
        } else { // Broken chain
            prettytree(n);
            assert(0);    
        }
    }
}


void applyPair(Block b, NodePair pair) {
//    fprintf(stderr, "Applying transforamtion to ");
//    printPair(pair);
    assert(pair->separation > 0);
    Node kid1, kid2;
    int depth1, depth2;
    if (generic(pair->first->op) == ASGN) {
        if (generic(pair->second->op) == ASGN) { // def-def
            // removeDeadTree(b, pair->first);
        } else { // def-use
            kid1 = pair->first->kids[0];
            depth1 = getDepth(kid1);
            kid2 = pair->second->kids[0];
            if (generic(pair->second->op) == TUCK)
                kid2 = kid2->kids[0];
            assert(kid2->op == VREG + P);
            depth2 = getDepth(kid2);
            if (depth1 < DEPTH && depth2 < DEPTH) {
                int op = pair->first->op + TUCK - ASGN;
                pair->first->kids[1] = newnode(op, pair->first->kids[1-vreg], NULL, intconst(depth1));
                kid2->op = STACK + P;
                setDelve(kid2, depth2 + 1);
                incrementDepth(kid1);
                increaseDepth(kid1, kid2);
            }
            assert(pair->first->syms[2] == NULL);
//            pair->first->syms[2] = LIVE;
        }
    } else {
        assert(generic(pair->first->op) == INDIR);
        if (generic(pair->second->op) == ASGN) { // use-def
//            pair->first->kids[0]->syms[2] = FINAL;
        } else { // use-use
            assert(pair->first != pair->second);
            kid1 = pair->first->kids[0];
            depth1 = getDepth(kid1);
            kid2 = pair->second->kids[0];
            if (generic(pair->second->op) == TUCK)
                kid2 = kid2->kids[0];
            assert(kid2->op == VREG + P);
            depth2 = getDepth(kid2);
            if (depth1 < DEPTH && depth2 < DEPTH) {
                int op = pair->first->op + TUCK - INDIR;
                pair->first->kids[0] = newnode(pair->first->op, kid1, NULL, NULL);
                pair->first->op = op;
                if (generic(kid1->op) == STACK)
                    pair->first->syms[0] = intconst(depth1);
                else 
                    pair->first->syms[0] = intconst(depth1 + 1);
                kid2->op = STACK + P;
                setDelve(kid2, depth2 + 1);
                increaseDepth(kid1, kid2);
            }
//            kid2->syms[2] = NON_FINAL;
        }
    }
    pair->separation = 0;
}

static int compare(const void* v1, const void* v2) {
    const NodePair n1 = *(NodePair*)v1;
    const NodePair n2 = *(NodePair*)v2;
    return n2->separation - n1->separation;
}

static void intraBlock(Block b) {
    if (local_variable_count == 0) 
        return;
    ArrayList pairs = getPairs(b);
    NodePair* pair = (NodePair*) list_array(pairs);
    int size = list_size(pairs);
    qsort(pair, size, sizeof(NodePair), &compare);
    for (; *pair; pair++) {
        applyPair(b, *pair);    
    }
}

optimiser koopman = {
    "koopman",
    "Eliminate some local variables using Phil Koopman's intra-block algorithm",
    0,
    &koop
};