#!/usr/bin/python3
import random

# RSA Motivation
# a=12
# n=9973
# [ a**i % n  for i in range(n) ]

'''
-------------------------------------------------------
RSA 
-------------------------------------------------------
RSA Fundamental Theorem:
Let p,q be any prime numbers. Let m in N be s.t. gcd(m,p*q)=1
then m**((p-1)*(q-1))=1 mod (p*q)

Example: 
m = 56
p,q = 101, 103
(56 ** ((p-1)*(q-1)))% (p*q)

Note: 
m**(1+(p-1)*(q-1))=m mod (p*q)
Also
m**(1+k*(p-1)*(q-1))=m mod (p*q)

This suggests a crypto system.

Find e,d such that 1=e*d+k*(p-1)*(q-1), that is...
e*d = 1-k*(p-1)*(q-1)

Then you can encrypt with 

m**e mod (p*q) and decrypt with (m**e)**d mod (p*q)

RSA is the following...

0) Choose large primes p and q
   FACT: Can find large primes quickly.

1) Choose encryption key e in 1,...,(p-1)*(q-1)
   such that gcd(e,(p-1)*(q-1))=1

2) Extended Euclidean Algorithm says that 

   if gcd(x,y)=z then we can find a,b in Z such that
   z = a*x+b*y

   In our setting, we have 
   if gcd(e,(p-1)*(q-1))=1 then we can find d,k in Z such that
   1 = d*e+k*(p-1)*(q-1)

   that is e*d = 1 - k*(p-1)*(q-1)

   FACT: Extended Euclidean Algorithm says Can find d from e easily (=quickly)

3) Keys:

   public_key = (e,p*q)
   private_key = (d,p*q)

4) Encrypt
   c = (m**e)%(p*q)

5) Decrypt
   Lets decrypt...
   (c**d)%(p*q) = ((m**e)**d)%(p*q) 
                = (m**(e*d))%(p*q)
                = (m**(1+k*(p-1)*(q-1))%(p*q)
                = (m*m**(k*(p-1)*(q-1))%(p*q)
                = (m*(m**((p-1)*(q-1)))**k %(p*q)
                = m*1**k %(p*q)
                = m

EXERCISE: Compute (x**y)%n quickly. Python has pow function for this.

--------------------------------------------------------
Security of RSA
--------------------------------------------------------
Say Eve intercepts cyphertext c

Eve
Knows                       Doesn't Know
c
(e,n)                       (d,n)


The equations...
e*d = 1 mod (p-1)*(q-1)
n = p*q
without knowing d,p,q

It seems that Eve needs to find
(p-1) and (q-1)
to do this, Eve seems to need to find
p,q such that p*q = n

That is, to factor n. 

FACT: No one knows a fast way to factor in general (YET). 

FACT: Factoring is in NP intersect CoNP, everything in NP intersect
CoNP has fallen to P so far. This is a worry.

-------------------------------------------------------
Background
-------------------------------------------------------

[ 2**(i) % 20 for i in range(10) ] 
[ 3**(i) % 20 for i in range(10) ] 
[ 4**(i) % 20 for i in range(10) ] 
[ 5**(i) % 20 for i in range(10) ] 

FACT: if m and n do not share factors, then 

[ m**(i) % n for i in range(1,n) ] is interesting, in fact

set([ m**(i) % n for i in range(1,n) ]) forms a GROUP, 
with respect to multiplication

The important part of this is...
set([ 7**(i) % 20 for i in range(1,20) ]) == { 1,3,7,9 }
1) 1 is in the set
2) The set is closed under multiplication
   Example (9*3)%20 == 7
3) Every element has its inverse in the set
   Example: 9*9 % 20 == 1
   Example: 3*7 % 20 == 1

Example: 
p = 101
for j in range(2,100):
    group  = set([ j**i % p for i in range(1,p) ])
    print(j,len(group), group)

NOTE: if gcd(m,n)!=1 then the above is NOT true. For example.
set([ (2*3)**i % (3*5) for i in range(1,3*5) ])

For what follows, N={0,1,2,...} is the set of Natural Numbers.

Def: for p,n in N, p divides n if there is k in N
     such that n = p * k.

Note: We write p is a divisor of n, p is a factor of n,
      n is a multiple of p
    
'''
def divides_slow(p,n):
	# Note p can not be 0 here!
	if p==0: return False
	for k in range(0,n+1):
		if n == p*k:
			return True
	return False

def divides(p,n): # Note: p cannot be 0 here!
	if p == 0: return False
	return n%p == 0 

