# -*- coding: utf-8 -*-
"""
blacklist tests

@copyright: Copyright  2004 by Nir Soffer <nirs@freeshell.org>
@license: GNU GPL, see COPYING for details
"""

import time
import sys
try:
    import cPickle as pickle
except ImportError:
    import pickle

import unittest
from blacklist.blacklist import BlackList
from blacklist import ipv4


text = """
 * 127.0.0.1            # full address
 - 127.0.0.128/28       # netblock
192.114.128.0/20        # netblock without wiki list markup
# 127.0.0.2 This address should be ignored
This line should be ignored and so is the next line

"""

class BlackListTestCase(unittest.TestCase):
    bl = BlackList(text)

    def testMatchSingle(self):
        """ blacklist: match single ip """
        self.assert_('127.0.0.1' in self.bl)
        self.assert_(0x7f000001 in self.bl)
        self.failIf('127.0.0.2' in self.bl)

    def testMatchInNetblock(self):
        """ blacklist: match ip in netblock x.x.x.x/x """
        self.failIf('127.0.0.127' in self.bl)
        self.assert_('127.0.0.128' in self.bl)
        self.assert_('127.0.0.129' in self.bl)
        self.assert_('127.0.0.143' in self.bl)
        self.failIf('127.0.0.144' in self.bl)

        self.failIf('192.114.127.0' in self.bl)
        self.assert_('192.114.128.0' in self.bl)
        self.assert_('192.114.134.51' in self.bl)
        self.assert_('192.114.143.255' in self.bl)
        self.failIf('192.114.144.0' in self.bl)
            
    def testUpdateFromText(self):
        """ blacklist: update blacklist from text """        
        self.bl.updateFromText('127.0.0.100\n10.0.0.48/28')
        self.assert_('127.0.0.100' in self.bl)
        self.failIf('10.0.0.47' in self.bl)
        self.assert_('10.0.0.48' in self.bl)
        self.assert_('10.0.0.49' in self.bl)
        self.assert_('10.0.0.63' in self.bl)
        self.failIf('10.0.0.64' in self.bl)


class IleagalAddressTestCase(unittest.TestCase):
    bad = (
        # Illegal address and netblocks
        '127.0.0.256',
        '10.10.10.10/45',
        )
        
    def testBadAddress(self):
        """ blacklist: illegal address should raise ValueError """
        for address in self.bad:
            self.assertRaises(ValueError, BlackList, address)

    def testBadAddressLookup(self):
        """ blacklist: illegal address lookup raise ValueError """
        self.assertRaises(ValueError, BlackList().__contains__, '127.0.0.256')


class PickleTestCase(unittest.TestCase):
    def testPickleTestCase(self):
        """ blacklist: pickle, unpickle, update, pickle, unpickle, match """
        self.doTest(BlackList(text), 'blacklist.pickle', '113.114.115.116') 
          
    def doTest(self, bl, name, needle):
        """ Helper for tests """
        pickle.dump(bl, file(name, 'w'), pickle.HIGHEST_PROTOCOL)
        bl = pickle.load(file(name))
        self.assert_('127.0.0.1' in bl)
        assert needle not in bl
        bl.updateFromText(needle)
        self.assert_(needle in bl)
        pickle.dump(bl, file(name, 'w'), pickle.HIGHEST_PROTOCOL)
        bl = pickle.load(file(name))
        self.assert_(needle in bl)
                      

class TypicalBlackListTestCase(unittest.TestCase):
    addressMatch = '192.114.1.234'
    netblockMatch = '200.0.48.207'
    noMatch = '254.254.254.254'

    # Create typical test text
    lines = []
    # Add 100 single addresses
    count = 0
    min, max = ipv4.netblock('192.114.0.0/16')
    for address in range(min, max, 10):
        count += 1
        lines.append(' * %s # say no to address %d' %
                     (ipv4.addressToString(address), count))
        if count == 100: break
    
    # And 100 netblocks
    count = 0
    min, max = ipv4.netblock('200.200.0.0/8')
    for address in range(min, max, 255):
        count += 1
        lines.append(' * %s/24 # say no to netblock %d' %
                     (ipv4.addressToString(address), count))
        if count == 100: break
    
    text = '\n'.join(lines)
    ##print text

    def testCreateFromTypicalText(self):
        """ blacklist: time create typical black list """
        self.time(BlackList, self.text) 

    def testAddressMatchInTypicalBlackList(self):
        """ blacklist: time one address match in typical black list """
        bl = BlackList(self.text)
        self.time(bl.__contains__, self.addressMatch)
        self.assert_(self.addressMatch in bl)
        
    def testNetblockMatchInTypicalBlackList(self):
        """ blacklist: time one netblock match in typical black list """
        bl = BlackList(self.text)
        self.time(bl.__contains__, self.netblockMatch)
        self.assert_(self.netblockMatch in bl)
        
    def testNoMatchInTypicalBlackList(self):
        """ blacklist: time no match in typical black list """
        bl = BlackList(self.text)
        self.time(bl.__contains__, self.noMatch)
        self.failIf(self.noMatch in bl)

    def testUnPickle(self):
        """ blacklist: time unpickle typical blacklist """
        name = 'blacklist.pickle'
        bl = BlackList(self.text)
        pickle.dump(bl, file(name, 'w'), pickle.HIGHEST_PROTOCOL)
        self.time(pickle.load, file(name))

    def time(self, callable, *args):
        """ blacklist: time calls and insert timing in the tests output """
        start = time.time()
        callable(*args)
        stop = time.time()
        sys.stdout.write('%0.8fs ... ' % (stop - start))
        sys.stdout.flush() 
    
        

def suite():
    test_cases = [unittest.makeSuite(obj, 'test') 
        for name, obj in globals().items()
        if name.endswith('TestCase')]
    return unittest.TestSuite(test_cases)
    
if __name__ == '__main__':
    unittest.TextTestRunner(verbosity=2).run(suite())
