#
#       snmp_pdu.py
#
#       2003.9.14
#
#   Copyright (C) Hidetoshi Nakano
#
#   Please use this program at your own risk.
#   Without any warranty.
# 
#
#       RFC-1157
############################
import string
import Lpy.utils.file_utils

Tag_Class = {
        'APPLICATION' : 0x40,
        'CONTEXT'     : 0x80,
        'PRIVATE'     : 0xC0}

Tag_Constructed  = 0x20

Tags = {
        'Boolean'      : 0x00 ,
        'integer'      : 0x02 ,
        'BitString'    : 0x03 ,
        'Octet'        : 0x04 ,
        'Null'         : 0x05 ,
        'ObjectID'     : 0x06 ,
        'Sequence'     : 0x10 ,
        'Set'          : 0x11 ,
        # application specific tags
        'IpAddress'    : 0x40 ,
        'Counter32'    : 0x41 ,
        'Gauge32'      : 0x42 ,  #'Unsigned32'
        'TimeTick'     : 0x43 ,
        'Opaque'       : 0x44 ,
        'NsapAddress'  : 0x45 ,
        'Counter64'    : 0x46 ,
        # SNMP v.2 exception tags
        'noSuchObject'    : 0x80 ,
        'noSuchInstance'  : 0x81 ,
        'endOfMibView'    : 0x82 }

Pdu_tags = { # Protocol Data Unit
        'GetRequest':     0xa0,  # RFC-1157
        'GetNextRequest': 0xa1,
        'GetResponse':    0xa2,
        'SetRequest':     0xa3,
        'Trap1':          0xa4,
        # v2c RFC-1448
        'GetBulkRequest':     0x05,
        'InformationRequest': 0x06,
        'Trap2':              0x07 }

Snmp_Errors = {
        0: 'NoError',
        1: 'tooBig',
        2: 'noSuchName',
        3: 'badValue',
        4: 'readOnly',
        5: 'genErr',
        # v2 RFC-1905
        6: 'noAccess',
        7: 'wrongType',
        8: 'wrongLength',
        9: 'wrongEncoding',
        10: 'wrongValue',
        11: 'noCreation',
        12: 'inconsistentValue',
        13: 'resourceUnavailable',
        14: 'commitFailed',
        15: 'undoFailed',
        16: 'authorizationError',
        17: 'notWritable',
        18: 'inconsistentName'}

Trap_types = {
    'coldStart': 0x00,
    'warmStart': 0x01,
    'linkDown':  0x02,
    'linkUp':    0x03,
    'authentificationFailure': 0x04,
    'egpNeighborLoss':    0x05,
    'enterpriseSpecfic':  0x06 }


### encode PDU ###
def get_snmp_version(flag,version):
    if flag == "oid":
        if version == 3:
            return 0x03
        elif version == 2:
            return 0x01
        else:
            return 0x00
    else:
        if version == 0x03:
            return 3
        elif version == 0x01:
            return 2
        else:
            return 1

def get_requestID(id):
        a = id >> 24 & 0x00ff
        b = id >> 16 & 0x00ff
        c = id >> 8 & 0x00ff
        d = id & 0x00ff
        return "%c%c%c%c%c%c" % (0x02,0x04,a,b,c,d)

def get_oid_pair(name):
        oid = name.split('.')
        length = len(oid)
        dat = ''
        for ii in oid[3:]:
            i = int(ii)
            if i < 128:
                dat += "%c" % i
            elif i >= 128 and i < 16384:
                dat += "%c%c" % (i/128 + 128,i%128)
                length += 1
            else:
                m = i/16384
                n = i%16384
                dat += "%c%c%c" % (m + 128,n/128 + 128,n%128)
                length += 2

        data = "%c%c%c" % (0x06,length -2,0x2b) # 0x2b == .1.3
        data += dat
        data += "%c%c" % (0x05,0x00)
        return data

def get_oid(oid_name):
        oid = get_oid_pair(oid_name)
        length = len(oid)
        data = "%c%c%c%c" % (0x30,length +2,0x30,length)
        return data + oid

