#!/bin/python

import sys
import os
from optparse import OptionParser
from dcbnetlink import DcbController
from collections import defaultdict
from subprocess import Popen, PIPE
from tc_wrap import *

DCB_CAP_DCBX_HOST = 0x1
DCB_CAP_DCBX_LLD_MANAGED = 0x2
DCB_CAP_DCBX_VER_CEE = 0x4
DCB_CAP_DCBX_VER_IEEE = 0x8
DCB_CAP_DCBX_STATIC = 0x10

IEEE_8021QAZ_TSA_STRICT = 0
IEEE_8021QAZ_TSA_CB_SHAPER = 1
IEEE_8021QAZ_TSA_ETS = 2
IEEE_8021QAZ_TSA_VENDOR = 255

class Maxrate:
    def get(self):
        pass
    def set(self, ratelimit):
        pass

    def prepare(self, ratelimit):
	old_ratelimit = self.get()
        ratelimit += old_ratelimit[len(ratelimit):8]
        return ratelimit

class MaxrateNL(Maxrate):
    def __init__(self, ctrl):
        self.ctrl = ctrl
    def get(self):
        return ctrl.get_ieee_maxrate()

    def set(self, ratelimit):
        ratelimit = self.prepare(ratelimit)
	ctrl.set_ieee_maxrate(ratelimit)

class MaxrateSysfs(Maxrate):
    def __init__(self, path):
        self.path = path

    def get(self):
	ratelimit = []
	f = open(self.path, "r")
	for item in f.read().split():
		ratelimit.append(float(item))
	f.close()

	return ratelimit

    def set(self, ratelimit):
        ratelimit = self.prepare(ratelimit)
	f = open(self.path, "w")
	f.write(" ".join(str(r) for r in ratelimit))
	f.close()

def pretty_print(prio_tc, tsa, tcbw, ratelimit):
	tc2up = defaultdict(list)
	if (printall == True):
		for i in range(8):
			tc2up.setdefault(i,[])
	for up in range(len(prio_tc)):
		tc = prio_tc[up]
		tc2up[int(tc)].append(up)

	for tc in tc2up:
                r = "unlimited"
		msg = ""
		try:
	                if ratelimit[tc] > 0:
	                    r = "%d Gbps" % (ratelimit[tc] / 1024/1024)
			msg = "tc: %d ratelimit: %s, tsa: " % (tc, r)
		except Exception, err:
			pass
		try:
			if (tsa[tc] == IEEE_8021QAZ_TSA_ETS):
				msg +="ets, bw: %s%%" % (tcbw[tc])
			elif (tsa[tc] == IEEE_8021QAZ_TSA_STRICT):
				msg += "strict"
			else:
				msg += "unknown"
		except Exception, err:
			pass

		if msg:
			print msg
		
		try:
			for up in tc2up[tc]:
				print "\t up: ", up
				for skprio in up2skprio[int(up)]:
					print "\t\t skprio: " + skprio
		except Exception, err:
			pass

def parse_int(str, min, max, description):
    try:
        v = int(str)

        if (v < min or v > max):
            raise ValueError("%d is not in the range %d..%d" % (v, min, max))

        return v
    except ValueError, e:
        print "Bad value for %s: %s" % (description, e)
        parser.print_usage()
        sys.exit(1)

parser = OptionParser(usage="%prog -i <interface> [options]", version="%prog 1.0")

parser.add_option("-p", "--prio_tc", dest="prio_tc",
		  help="maps UPs to TCs. LIST is 8 comma seperated TC numbers. " +
		  "Example: 0,0,0,0,1,1,1,1 maps UPs 0-3 to TC0, and UPs 4-7 to " +
		  "TC1", metavar="LIST")
parser.add_option("-s", "--tsa", dest="tsa", help="Transmission algorithm for " +
		"each TC. LIST is comma seperated algorithm names for each TC. " +
		"Possible algorithms: strict, etc. Example: ets,strict,ets sets " +
		"TC0,TC2 to ETS and TC1 to strict. The rest are unchanged.",
		metavar="LIST")
parser.add_option("-t", "--tcbw", dest="tc_bw",
		  help="Set minimal guaranteed %BW for ETS TCs. LIST is comma " +
		  "seperated percents for each TC. Values set to TCs that are " +
		  "not configured to ETS algorithm are ignored, but must be " +
		  "present. Example: if TC0,TC2 are set to ETS, then 10,0,90 " +
		  "will set TC0 to 10% and TC2 to 90%. Percents must sum to " +
		  "100.", metavar="LIST")