def divisors(n):
	''' The list of all divisors of q '''
	# FACT: No one knows how to compute this function quickly (YET)
	return [ p for p in range(1,n+1) if divides(p,n) ]

'''
Def: n in N is PRIME if n>=2 and the only divisors
     of n are 1 and n.
'''

def is_prime(n): # Note this is not a fast primality testing algorithm
	# FACT: There is a very fast primality testing algorithm
	return n>=2 and divisors(n)==[1,n]

def primes(n):
	''' return a list of the prime less than or equal to n '''
	return [ p for p in range(2,n+1) if is_prime(p) ]

'''
Fundamental Theorem of Arithmetic:
    Every n in N, n>=2 can be decomposed into prime factors.
    The decomposition is unique, up to the ordering of factors.

    That is, n can be written as the product of prime numbers.
    The ordering is unique up to the ordering of the primes.
'''

def prime_factorization(n):
	''' The algorithm below suggests an inductive proof! '''
	pf = []
	if n<2: return pf
	p = 2
	while p<=n:
		if divides(p,n): 
			pf.append(p)
			n = n // p
		else:
			p = p + 1
	return pf

'''
Division Theorem: For every n,m in N, there is a unique q,r in N
such that n = m*q+r. q is the quotient, r is the remainder.
'''

def division(n,m):
	return (n//m, n%m)

'''
Def: We write a = b mod n if there is some k in Z such that a-b=k*n
     In other words, a and b differ by a multiple of n.
     In other words, if n divides (a-b)

     Alternatively: a=b mod n, if a%n=b%n
     Proof: Exercise
'''
def is_congruent(a,b,n):
	return divides(n, a-b)
	# return a%n==b%n

'''
Examples:
[ i for i in range(100) if is_congruent(7,i,23) ]
[ i for i in range(100) if is_congruent(i,7,23) ]
[ 7+23*i for i in range(5) ]
'''

''' Some properties of congruence
1) a = a mod n
2) a = b mod n                 =>    b = a mod n
3) a = b mod n and b = c mod n =>    a = c mod n
4) a = b mod n                 =>  a+c = b+c mod n
5) a = b mod n                 =>  a*c = b*c mod n
6) a = b mod n                 => a**k = b**k mod n
7) a = b mod n and c = d mod n =>  a+c = b+d mod n
8) a = b mod n and c = d mod n =>  a*c = b*d mod n

Exercise: Prove 7, 8 and then 6
'''

def factorial(n):
	if n==0: return 1
	return n*factorial(n-1)

[ (i,factorial(i)) for i in range(10) ]

# choose(m,n) = m choose n = m!/(n!(m-n)!)
def choose(m,n):
	return factorial(m)/(factorial(n)*factorial(m-n))

[ "({},{})={}".format(6,i,choose(6,i)) for i in range(6+1) ]
[ "({},{})%6={}".format(6,i,choose(6,i)%6) for i in range(6+1) ]
[ choose(6,i)%6 for i in range(6+1) ]


[ "({},{})={}".format(7,i,choose(7,i)) for i in range(7+1) ]
[ "({},{})%7={}".format(7,i,choose(7,i)%7) for i in range(7+1) ]
[ choose(7,i)%7 for i in range(7+1) ]

[ "({},{})={}".format(101,i,choose(101,i)) for i in range(101+1) ]
[ "({},{})%101={}".format(101,i,choose(101,i)%101) for i in range(101+1) ]
[ choose(101,i)%101 for i in range(101+1) ]

'''
Proposition: if p is prime, n in {1,2,...,m-1} 
then choose(p,n)%m==0. That is, choose(p,n) is divisible by p 

Proof: exercise

#----------------------------------------------------------------
Fermat's little Theorem (p1): 
Let p be any prime, a any integer, then a**p = a mod p 
#----------------------------------------------------------------

#----------------------------------------------------------------
Example (p1): 
p = 19 
p in primes(100)

a = 142
a**p % p == a % p
is_congruent(a**p, a, p)
#----------------------------------------------------------------
Proof: 
https://artofproblemsolving.com/wiki/index.php?title=Fermat%27s_Little_Theorem

We prove this by induction, on a 
Base Case: 1**p % p = 1

Induction Hyp: Assume a**p % p = a for a>=1
Induction Step: 

	(a+1)**p = (a+1)*(a+1)*...*(a+1)

	         We write (p,2) for choose(p,2) below...
	         by the bynomial theorem ...

	         = a**p (p,0) + a**(p-1) (p,1) + a**(p-2) (p,2) + ... + a**1 (p,1) + a**0 (p,0)
                                           ^                ^     ...          ^
                         divisible by p----+----------------+------------------+
	so 
	(a+1)**p % p = (a**p + 1) % p
                     = (a + 1)  by Induction Hypothesis.

QED

#----------------------------------------------------------------
Fermat's little Theorem (p2): 
Let p be any prime, a any integer such that p does not divide a
then a**(p-1) = 1 mod p
#----------------------------------------------------------------
#----------------------------------------------------------------
Example (p2):

p = 19 
p in primes(100)
a = 142

divides(p,a)
a**(p-1) == 1 %p 
is_congruent(a**(p-1), 1, p)

[a**i%p for i in range(1,p+1)]
#----------------------------------------------------------------

Proof: 
From FLT(p1), we have 
    a**p % p     = a
    (a**p - a) % p = 0
    (a**(p-1) -1) * a % p =0
    now if p does not divide a, then p divides 
    (a**(p-1) -1)
    so 
    (a**(p-1) -1) % p = 0
    so
    a**(p-1) % p = 1
QED

#----------------------------------------------------------------
RSA Fundamental Theorem:
Let p,q be any prime numbers. Let m in N be s.t. gcd(m,p*q)=1
then m**((p-1)*(q-1))=1 mod (p*q)

Proof: 
    Fermat's Little Theorem applied to a=m**(p-1), prime q yields
    (m**(p-1))**(q-1) = 1 mod q 
    so
    m**((p-1)*(q-1)) = 1 mod q 
    so q divides m**((p-1)*(q-1)) - 1  
    so m**((p-1)*(q-1)) - 1  = k1*q
    
    Fermats Little Theorem applied to a=m**(q-1), prime p yields
    (m**(q-1))**(p-1) = 1 mod p
    so 
    m**((q-1)*(p-1)) = 1 mod p
    so
    m**((p-1)*(q-1)) = 1 mod p
    
    so p divides m**((p-1)*(q-1)) - 1  
    so m**((p-1)*(q-1)) - 1 = k2*p
    
    So k1*q = k2*p
    p and q are primes, so p does not divide q, but then
    p divides k1, since p divides k1*q, so k1=k3*p, so k1*q = k3*p*q
    but then p*q divides m**((p-1)*(q-1))-1 so
    m**((p-1)*(q-1))=1 mod pq
QED

http://www.mathaware.org/mam/06/Kaliski.pdf
https://www.cs.utexas.edu/~mitra/honors/soln.html
'''
def egcd(a, b):
	# FACT: the extended Euclidean algorithm (below) is fast 
	if a == 0:
		return (b, 0, 1)
	else:
		g, y, x = egcd(b % a, a)
		return (g, x - (b // a) * y, y)

def modinv(a, m):
	# find the inverse of a mod m, that is
	# a*modinv(a,m) = 1 mod m
	# or
	# a*modinv(a,m) % m == 1
	gcd, x, y = egcd(a, m)
	if gcd != 1:
		return None  # modular inverse does not exist
	else:
		return x % m

# egcd and modinv above two are taken from 
# https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm

def rsa_keys(p,q):
	# FACT: Can find large primes quickly, can test for primality quickly
	if not(is_prime(p) and is_prime(q)): return None
	n = p*q # modulus
	# find e,d such that e*d = 1 mod (p-1)*(q-1)
	# That is e*d - 1 = k*(p-1)*(q-1)
	# that is e*d - k*(p-1)*(q-1) = 1
	# Can do this by picking e, then finding d and k (don't need k)
	# using the Extended Euclidean algorithm
	while True:
		e = random.randrange(2,(p-1)*(q-1))
		d = modinv(e,(p-1)*(q-1))
		if d is not None: break
	# Now know that e*d = 1 mod (p-1)*(q-1)
	public_key = (e,n) # e is exponent, n is modulus 
	private_key = (d,n)
	return public_key, private_key
	
def encrypt(m, public_key):
	(e,n) = public_key
	if m>n: return None
	# if m and n have a common divisor: return None
	# this is unlikely, but if it happens, add some padding
	# c = m**e % n # can be computed more efficiently
	c = pow(m,e,n)
	return c

def decrypt(c,private_key):
	(d,n) = private_key
	# m = c**d % n # can be computed more efficiently
	m = pow(c,d,n)
	return m

'''
References: http://certauth.epfl.ch/rsa/

Example:

import rsa
rsa.rsa_keys(61519, 61561   )
((675825377, 3787171159), (3601085393, 3787171159))
rsa.encrypt(1234, (675825377, 3787171159))
3615735883L
rsa.decrypt(3615735883L, (3601085393, 3787171159))
1234L
'''

