#include <stdio.h>
#include <walk_heap.h>
#include <jvmti.h>
#include <tijmp.h>
#include <gc.h>
#include <object_info_list.h>

extern jvmtiEnv* jvmti;
static jlong current_class_tag = 1; /* 0 = untagged */
extern jlong current_object_tag;

typedef struct {
    /* summary information. */
    jint iteration;
    jint reference;
    jint field;
    jint array;
    jint string;
    
    /* information about each class. */
    jobjectArray classes;
    jlongArray ja_class_count;
    jlong* class_count;
    jlongArray ja_total_size;
    jlong* total_size;
} walk_heap_counter;

static jint JNICALL heap_iteration_callback (jlong class_tag, 
				     jlong size, 
				     jlong* tag_ptr, 
				     jint length, 
				     void* user_data) {
    walk_heap_counter* whp = (walk_heap_counter*)user_data;
    whp->iteration++;
    if (class_tag != 0) {
	whp->class_count[class_tag]++;
	whp->total_size[class_tag] += size;
    }
    return JVMTI_VISIT_OBJECTS;
}

static jint JNICALL heap_reference_callback (jvmtiHeapReferenceKind reference_kind, 
				     const jvmtiHeapReferenceInfo* reference_info, 
				     jlong class_tag, 
				     jlong referrer_class_tag, 
				     jlong size, 
				     jlong* tag_ptr, 
				     jlong* referrer_tag_ptr, 
				     jint length, 
				     void* user_data) {
    walk_heap_counter* whp = (walk_heap_counter*)user_data;
    whp->reference++;
    return JVMTI_VISIT_OBJECTS;
}

static jint JNICALL primitive_field_callback (jvmtiHeapReferenceKind kind, 
				      const jvmtiHeapReferenceInfo* info, 
				      jlong object_class_tag, 
				      jlong* object_tag_ptr, 
				      jvalue value, 
				      jvmtiPrimitiveType value_type, 
				      void* user_data) {
    walk_heap_counter* whp = (walk_heap_counter*)user_data;
    whp->field++;
    return 0;
}

static jint JNICALL array_primitive_value_callback (jlong class_tag, 
					    jlong size, 
					    jlong* tag_ptr, 
					    jint element_count, 
					    jvmtiPrimitiveType element_type, 
					    const void* elements, 
					    void* user_data) {
    walk_heap_counter* whp = (walk_heap_counter*)user_data;
    whp->array++;
    return 0;
}

static jint JNICALL string_primitive_value_callback (jlong class_tag, 
					     jlong size, 
					     jlong* tag_ptr, 
					     const jchar* value, 
					     jint value_length, 
					     void* user_data) {
    walk_heap_counter* whp = (walk_heap_counter*)user_data;
    whp->string++;
    return 0;
}

static void clear_walk_heap_counter (walk_heap_counter* whp) {
    whp->iteration = 0;
    whp->reference = 0;
    whp->field = 0;
    whp->array = 0;
    whp->string = 0;
}

static void report_walk_heap_data (JNIEnv* env, walk_heap_counter* whp) {
    jclass cls;
    jmethodID m_hwr;

    (*env)->ReleaseLongArrayElements (env, whp->ja_class_count, 
				      whp->class_count, 0);
    (*env)->ReleaseLongArrayElements (env, whp->ja_total_size, 
				      whp->total_size, 0);
    cls = (*env)->FindClass (env, "tijmp/TIJMPController");
    m_hwr = (*env)->GetStaticMethodID (env, cls, "heapWalkResult", 
				       "([Ljava/lang/Class;[J[J)V");
    if (m_hwr != NULL)
	(*env)->CallStaticVoidMethod (env, cls, m_hwr, whp->classes, 
				      whp->ja_class_count, whp->ja_total_size);
}

static void setup_whp (JNIEnv* env, walk_heap_counter* whp, 
		       jint class_count, jclass** classes) {
    jvmtiError err;
    jclass cls;
    jint i;
    jlong object_tag;

    cls = (*env)->FindClass (env, "java/lang/Class");
    whp->classes = (*env)->NewObjectArray (env, current_class_tag, cls, NULL);

    whp->ja_class_count = (*env)->NewLongArray (env, current_class_tag);
    whp->class_count = (*env)->GetLongArrayElements (env, whp->ja_class_count, NULL);
    
    whp->ja_total_size = (*env)->NewLongArray (env, current_class_tag);
    whp->total_size = (*env)->GetLongArrayElements (env, whp->ja_class_count, NULL);

    for (i = 0; i < current_class_tag; i++) {
	(*env)->SetObjectArrayElement (env, whp->classes, i, NULL);
	whp->class_count[i] = 0;
	whp->total_size[i] = 0;
    }

    for (i = 0; i < class_count; i++) {
	err = (*jvmti)->GetTag (jvmti, classes[0][i], &object_tag);
	(*env)->SetObjectArrayElement (env, whp->classes, (jint)object_tag, classes[0][i]);
    }

    (*jvmti)->Deallocate (jvmti, (unsigned char*)classes[0]);
}