parser.add_option("-r", "--ratelimit", dest="ratelimit",
		  help="Rate limit for TCs (in Gbps). LIST is a comma seperated " +
		  "Gbps limit for each TC. Example: 1,8,8 will limit TC0 to " +
		  "1Gbps, and TC1,TC2 to 8 Gbps each.", metavar="LIST")

parser.add_option("-i", "--interface", dest="intf",
                  help="Interface name")

parser.add_option("-a", action="store_true", dest="printall", default=False,
                  help="Show all interface's TCs")

(options, args) = parser.parse_args()

if len(args) > 0:
    print "Bad arguments"
    parser.print_usage()
    sys.exit(1)

if (options.intf == None):
	print "Interface name is required"
	parser.print_usage()
	
	sys.exit(1)

ratelimit_path = "/sys/class/net/" + options.intf + "/qos/maxrate"
skprio2up_path = "/sys/class/net/" + options.intf + "/qos/skprio2up"

# try using sysfs - if not exist fallback to tc tool
try:
	if (os.path.exists(skprio2up_path)):
		up2skprio = up2skprio_sysfs(skprio2up_path, options.intf)
	else:
		up2skprio = up2skprio_mqprio(options.intf)
except Exception, e:
	print "couldn't read skprio2up: %s" % e
	sys.exit(1)


tsa = [IEEE_8021QAZ_TSA_STRICT, IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT]
tc_bw = [0, 0, 0, 0, 0, 0, 0, 0]
prio_tc = [0, 0, 0, 0, 0, 0, 0, 0]
printall = False

ctrl = DcbController(options.intf)

try:
	maxrate = None
	if (not os.path.exists(ratelimit_path)):
	    maxrate = MaxrateNL(ctrl)
	else:
	    maxrate = MaxrateSysfs(ratelimit_path)
	
	if options.ratelimit:
		ratelimit = []
		for r in options.ratelimit.split(","):
			r = parse_int(r, 0, 1000000, "ratelimit")
	
			ratelimit += [r * 1024 * 1024]
                try:
	                maxrate.set(ratelimit)
                except:
	                print "Rate limit is not supported on your system!"
	
        try:
	        ratelimit = maxrate.get()
        except:
	        print "Rate limit is not supported on your system!"
except:
	if options.ratelimit:
		sys.exit(1)
	else:
		ratelimit = []

try:
	if (not (ctrl.get_dcbx() & DCB_CAP_DCBX_VER_IEEE)):
		ctrl.set_dcbx(DCB_CAP_DCBX_VER_IEEE | DCB_CAP_DCBX_HOST)
	
	prio_tc, tsa, tc_bw = ctrl.get_ieee_ets()
except:
	print "ETS features are not supported on your system"

if (options.tsa):
	i = 0
	for t in options.tsa.split(","):
                if i >= 8:
                    print "Too many items for TSA"
                    sys.exit(1)

		if (t == "strict"):
			tsa[i] = IEEE_8021QAZ_TSA_STRICT
		elif (t == 'ets'):
			tsa[i] = IEEE_8021QAZ_TSA_ETS
		else:
			print "Bad TSA value: ", t
			parser.print_usage()
			sys.exit(1)
		i += 1

if options.printall:
	printall = True

if options.tc_bw:
	i = 0
	for t in options.tc_bw.split(","):
            if i >= 8:
                print "Too many items for ETS BW"
                sys.exit(1)

            bw = parse_int(t, 0, 100, "ETS BW")

            if tsa[i] == IEEE_8021QAZ_TSA_STRICT and bw != 0:
                print "ETS BW for a strict TC must be 0"
                parser.print_usage()
                sys.exit(1)

            tc_bw[i] = bw
            i += 1

if options.prio_tc:
	i = 0
        for t in options.prio_tc.split(","):
                if i >= 8:
                    print "Too many items in UP => TC mapping"
                    sys.exit(1)

                prio_tc[i] = parse_int(t, 0, 7, "UP => TC mapping")
                i += 1

if options.tsa or options.tc_bw or options.prio_tc:
    try:
	ctrl.set_ieee_ets(_prio_tc = prio_tc, _tsa = tsa, _tc_bw = tc_bw)
    except OSError, e:
        print e
        sys.exit(1)

pretty_print(prio_tc, tsa, tc_bw, ratelimit)
