#!/usr/bin/env python3
#need python3
import sys,io,time,os
import re
import subprocess as sp # dusage command
from netCDF4 import Dataset
import gzip
import json

def byte2tb(nbyte):
    tb   = "{:3.2f}".format(nbyte/(2**40)) # byte to TB
    return tb
def dict_to_pie_texstr(d,totalsize=0):
    #d[name] = int(value)

    if not totalsize: #use total usage
        for i in d.values():
            totalsize += i
    if not totalsize: # no data
        return 'No data'
    free = 100
    otherPerc =   0
    otherSize =   0
    otherThreshold = 3 # in %, below this will count as 'other'
    texstr = r'\begin{tikzpicture}'+'\n'
    texstr += r'\pie[text=pin,explode=0.1]{'
    for n in sorted(d, key=d.get, reverse=True):
        perc = "{:2.2f}".format(d[n]/totalsize *100)
        tb   = "{:3.2f}".format(d[n]/(2**40)) # byte to TB
        free -= float(perc)

        if float(perc) > otherThreshold:  # if perc is large enough
            texstr += perc+'/'+tb+"T "+n.replace('_',r'\_')+', '
        else: # if prec is too small
            otherPerc += float(perc)
            otherSize += float(tb)

    # add 'other'
    if otherPerc: texstr += '{:2.2f}'.format(otherPerc)+'/{:3.2f}'.format(otherSize)+"T "+'others, '

    if free > 1: 
        texstr += '{:2.2f}'.format(free)+'/{:3.2f}'.format(free*totalsize/100/(2**40))+"T "+'free'
    else:
        texstr = texstr[:-2]
    texstr += '};\n'

    stotal = '{:3.2f}T'.format(totalsize/(2**40))
    texstr += r'\node at (current bounding box.north) {Total size: '+stotal+'};\n' # title

    texstr += r'\end{tikzpicture}'+'\n'
    return texstr

def dict_to_table(d,totalsize=0,otherThreshold=1):
    if not totalsize: #use total usage
        for i in d.values():
            totalsize += i
    if not totalsize: # no data
        return 'No data'
    free = 100
    otherPerc =   0
    otherSize =   0
    #otherThreshold = 1 # in %, below this will count as 'other'
    texstr =  r'\begin{table}[]'+'\n'
    texstr += r'\begin{tabular}{lrr}'+'\n'
    for n in sorted(d, key=d.get, reverse=True):
        perc = "{:2.2f}".format(d[n]/totalsize *100)
        tb   = "{:3.2f}".format(d[n]/(2**40)) # byte to TB
        free -= float(perc)

        if float(perc) > otherThreshold:  # if perc is large enough
            #texstr += perc+'/'+tb+"T "+n.replace('_',r'\_')+', '
            texstr += n.replace('_',r'\_')+' & '+tb+'T & '+perc+r'\%\\'+'\n'
        else: # if prec is too small
            otherPerc += float(perc)
            otherSize += float(tb)

    # add 'other'
    if otherPerc: texstr += 'others & '+'{:3.2f}'.format(otherSize)+'T & '+'{:2.2f}'.format(otherPerc)+r'\% \\'+'\n'

    if free > 1: 
        texstr += '{:2.2f}'.format(free)+'{:3.2f}'.format(free*totalsize/100/(2**40))+"T "+'free'
    else:
        texstr = texstr[:-2]
    texstr += '\n'

    stotal = '{:3.2f}T'.format(totalsize/(2**40))

    texstr += r'\end{tabular}'+'\n'
    texstr += r'\end{table}'+'\n'
    texstr += r'total size: '+stotal+' \n'
    return texstr

def check_nc_netcdf4(fn):
    isnc4 = False
    if not fn.endswith('.nc'): return False

    try:
        a = Dataset(fn)
        if 'NETCDF4' in a.data_model: 
            isnc4 = True
            with open('nc4files.txt','a') as f:
                f.write(fn+'\n')
        else:
            pass
            #with open('nc3files.txt','a') as f:
            #    f.write(fn+'\n')
        a.close()
        del a
    except:
        pass
        #with open('err_ncfiles.txt','a') as f:
        #    f.write(fn+'\n')
    
    return isnc4

