using OnlinePCA
using ArgParse:
    ArgParseSettings, parse_args, @add_arg_table!

# main
function main()
    pca = OnlinePCA.SGD()
    parsed_args = OnlinePCA.parse_commandline(pca)
    println("Parsed args:")
    for (arg, val) in parsed_args
        println("  $arg  =>  ", repr(val))
    end

    if parsed_args["pseudocount"] == 1.0f0
        pseudocount = 1.0f0
    else
        pseudocount=parse(Float32, parsed_args["pseudocount"])
    end

    dim=parse(Int64, parsed_args["dim"])

    if parsed_args["stepsize"] == 1.0f3
        stepsize = 1.0f3
    else
        stepsize=parse(Float32, parsed_args["stepsize"])
    end

    if parsed_args["numbatch"] == 100
        numbatch = 100
    else
        numbatch=parse(Int64, parsed_args["numbatch"])
    end

    if parsed_args["numepoch"] == 5
        numepoch = 5
    else
        numepoch=parse(Int64, parsed_args["numepoch"])
    end

    if parsed_args["g"] == 0.9f0
        g = 0.9f0
    else
        g=parse(Float32, parsed_args["g"])
    end

    if parsed_args["epsilon"] == 1.0f-8
        epsilon = 1.0f-8
    else
        epsilon=parse(Float32, parsed_args["epsilon"])
    end

    if parsed_args["lower"] == 0
        lower = 0
    else
        lower=parse(Float32, parsed_args["lower"])
    end

    if parsed_args["upper"] == 1.0f+38
        upper = 1.0f+38
    else
        upper=parse(Float32, parsed_args["upper"])
    end

    if parsed_args["evalfreq"] == 5000
        evalfreq = 5000
    else
        evalfreq=parse(Int64, parsed_args["evalfreq"])
    end

    if parsed_args["offsetStoch"] == 1.0f-6
        offsetStoch = 1.0f-6
    else
        offsetStoch=parse(Float32, parsed_args["offsetStoch"])
    end

    if parsed_args["initW"] == nothing
        initW = parsed_args["initW"]
    else
        initW = String(parsed_args["initW"])
    end

    if parsed_args["initV"] == nothing
        initV = parsed_args["initV"]
    else
        initV = String(parsed_args["initV"])
    end

    if parsed_args["logdir"] == nothing
        logdir = parsed_args["logdir"]
    else
        logdir = String(parsed_args["logdir"])
    end

    if parsed_args["perm"] == false
        perm = false
    elseif parsed_args["perm"] == "true"
        perm = true
    else
        error("Please specify the perm option as true or false")
    end

    if parsed_args["expvar"] == 0.1f0
        expvar = 0.1f0
    else
        expvar = parse(Float32, parsed_args["expvar"])
    end

    if parsed_args["cper"] == 1.0f0
        cper = 1.0f0
    else
        cper=parse(Float32, parsed_args["cper"])
    end

    out = OnlinePCA.sgd(input=parsed_args["input"],
        scale=parsed_args["scale"],
        pseudocount=pseudocount,
        rowmeanlist=parsed_args["rowmeanlist"],
        rowvarlist=parsed_args["rowvarlist"],
        colsumlist=parsed_args["colsumlist"],
        dim=dim,
        stepsize=stepsize,
        numepoch=numepoch,
        scheduling=parsed_args["scheduling"],
        g=g,
        epsilon=epsilon,
        lower=lower,
        upper=upper,
        evalfreq=evalfreq,
        offsetStoch=offsetStoch,
        initW=initW,
        initV=initV,
        logdir=logdir,
        perm=perm,
        cper=cper)
    OnlinePCA.output(parsed_args["outdir"], out, expvar)
end

main()