class NameSpace:
	pass

PDU_conf = NameSpace()
PDU_conf.errInfo = "%c%c%c%c%c%c" % (0x02,0x01,0x00,0x02,0x01,0x00)
PDU_conf.version  = None
PDU_conf.community = None
PDU_conf.id    = None
PDU_conf.reqID = None
PDU_conf.oid   = None
PDU_conf.data  = None

def get_pdu_data(request,oids,community,version):
        if Pdu_tags.has_key(request):
            pdu = Pdu_tags[request]
        else:
            print "get_pdu() request Error.%s " % request
            return None

        PDU_conf.id = 0x06050403 + Lpy.utils.file_utils.get_random_number(0xff,0x0fffffff)
        PDU_conf.version   = "%c%c%c" % (0x02,0x01,get_snmp_version("oid",version))
        PDU_conf.community = "%c%c%s" % (0x04,len(community),community)

        PDU_conf.reqID = get_requestID(PDU_conf.id)
        PDU_conf.oid   = get_oid(oids)

        data = PDU_conf.reqID + PDU_conf.errInfo + PDU_conf.oid
        data1 = ("%c%c" % (pdu,len(data))) + data
        data2 = PDU_conf.version + PDU_conf.community + data1
        PDU_conf.data  = ("%c%c" % (0x30,len(data2))) + data2

        return PDU_conf

##### mib data decode ###
def set_number_mib(data):
        dat = map(ord,data)
        return reduce(lambda x,y: x<<8 | y,dat, 0L)

def get_oid_number(dat):
    if ord(dat[0]) & 0x80:
        if ord(dat[1]) & 0x80:
            ret = (ord(dat[0]) & 0x7f)*16384 + (ord(dat[1]) & 0x7f)*128 + ord(dat[2])
            ii = 2
        else:
            ret = (ord(dat[0]) & 0x7f)*128 + ord(dat[1])
            ii = 1
    else:
        ret = ord(dat[0])
        ii = 0
    return ii,ret

def decode_oid_data(oid):
    if ord(oid[0]) == 0x2b:
        data = '.1.3'
        num = len(oid)
        i = 1
        while i < num:
            ii,ret = get_oid_number(oid[i:])
            data += ".%d" % ret
            i += 1 + ii
    else: # .ccitt.zeroDotZero
        data = ''
        for ii in oid:
            data += '.' + str(ord(ii))
    return data

def octet_hex(mib):
        data = ''
        for ii in range(len(mib)):
            if ii == 0:
                data += hex(ord(mib[ii]))
            else:
                data += ' ' + hex(ord(mib[ii]))
        return data

def octet_ipaddress(mib):
        data = ''
        for ii in range(len(mib)):
            if ii == 0:
                data += str(ord(mib[ii]))
            else:
                data += ':' + str(ord(mib[ii]))
        return data

def check_printable(data):
    for key in data:
        if not key in string.printable:
            return None
    return 1

def oid_data_type(flag,mib):
    for ii in Tags.keys():
        if Tags[ii] == flag:
            key = ii
            break
    else:
        key = 'Opaque'

    if key == 'ObjectID':
        data = decode_oid_data(mib)
    elif key == 'Null':
        data = None
    elif key == 'IpAddress' or key == 'NsapAddress':
        data = octet_ipaddress(mib)

    elif key == 'integer':
        if ord(mib[0]) >= 0x80:
            data = set_number_mib(mib) - 256**(len(mib))
        else:
            data = set_number_mib(mib)

    elif key == 'Counter64' or key == 'Counter32' or key == 'Gauge32':
            num = len(mib)
            if num == 1:
                data = ord(mib)
            elif num > 1:
                data = set_number_mib(mib)
    elif key == "TimeTick":
        data = set_number_mib(mib)

    else: # 'Octet'
        if len(mib) >= 1:
            if check_printable(mib):
                data = mib
            else:
                data = octet_hex(mib)
        else:
            data = None
    return key,data

