diff --git a/src-py/plexcryptool/scripts/oaep_rsa.py b/src-py/plexcryptool/scripts/oaep_rsa.py index 8d2d7c5..4964a67 100755 --- a/src-py/plexcryptool/scripts/oaep_rsa.py +++ b/src-py/plexcryptool/scripts/oaep_rsa.py @@ -8,6 +8,7 @@ Perform RSA-OAEP """ import hashlib +import random from math import floor # the given key in the assignment @@ -39,21 +40,39 @@ def mgf1(seed: bytearray, length: int, hash_func=hashlib.sha256) -> bytearray: # 4.Output the leading l octets of T as the octet string mask. return bytearray(T[:length]) -def byte_xor(ba1: bytearray, ba2: bytearray): - a = int(ba1.hex(), 16) - b = int(ba2.hex(), 16) - c = a^b - if c.bit_length() / 8 == floor(c.bit_length() / 8): - clen = floor(c.bit_length() / 8) - else: - clen = floor(c.bit_length() / 8) + 1 - return bytearray(c.to_bytes(clen, 'big')) +def byte_xor(ba0: bytearray, ba1: bytearray): + """ + helper function for bytewise xor + """ + ba2: bytearray = bytearray(0) + for (b0, b1) in zip(ba0, ba1): + ba2.append((b0 ^ b1)) + #print("xored:\n%s" % ba2.hex()) + return ba2 +def calclen(n: int) -> int: + """ + helper function to calculate the length in bytes + """ + len = n.bit_length() / 8 + if len > floor(len): + return 1 + floor(len) + else: + return floor(len) -def rsa_oaep_inner(seed: bytearray, datablock: bytearray) -> tuple[bytearray, bytearray, bytearray]: - mgf_seed = mgf1(seed, len(seed)) - masked_db = byte_xor(mgf_seed, datablock) - print(masked_db.hex()) +def rsa_oaep_inner(seed: bytearray, block: bytearray) -> tuple[bytearray, bytearray]: + """ + inner function of rsa-oaep + """ + mgf_seed = mgf1(seed, len(block)) + print("mgf1(seed):\n%s" % mgf_seed.hex()) + masked_db = byte_xor(mgf_seed, block) + print("mgf1(seed) ^ block:\n%s" % masked_db.hex()) + mask_seed = mgf1(masked_db, len(seed)) + print("mgf1(mgf1(seed) ^ block):\n%s" % mask_seed.hex()) + masked_seed = byte_xor(seed, mask_seed) + print("mgf1(mgf1(seed) ^ block) ^ seed:\n%s" % masked_seed.hex()) + return (masked_seed, masked_db) def test_rsa_oaep_inner(): seed: bytearray = bytearray.fromhex("aa1122fe0815beef") @@ -64,8 +83,8 @@ def test_rsa_oaep_inner(): 00000000000000000000000000000000000001466f6f626172203132 33343536373839 """) - print("seed\t%s" % seed.hex()) - print("db\t%s" % db.hex()) + print("seed:\n%s" % seed.hex()) + print("db:\n%s" % db.hex()) result = rsa_oaep_inner(seed, db) @@ -85,8 +104,36 @@ def test_rsa_oaep_inner(): """) GIVEN_MASK_FOR_SEED = bytearray.fromhex("713162084a4e0e6d ") GIVEN_MASKED_SEED = bytearray.fromhex("db2040f6425bb082") - assert result[0] == GIVEN_MASKED_SEED, "is %s" % result[0].hex() - assert result[1] == GIVEN_MASKED_DB, "is %s" % result[1].hex() + assert result[0] == GIVEN_MASKED_SEED, "is\n%s\ninstead of\n%s" % (result[0].hex(), GIVEN_MASKED_SEED.hex()) + assert result[1] == GIVEN_MASKED_DB, "is\n%s\ninstead of\n%s" % (result[1].hex(), GIVEN_MASKED_DB.hex()) + +def rsa_oaep(ha: bytearray, m: bytearray): + # generate a seed + seed = random.randint(0, 2**64 - 1) + seed = bytearray(seed.to_bytes(calclen(seed), 'big')) + # build the message + block: bytearray = bytearray(0) + assert len(block) == 0 + curlen = 0 + curlen += len(ha) + curlen += len(m) + block += ha + block += bytearray(calclen(GIVEN_PUB_KEY[0]) - curlen) + block += m + + assert len(block) == calclen(GIVEN_PUB_KEY[0]), "curlen:\n%s\nmodlen:\n%s" % (curlen, calclen(GIVEN_PUB_KEY[0])) + result = rsa_oaep_inner(seed, block) + print() + print(result[0].hex()) + print(result[1].hex()) + print() + return bytearray(1) + result[0] + result[1] + +def main(): + ha = bytearray(0) + m = bytearray.fromhex("466f6f62617220313233343536373839") + r = rsa_oaep(ha, m) + print("final:\n%s" % r.hex()) if __name__ == "__main__": - test_rsa_oaep_inner() + main()