void JNICALL tag_classes (JNIEnv* env, jint* class_count, jclass** classes) {
    jvmtiError err;
    jlong object_tag;
    jint i;
    jint tagged_classes = 0;

    err = (*jvmti)->GetLoadedClasses (jvmti, class_count, classes);
    if (err != JVMTI_ERROR_NONE) {
	handle_global_error (err);
	return;
    }
    
    for (i = 0; i < *class_count; i++) {
	err = (*jvmti)->GetTag (jvmti, classes[0][i], &object_tag);
	if (err != JVMTI_ERROR_NONE) {
	    handle_global_error (err);
	    return;
	}
	if (object_tag <= 0) {
	    object_tag = current_class_tag++;
	    err = (*jvmti)->SetTag (jvmti, classes[0][i], object_tag);
	    if (err != JVMTI_ERROR_NONE) {
		handle_global_error (err);
		return;
	    }
	    tagged_classes++;
	}
    }
}

static void tag_classes_and_setup_whp (JNIEnv* env, walk_heap_counter* whp) {
    jint class_count;
    jclass* cp = NULL;
    jclass** classes = &cp;
    tag_classes (env, &class_count, classes);
    setup_whp (env, whp, class_count, classes);
}

void JNICALL walk_heap (JNIEnv *env) {
    jvmtiError err;
    jvmtiHeapCallbacks callbacks;
    walk_heap_counter whp;


    /* force gc to remove all garbage. */
    force_gc ();
    
    clear_walk_heap_counter (&whp);
    tag_classes_and_setup_whp (env, &whp);

    callbacks.heap_iteration_callback = heap_iteration_callback; 
    callbacks.heap_reference_callback = heap_reference_callback; 
    callbacks.primitive_field_callback = primitive_field_callback;
    callbacks.array_primitive_value_callback = array_primitive_value_callback;
    callbacks.string_primitive_value_callback = string_primitive_value_callback;
    err = (*jvmti)->IterateThroughHeap (jvmti, 0, NULL, &callbacks, &whp);
    if (err != JVMTI_ERROR_NONE)
	handle_global_error (err);
    report_walk_heap_data (env, &whp);
}

static jint JNICALL fi_heap_iteration_callback (jlong class_tag, 
					jlong size, 
					jlong* tag, 
					jint length, 
					void* user_data) {
    object_info_list* oil = (object_info_list*)user_data;
    if (*tag == 0)
	*tag = current_object_tag--;
    add_object_info_to_list (oil, size, length, *tag);
    return 0;
}

static void report_instances (JNIEnv* env, object_info_list* oil) {
    jclass cls;
    jmethodID m_hwr;
    jlongArray la; 
    jintArray ia;
    jobjectArray oa;
    jint i;
    char* sign = "(Ljava/lang/Class;[Ljava/lang/Object;[J[I)V";

    oa = (*env)->NewObjectArray (env, oil->count, oil->clz, NULL);
    for (i = 0; i < oil->count; i++)
	(*env)->SetObjectArrayElement (env, oa, i, oil->objects[i]);
    la = (*env)->NewLongArray (env, oil->count);
    (*env)->SetLongArrayRegion (env, la, 0, oil->count, oil->sizes);
    ia = (*env)->NewIntArray (env, oil->count);
    (*env)->SetIntArrayRegion (env, ia, 0, oil->count, oil->lengths);

    cls = (*env)->FindClass (env, "tijmp/TIJMPController");    
    m_hwr = (*env)->GetStaticMethodID (env, cls, "instances", sign);
    if (m_hwr != NULL)
	(*env)->CallStaticVoidMethod (env, cls, m_hwr, oil->clz, oa, la, ia);
}

static jint find_pos (object_info_list* oil, jlong tag) {
    jint i;
    for (i = 0; i < oil->count; i++) 
	if (oil->tags[i] == tag)
	    return i;
    return -1;
}

static void append_objects (object_info_list* oil) {
    jlong* ids;
    jint count;
    jobject* objects;
    jlong* tags;
    jint i;
    (*jvmti)->Allocate (jvmti, sizeof (*ids) * oil->count, 
			(unsigned char**)&ids);
    for (i = 0; i < oil->count; i++)
	ids[i] = oil->tags[i];
    (*jvmti)->GetObjectsWithTags (jvmti, oil->count, ids, &count,
				  &objects, &tags);
    for (i = 0; i < count; i++) {
	jint pos = find_pos (oil, tags[i]);
	if (pos > -1) {
	    oil->objects[pos] = objects[i];
	    (*jvmti)->SetTag (jvmti, objects[i], 0);
	} else {
	    // TODO: handle this better?
	    fprintf (stderr, "failed to find pos for tag: %ld\n", tags[i]);
	}
    }
    (*jvmti)->Deallocate (jvmti, (unsigned char*)objects);
    (*jvmti)->Deallocate (jvmti, (unsigned char*)tags);
    (*jvmti)->Deallocate (jvmti, (unsigned char*)ids);    
}

void JNICALL find_all_instances (JNIEnv *env, jclass clz) {
    jvmtiError err;
    jvmtiHeapCallbacks callbacks;
    object_info_list* oil;
    
    /* force gc to remove all garbage. */
    force_gc ();
    
    // find all instance sizes and tag the objects
    oil = create_object_info_list (clz);
    callbacks.heap_iteration_callback = fi_heap_iteration_callback;
    callbacks.heap_reference_callback = 0;
    callbacks.primitive_field_callback = 0;
    callbacks.array_primitive_value_callback = 0;
    callbacks.string_primitive_value_callback = 0;

    /* This will tag the objects. */
    err = (*jvmti)->IterateThroughHeap (jvmti, 0, clz, &callbacks, oil);
    if (err != JVMTI_ERROR_NONE)
	handle_global_error (err);

    /* now we use the tags to find the actual objects. */
    create_object_store (oil);
    append_objects (oil);

    report_instances (env, oil);
    free_object_info_list (oil);
}