def set_oid_pair(data):
        if ord(data[0]) == 0x30:
            if ord(data[1]) < 0x81:
                pair_length = ord(data[1])
                start = 2
            elif ord(data[1]) == 0x81:
                pair_length = ord(data[2])
                start = 3
            elif ord(data[1]) == 0x82:
                pair_length = ord(data[2]) * 256 + ord(data[3])
                start = 4

            if len(data) < pair_length + 2:
                print "MIB pair data size Error."
                return None,None

            if ord(data[start]) == 0x06:
                data_length = ord(data[start + 1])
                obj_start = start + 2 + data_length

                oid = decode_oid_data(data[start + 2:obj_start])
                if oid == None:
                    print "MIB pair OID data Error."
                    return None,None

                flag = ord(data[obj_start])  # data type
                mib = data[obj_start + 2:obj_start + 2 + ord(data[obj_start + 1])]
                key,dat = oid_data_type(flag,mib)
                #print "%s = %s:%s" % (oid,key,dat)
                return pair_length +2,(oid,dat)

        return None,None

Mib = NameSpace()
Mib.version   = None
Mib.community = None
Mib.id        = None
Mib.oid   = None
Mib.error = None
Mib.error_info     = None
Mib.error_position = None

def decode_mib_data(data,flag = 0):
        if ord(str(data[1])) < 0x81:
            add_num = 0
            if ord(str(data[1])) != len(data[2:]):
                print "MIB data size Error."
                return None
        else:
            error = 0
            add_num = 1
            if ord(str(data[1])) == 0x81:
                if ord(str(data[2])) != len(data[3:]):
                    error = 1
            elif ord(str(data[1])) == 0x82:
                ii = ord(data[2]) * 256 + (ord(data[3]))
                add_num = 2
                if ii != len(data[4:]):
                    error = 1

            if error == 1:
                print "MIB data",len(data[2:]),ord(str(data[1])),ord(str(data[2]))
                print "MIB data size Error."
                for i in range(60):
                    print "%d :%d :%x" % (i,ord(data[i]),ord(data[i]))
                return None

        pdu = 7 + add_num + ord(data[6 + add_num])
        Mib.community = data[7+ add_num:pdu]
        Mib.version = get_snmp_version("mib",ord(data[4+ add_num]))

        if (ord(data[pdu]) != Pdu_tags['GetResponse'] and
            ord(data[pdu]) != Pdu_tags['GetNextRequest'] and
            ord(data[pdu]) != Pdu_tags['GetResponse']):
            print "MIB PDU Error.(%d:%d)" % (pdu,ord(data[pdu]))
            return None

        Mib.oid   = None
        Mib.error = None

        if ord(data[pdu + 1]) == 0x81:
            pdu +=1
        elif ord(data[pdu + 1]) == 0x82:
            pdu +=2

        err_p  = pdu + 4 + ord(data[pdu + 3])
        err_pp = err_p + 2 + ord(data[err_p + 1])
        Mib.id = set_number_mib(data[pdu + 4:err_p])
        oid_start = err_pp + 2 + ord(data[err_pp + 1])

        ## if error
        error_info = ord(data[err_p + 2:err_pp])
        if error_info == 0x00:
            Mib.error_info = 0x00
        else:
            if Snmp_Errors.has_key(error_info):
                Mib.error_info = Snmp_Errors[error_info]
            else:
                Mib.error_info = "unknown Error"
            Mib.error_position = ord(data[err_pp + 2:oid_start])
            #return Mib

        ## OID data start
        if ord(str(data[oid_start])) == 0x30:
            if ord(str(data[oid_start + 1])) < 0x81:
                start = oid_start + 2
            elif ord(str(data[oid_start + 1]))== 0x81:
                start = oid_start + 3
            elif ord(str(data[oid_start + 2]))== 0x82:
                start = oid_start + 4

        else:
            Mib.error = "oid start: Error"
            return Mib

        if start < len(data):
            length,oid = set_oid_pair(data[start:])
            if length:
                Mib.oid = oid
            else:
                Mib.error = "oid pair data: Error"
        else:
            Mib.error = "oid pair data size: Error"
        return Mib


