package sparkling.serialization;

import java.lang.ClassNotFoundException;
import java.lang.Iterable;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.ObjectInputStream;
import java.util.HashSet;

import clojure.lang.IFn;
import clojure.lang.IPersistentMap;
import clojure.lang.Keyword;
import clojure.lang.Ref;
import clojure.lang.RT;
import clojure.lang.Symbol;
import clojure.lang.Var;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Utils {
    final static Logger logger = LoggerFactory.getLogger(Utils.class);

    static final Var require = RT.var("clojure.core", "require");
    static final Var symbol = RT.var("clojure.core", "symbol");
    static final Var vals = RT.var("clojure.core", "vals");
    static final Var seq = RT.var("clojure.core", "seq");

    private Utils() {
    }

    public static void requireNamespace(Symbol namespace) {
        try {
            logger.debug("(require " + namespace.getName() + ")");
            synchronized (RT.REQUIRE_LOCK) {
                require.invoke(namespace);
            }
        } catch (Exception e) {
            logger.warn ("Error deserializing function (require " + namespace +")   " ,e);
        }
    }

    private static HashSet<Symbol> getReferencedNamespaces(Object f) {
        HashSet<Symbol> set = new HashSet<Symbol>();
        addReferencedNamespaces(f, set);
        return set;
    }

    private static void addReferencedNamespaces(Object f, HashSet<Symbol> set) {
        if (f instanceof Var) {
            Var v = (Var)f;
            set.add(v.ns.getName());
        } else if (!(f instanceof Keyword || f instanceof Symbol || f instanceof Ref)) {
            // special case maps and records to traverse over their contents in addition
            // than their fields
            if (f instanceof IPersistentMap) {
                // vals returns null for empty maps
                Iterable values = (Iterable)vals.invoke(f);
                if (values != null) {
                    for (Object val : values) {
                        if ((val instanceof IFn || val instanceof IPersistentMap) && !val.equals(f)) {
                            addReferencedNamespaces(val, set);
                        }
                    }
                }
            }
            for (Field field : f.getClass().getDeclaredFields()) {
                // only traverse static fields of maps, otherwise traverse all fields
                if (!(f instanceof IPersistentMap) || Modifier.isStatic(field.getModifiers())) {
                    try {
                        field.setAccessible(true);
                        Object val = field.get(f);
                        if ((val instanceof IFn || val instanceof IPersistentMap) && !val.equals(f)) {
                            addReferencedNamespaces(val, set);
                        }
                    } catch (IllegalAccessException e) {
                        logger.warn("Error resolving namespaces references in IFn " + f, e);
                    }
                }
            }
        }
    }

    public static void writeIFn(ObjectOutputStream out, IFn f) throws IOException {
        try {
            logger.debug("Serializing " + f );
            out.writeObject(f.getClass().getName());
            try {
                out.writeObject(getReferencedNamespaces(f));
            } catch (StackOverflowError e) {
                String msg = "Stack overflow resolving namespaces references in IFn " + f;
                logger.error(msg, e);
                throw new RuntimeException(msg, e);
            }
            out.writeObject(f);
        } catch (IOException e) {
            logger.error("Error serializing IFn " + f,e);
            throw e;
        } catch (RuntimeException e) {
            logger.error("Error serializing IFn " + f,e);
            throw e;
        }
    }

    @SuppressWarnings("unchecked")
    public static IFn readIFn(ObjectInputStream in) throws IOException, ClassNotFoundException {
        String clazz = "";
        try {
            clazz = (String) in.readObject();
            HashSet<Symbol> required = (HashSet<Symbol>) in.readObject();
            IFn f = (IFn) in.readObject();
            logger.debug("Deserializing " + f );
            for (Symbol ns : required) {
                requireNamespace(ns);
            }
            return f;
        } catch (IOException e) {
            logger.error("Error deserializing function (clazz: " + clazz + ")", e);
            throw e;
        } catch (ClassNotFoundException e) {
            logger.error("Error deserializing function (clazz: " + clazz + ")", e);
            throw e;
        } catch (RuntimeException e) {
            logger.error("Error deserializing function (clazz: " + clazz + ")", e);
            throw e;
        }
    }
}
