#
# Aquerre Technologies
# awscredentials.py -- module to manage AWS session credentials
# Revision: 2018-01-06
#
from boto3 import client
from service import service
from datetime import datetime
from os import path
import json
import json_mod

class userkeys:
   OWNERID='xxxxxxxxxxxx'

    def __init__(self,USERDATAFILE):
        SERVICE=service()
        USERDATA=SERVICE.loadUserData(USERDATAFILE)

        self.AWSACCESSKEYID=USERDATA.get('AccessKeyId')
        self.AWSSECRETACCESSKEY=USERDATA.get('SecretAccessKey') 
        self.TOKENSERIALNUMBER=USERDATA.get('MfaSerialNumber')
   
class ststoken:
    DEBUG=0

    def getNewSessionToken(self,FILENAME,TOKENDURATIONSECONDS):
        print "awscredentials: Obtaining new session credentials and saving to file "+FILENAME+"..."

        TOKENCODE0=raw_input("Enter the MFA code: ")    
        TOKENCODE=str(int(TOKENCODE0)).zfill(6)

        if self.DEBUG == 1:
            print "TOKENCODE=%s" % TOKENCODE

        STSCLIENT=client('sts',aws_access_key_id=self.USERKEYS.AWSACCESSKEYID,aws_secret_access_key=self.USERKEYS.AWSSECRETACCESSKEY)
        self.STSTOKEN=STSCLIENT.get_session_token(DurationSeconds=TOKENDURATIONSECONDS,SerialNumber=self.USERKEYS.TOKENSERIALNUMBER,TokenCode=TOKENCODE)
        self.CREDENTIALS=self.STSTOKEN.get('Credentials')

        if self.DEBUG == 1:
            print(self.STSTOKEN.keys())
            print(self.CREDENTIALS.keys())
            
        f=open(FILENAME, 'w')
        f.write(json.dumps(self.STSTOKEN, default=json_mod.json_serial))
        f.close()

    def loadExistingSessionToken(self,FILENAME):
        print "awscredentials: Reading existing session credentials from file "+FILENAME+"..."

        f=open(FILENAME, 'r')
        self.STSTOKEN=json.load(f)
        f.close()

        self.CREDENTIALS=self.STSTOKEN.get('Credentials')
        
    def __init__(self,FILENAME,TOKENDURATIONSECONDS,USERDATAFILE): 
        self.USERKEYS=userkeys(USERDATAFILE)
        FILEEXISTS=path.isfile(FILENAME)
        
        if FILEEXISTS == 0:
            self.getNewSessionToken(FILENAME,TOKENDURATIONSECONDS)

        else:
            self.loadExistingSessionToken(FILENAME)

            TIMESTAMP_DT=datetime.utcnow()
            TIMESTAMP_STR=TIMESTAMP_DT.strftime("%Y-%m-%d %H:%M:%S")
            if self.DEBUG == 1:
                print "TIMESTAMP_STR=%s" % TIMESTAMP_STR
                print ""

            SESSIONEXPIRE_STR=self.CREDENTIALS.get('Expiration')
            if self.DEBUG == 1:
                print "SESSIONEXPIRE_STR=%s" % SESSIONEXPIRE_STR
            
            SESSIONEXPIRE_DT=datetime.strptime(SESSIONEXPIRE_STR,"%Y-%m-%d %H:%M:%S")
            if self.DEBUG == 1:
                print "SESSIONEXPIRE_DT="
                print SESSIONEXPIRE_DT
                print ""

            VALIDSESSION=(TIMESTAMP_DT<SESSIONEXPIRE_DT)
            #VALIDSESSION=1
            
            if self.DEBUG == 1:
                print "VALIDSESSION=%s" % VALIDSESSION

            if VALIDSESSION == 1:
                print "awscredentials: Session credentials valid until %s UTC (Current time is %s UTC)" % (SESSIONEXPIRE_STR,TIMESTAMP_STR)

            else:
                print "awscredentials: Session credentials expired at %s UTC (Current time is %s UTC)" % (SESSIONEXPIRE_STR,TIMESTAMP_STR)
                self.getNewSessionToken(FILENAME,TOKENDURATIONSECONDS)
                
        self.SESSIONTOKEN=self.CREDENTIALS.get('SessionToken')
        self.SECRETACCESSKEY=self.CREDENTIALS.get('SecretAccessKey')
        self.ACCESSKEYID=self.CREDENTIALS.get('AccessKeyId')

        if self.DEBUG == 1:
            print ""
            print "SESSIONTOKEN=%s" % self.SESSIONTOKEN
            print ""
            print "SECRETACCESSKEY=%s" % self.SECRETACCESSKEY
            print ""
            print "ACCESSKEYID=%s" % self.ACCESSKEYID
            print ""