debug=0
progress = False
if __name__ == "__main__":
    dic_analysis = dict()  ## store all data
    dic_analysis['units'] = 'byte'

    #totalsize = 260*(2**40) # total size, from 'dusage -p ns9039k'
    #totalsize = int(sp.getoutput("/usr/local/bin/dusage -p ns9039k |grep 'projects  Disk' | tr -s ' ' | cut -d' ' -f6 | tr -d 'TB'"))*(2**40)
    totalsize = int(sp.getoutput("/usr/local/bin/dusage -p ns9039k |grep 'datalake  Disk' | tr -s ' ' | cut -d' ' -f6 | tr -d 'TB'"))*(2**40)

    Sizes = dict()
    NorCPMSizes = dict()
    NorCPMSizes['not'] = 0
    sharedSizes = dict()
    sharedSizes['not'] = 0
    hlinkSize = 0
    counter = -1
    nc4db = False
    gzfilelist  = False
    if len(sys.argv) > 1:
        gzfilelist = sys.argv[1]
        print('reading file: '+gzfilelist,end='',flush=True)
        start = time.time()
        filelist = gzip.open(gzfilelist,'rt',encoding='utf-8',errors='ignore').readlines()
        end = time.time()
        print(' in %2.2f secs'%(end-start),flush=True)
        total = len(filelist)
    else:
        filelist = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8',errors='ignore')  ## for non-ascii file name

    nc4dbfn = 'nc4files.txt'  ## create with is_nc4.py
    if os.path.isfile(nc4dbfn): 
        print('reading file: '+nc4dbfn,end='',flush=True)
        start = time.time()
        #nc4db = set(open(nc4dbfn,'r').readlines())
        nc4db = set()
        with open(nc4dbfn,'r') as f:
            for i in f: nc4db.add(i)
        end = time.time()
        print(' in %2.2f secs'%(end-start),flush=True)

    start = time.time()
    for line in filelist:
        counter += 1
        if debug:
            if counter > 30000: break
        if progress:
            if counter%10000 == 0: 
                end = time.time()
                print('\r'+str(counter)+' with %2.2f sec/10000                      '%(end-start),end='',flush=True)
                start = end
            #print('\r'+str(counter),end='',flush=True)
            if counter > 100000: print('\r'+str(counter),end='',flush=True)
        try:
            user, size, hlink,  fn = tuple(line.split(maxsplit=3))
        except ValueError:
            print('ignore parse error at line: %d' %counter)
            continue
            
        if not size.isdigit():
            print('ignore parse error at line: %d' %counter)
            continue
        if size == '0': continue

        fn =  fn.rstrip()

        ## Regular Sizes
        if not user in Sizes.keys():
            Sizes[user] =  {'rest': {'compressed':0, 'uncompressed': 0}
                , 'hist': {'compressed':0, 'uncompressed': 0}
                , 'misc': {'compressed':0, 'uncompressed': 0}
                , 'nfile': 0
                }

        ## number of file
        Sizes[user]['nfile'] += 1

        ## compress or not
        if nc4db:
            try:
                ################################# readlines() keeps '\n'
                COMPFLAG = 'compressed' if fn.endswith('.gz') or fn.endswith('.tar') or fn.endswith('.zip') or fn.endswith('.bz2') \
                            or (fn.endswith('.nc') and fn+'\n' in nc4db  ) else 'uncompressed'
            except KeyboardInterrupt:
                print(fn)
                sys.exit()
        else:
            COMPFLAG=False

        FILETYPE = 'misc' 
        if re.match('.*[/.]rest[/.].*',fn) : FILETYPE = 'rest'
        if re.match('.*[/.]hist[/.].*',fn) : FILETYPE = 'hist'

        # do not recount hardlink
        nhlink = int(hlink[6:])

        #print(user+' '+FILETYPE+' '+COMPFLAG+' '+fn)
        #print('  nhlink='+str(nhlink)+' with '+hlink)

        if COMPFLAG: 
            Sizes[user][FILETYPE][COMPFLAG] += int(size)/nhlink
        else:
            pass
            print(fn+' COMPFLAG '+str(COMPFLAG))

        if nhlink > 1: 
            ##print(fn+"  nhink="+str(nhlink))
            hlinkSize += int(size)

        ## NorCPM cases size
        if re.match('/datalake/NS9039K/shared/norcpm/cases/.*',fn):
            tdir = '-'.join(fn.split('/')[7:9])
            if re.match(r'.*mem\d\d',tdir): 
                tdir = fn.split('/')[6]
            if not tdir in NorCPMSizes.keys():
                NorCPMSizes[tdir] = 0
            NorCPMSizes[tdir] += int(size)
        ## shared size
        if re.match('/datalake/NS9039K/shared/.*',fn):
            tdir = fn.split('/')[4]
            if not tdir in sharedSizes.keys():
                sharedSizes[tdir] = 0
            sharedSizes[tdir] += int(size)
    print('')
    #print("hard linked file size(byte): "+str(hlinkSize))

    # data for cake figures
    # 01 Total usage of each user
    texfn = 'latex/01_total_users.tex'
    # pie chart string: '\pie{10/Str1, 30/Str2, 60/Str3}'
    #print("Total usage of each user: "+texfn)
    sizes_01 = dict()
    for user in Sizes.keys():
        size = list()
        #print('    '+user)
        #print(Sizes[user]['misc'].values())
        for i in Sizes[user]:
            if i == 'nfile': continue
            for j in Sizes[user][i]:
                size.append(Sizes[user][i][j])
        sizes_01[user] = sum(size)
        #print('        '+str(sizes_01[user]))

    # output size with percent and TB 
    #texstr = dict_to_pie_texstr(sizes_01,totalsize)
    texstr = dict_to_pie_texstr(sizes_01,sum(sizes_01.values()))

    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_01.copy()


    # 02_total_filetypes.tex
    texfn = 'latex/02_total_filetypes.tex'
    #Sizes[user][FILETYPE][COMPFLAG] += int(size)
    sizes_02 = {'rest':0, 'hist':0, 'misc':0}
    for user in Sizes.keys():
        for ft in Sizes[user].keys():
            if ft == 'nfile': continue
            for comp in Sizes[user][ft]:
                sizes_02[ft] += Sizes[user][ft][comp]
    
    texstr = dict_to_pie_texstr(sizes_02,sum(sizes_02.values()))
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_02.copy()

    # 03_hist_users.tex
    texfn = 'latex/03_hist_users.tex'
    sizes_03 = {key: 0 for key in Sizes.keys()}

    for user in Sizes.keys():
        for comp in Sizes[user]['hist'].keys():
            sizes_03[user] += Sizes[user]['hist'][comp]
    
    texstr = dict_to_pie_texstr(sizes_03)
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_03.copy()

    # 04_rest_users.tex
    texfn = 'latex/04_rest_users.tex'
    sizes_04 = {key: 0 for key in Sizes.keys()}

    for user in Sizes.keys():
        for comp in Sizes[user]['rest'].keys():
            sizes_04[user] += Sizes[user]['rest'][comp]
    
    texstr = dict_to_pie_texstr(sizes_04)
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_04.copy()

    # 05_misc_users.tex
    texfn = 'latex/05_misc_users.tex'
    sizes_05 = {key: 0 for key in Sizes.keys()}

    for user in Sizes.keys():
        for comp in Sizes[user]['misc'].keys():
            sizes_05[user] += Sizes[user]['misc'][comp]
    
    texstr = dict_to_pie_texstr(sizes_05)
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_05.copy()

    # 06_total_comp.tex
    texfn = 'latex/06_total_comp.tex'
    sizes_06 = {key: 0 for key in ['compressed','uncompressed']}

    for user in Sizes.keys():
        for ft in Sizes[user].keys():
            if ft == 'nfile': continue
            for comp in Sizes[user][ft].keys():
                sizes_06[comp] += Sizes[user][ft][comp]
    sumval = sum(sizes_06.values())
    if sumval == 0:
        texstr = "Not done yet."
    else:
        texstr = dict_to_pie_texstr(sizes_06,sumval)
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_06.copy()

    # 07_uncomp_users.tex
    texfn = 'latex/07_uncomp_users.tex'
    sizes_07 = {key: 0 for key in Sizes.keys()}
    sizes_07_c = {key: 0 for key in Sizes.keys()}

    for user in Sizes.keys():
        for ft in Sizes[user].keys():
            if ft == 'nfile': continue
            sizes_07[user] += Sizes[user][ft]['uncompressed']
            sizes_07_c[user] += Sizes[user][ft]['compressed']
    
    sumval = sum(sizes_07.values())
    if sumval == 0:
        texstr = "Not done yet."
    else:
        texstr = dict_to_pie_texstr(sizes_07)
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_07.copy()

    # 08_NorCPM_case.tex
    texfn = 'latex/08_NorCPM_case.tex'
    sizes_08 = NorCPMSizes

    texstr = dict_to_table(sizes_08)
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_08.copy()

    # 09 user usage table
    ## USER          TOTAL    HIST    REST    MISC    COMP  UNCOMP
    ## all          413.57  100.40  188.75  124.42  283.75  129.82
    ## ywang        109.99   15.81   88.17    6.02  105.46    4.53
    ## sbarthelemy   60.36     ...
    ## ingo          60.18     ...
    ## ...

    sizes_09 = list()
    sizes_09.append(['USER','TOTAL','HIST','REST','MISC','COMP','UNCOMP','NFILE'])
    ## USER: all
    ss = list()
    ss.append('all')
    ss.append(byte2tb(sum(sizes_01.values()))) ## TOTAL
    ss.append(byte2tb(sizes_02['hist']))     ## HIST
    ss.append(byte2tb(sizes_02['rest']))     ## REST
    ss.append(byte2tb(sizes_02['misc']))     ## MISC
    ss.append(byte2tb(sizes_06['compressed']))   ## COMP
    ss.append(byte2tb(sizes_06['uncompressed'])) ## UNCOMP
    ss.append(sum([ i['nfile'] for i in Sizes.values()]))## NFILE
    sizes_09.append(ss[:])

    ## user order
    users_09 = [ i[0] for i in sorted(sizes_01.items(),key=lambda x: x[1], reverse=True) ]
    ## gather infomation
    for u in users_09:
        ss = list()
        ss.append(u)
        ss.append(byte2tb(sizes_01[u]))      ## TOTAL
        ss.append(byte2tb(sizes_03[u]))      ## HIST
        ss.append(byte2tb(sizes_04[u]))      ## REST
        ss.append(byte2tb(sizes_05[u]))      ## MISC
        ss.append(byte2tb(sizes_07_c[u]))    ## COMP
        ss.append(byte2tb(sizes_07[u]))      ## UNCOMP
        ss.append(Sizes[u]['nfile'])      ## NFILE
        sizes_09.append(ss[:])

    ## output table
    maxlen_users = max([ len(i) for i in users_09])
    lines_09 = ''
    lines_09 += 'NS2345K datalake disk usage table. Created at %s' % time.strftime("%Y%m%d-%H%M")
    lines_09 += '\n'
    lines_09 += 'TOTAL: total disk usage\n'
    lines_09 += 'HIST: size of NorESM/CESM output (path contains "/hist/")\n'
    lines_09 += 'REST: size of NorESM/CESM restart ((path contains "/rest/")\n'
    lines_09 += 'MISC: files size neither HIST nor REST/")\n'
    lines_09 += 'COMP: compressed file, include netcdf4 and .gz\n'
    lines_09 += 'UNCOMP: uncompressed files\n'
    lines_09 += 'NFILE: number of files\n'
    for line in sizes_09:
        linetext = ''
        for i in line:
            if type(i) == type(int(1)):
                linetext += str(i).rjust(maxlen_users +2)
            else:
                linetext += i.rjust(maxlen_users +2)
        lines_09 += linetext
        lines_09 += '\n'

    with open('sizeTable.txt','w') as f:
        f.write(lines_09)

    # 10_shared_size.tex
    texfn = 'latex/10_shared_size.tex'
    sizes_10 = sharedSizes

    texstr = dict_to_table(sizes_10,otherThreshold=0)
    open(texfn,'w').write(texstr)

    dic_analysis[texfn] = sizes_10.copy()

    ## output data in json
    jsonfn = 'NS2345K_datalake_analysis_%s.json' % time.strftime("%Y%m%d-%H%M")
    with open(jsonfn,'w') as f:
        json.dump(dic_analysis,f,indent=4, sort_keys=True)
