26
Apr
2013

pCTF 2013 – giga (crypto 250)

In this challenge we get a network service which generates a RSA keypair, encrypts a flag with it and shows you the ciphertext, and then allows you to encrypt a bunch of different plaintexts and view the corresponding ciphertexts. The goal is to decrypt the encrypted flag somehow.

Normally this should not be possible with RSA, so there have to be some bugs. Let’s look at the code:

#!/usr/bin/env python
import os
from Crypto.PublicKey import RSA
from Crypto.Hash import MD5
import SocketServer
import threading
import time
 
rbuf = os.urandom(4096)
hr = MD5.new()
 
flag = open("secret").read()
 
def rng(n):
  global rbuf
  rand = rbuf[:n]
  rbuf = rbuf[n:]
  while (len(rbuf) < 4096):
    hr.update(os.urandom(4096))
    rbuf += rbuf + hr.hexdigest()
  return rand
 
 
class threadedserver(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
    pass
 
class incoming(SocketServer.BaseRequestHandler):
  def handle(self):
    cur_thread = threading.current_thread()
    welcome = """
*******************************************
*** Welcome to GIGA! ***
**the super secure key management service**
*******************************************
 
We are generating an RSA keypair for you now.
(Please be sure to move your mouse to populate the entropy stream)
"""
    self.request.send(welcome)
    rsa = RSA.generate(1024,rng)
    print getattr(rsa,'n')
    #make it look like we're doing hardcore crypto
    for i in xrange(20):
      time.sleep(0.2)
      self.request.send(".")
    self.request.send("\nCongratulations! Key created!\n")
 
    #no one will ever be able to solve our super challenge!
    self.request.send("To prove how secure our service is ")
    self.request.send("here is an encrypted flag:\n")
    self.request.send("==================================\n")
    self.request.send(rsa.encrypt(flag,"")[0].encode("hex"))
    self.request.send("\n==================================\n")
    self.request.send("Find the plaintext and we'll give you points\n\n")
 
    #now they can be safe from the FBI too!
    while True:
      self.request.send("\nNow enter a message you wish to encrypt: ")
      m = self.request.recv(1024)
      self.request.send("Your super unreadable ciphertext is:\n")
      self.request.send("==================================\n")
      self.request.send(rsa.encrypt(m,"")[0].encode("hex"))
      self.request.send("\n==================================\n")
 
server = threadedserver(("0.0.0.0", 4321), incoming)
server.timeout = 4
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
 
server_thread.join()

Now there are a couple of things we noticed immediately:

* The client reads blocks of 1024 bytes, but the RSA key is only 1024 *bits*, so it can only encrypt 127 bytes (not even 128, because the plaintext has to be lower than the RSA modulus).
* There is not RSA padding done, which is always a very bad idea.
* There is some custom random generator, which looks weird.

A closer look at the random number generator reveals a subtle bug:

def rng(n):
  global rbuf
  rand = rbuf[:n]
  rbuf = rbuf[n:]
  while (len(rbuf) < 4096):
    hr.update(os.urandom(4096))
    rbuf += rbuf + hr.hexdigest()   # this uses both "rbuf +=" and "rbuf + hr.hexdigest"!
  return rand

As you can see, there is a bug which leads to the “random” data being mostly reused and duplicated old random data, instead of fresh new randomness. Since this is used to generate RSA keys, this is very bad!

This randomness is used to generate two primes which are multiplied together to form the RSA modulus. RSA is secure because it is very hard to find these two primes from just the modulus (the public key). But if there are two different keys which have a common prime (because of the low level of randomness), we can easily find this common prime using the Greatest Common Divisor function.

As was described in the great FactHacks talk at 29C3 by Daniel J Bernstein and Tanja Lange RSA keys with low entropy (not enough randomness) is actually a common problem for many devices on the internet, especially embedded ones. Their presentation goes into a lot of detail about how you can combine a huge amount of keys and break the ones that share enough random data that their moduli share a entire prime.

Because this is a CTF challenge we do not expect to deal with massive amounts of keys here however, so we can probably get away with using the simple GCD algorithm on a bunch of keys that we gather from the service.

In fact, when we modify the program to just dump out a bunch of public keys, we immediately see a lot of public keys that share a prime. So this attack should work! But there is a problem: how can we get the public key?

Our first idea was to see if the service responds differently if we send it a value that is larger than the modulus. As mentioned before, such values cannot legally be encrypted, so we would expect an error of some kind. Using this information we could find the public key using a simple binary search (checking for each bit whether setting that bit would make our guess larger than the modulus).

Unfortunately, the server actually gives an exception when we give it a value higher than the modulus, which kills our connection. And when we reconnect, it will have a new key, so we have learned nothing useful.

We were stuck here for a while until we decided that it really should not be that hard to figure out the modulus if you can get the server to encrypt arbitrary values, and sat down with a notepad and a python interpreter shell.

This is what we came up with (and yes, other people have figured this out long ago, we just didn’t google correctly):

a = pow(2,   0x10001, n)
b = pow(2*2, 0x10001, n)
x = (a*a - b)

a = pow(3,   0x10001, n)
b = pow(3*3, 0x10001, n)
y = (a*a - b)

n = gcd(x,y)

The pow() function here is simply RSA encryption: 0x10001 is the default public exponent used by most cryptographic libraries these days (including the one used by the server), and n is the RSA modulus. We don’t know n, but we can get the result of pow(x, 0x10001, n) by simply asking the server to encrypt x for us.

Now an encryption is a modular exponentiation. That basically means the following:

pow(x, 0x10001, n) == y
pow(x, 0x10001)    == n * u + y

If we take the square of that, it becomes:

pow(pow(x, 0x10001), 2) == pow(n * u + y, 2)
pow(pow(x, 0x10001), 2) == (n * u + y) * (n * u + y)
pow(pow(x, 0x10001), 2) == n*n*u*u + 2*n*u*y + y*y
pow(pow(x, 0x10001), 2) == n*(n*u*u + 2*u*y) + y*y

Now, on the other hand, let’s write the encryption of pow(x,2):

pow(pow(x, 2), 0x10001) == n * v + z

We don’t know u, and we don’t know v, but it turns out there is an interesting relationship between z and y*y:

y * y == n * w + z

Since we actually know y and z (they are the results of the encryption of x and pow(x,2), respectively) we can calculate n * w:

n * w == y * y - z

So now we have a number that is a small multiple of n. Recall that n is the product of two large prime numbers, so it does not have any divisors besides those two primes. Thus we can do this whole procedure for two different values of x, and take the gcd of the resulting values for (n * w). This will result in n, unless the two values of w share a common divisor. So better check if the result is indeed 1024 bits, but with high probability we will get the correct n, which is the RSA modulus.

Now it is time to build the complete attack. First we build a script that connects to the server, stores the encrypted flag, recovers the RSA modulus using the above procedure, and stores it along with the encrypted flag. We run this script a bunch of times in the background to collect a series of (encrypted flag, RSA modulus) pairs.

Then we write a processing script that loops through all the RSA moduli and finds two that share a prime, using the GCD algorithm. Once we have the two primes, we can easily find the RSA private key using the same method that is used when generating new RSA keys:

from Crypto.Util.number import inverse

def getpriv(p,q):
    return inverse(0x10001, (p-1) * (q-1))

The inverse function finds the modular inverse of a number if it exists. It’s not hard to write, but the python crypto api already offers a perfectly fine implementation.

It is then only a matter of decrypting the encrypted flag that we stored with the modulus we just cracked.

The full modulus/encrypted flag harvesting script:


from Crypto.Util.number import long_to_bytes, bytes_to_long, GCD

import socket
import sys

def waitfor(s,txt):
    res = ''
    while True:
        tmp = s.recv(1)
        if not tmp: break
        if tmp == '\n': print "RECV: %r" % (res.split('\n')[-1])
        res += tmp
        if res.endswith(txt): break
    return res

s = socket.create_connection(('184.73.59.25',4321))

header = waitfor(s,"==================================\n")
eflag = waitfor(s,"==================================\n").split('\n')[0]
eflag = bytes_to_long(eflag.decode('hex'))

print 'eflag = %x' % (eflag,)

def encryptval(val):
    waitfor(s,"Now enter a message you wish to encrypt:")
    s.send(long_to_bytes(val,1024))
    header = waitfor(s,"==================================\n")
    etxt = waitfor(s,"==================================\n").split('\n')[0]
    return bytes_to_long(etxt.decode('hex'))

a = encryptval(2)
b = encryptval(4)

x = (a*a - b)

a = encryptval(3)
b = encryptval(9)

y = (a*a - b)

n = GCD(x,y)

print "n = %x" % n
print "bits = %d" % n.bit_length()

with open("pubdumps/" + ("%x" % n)[:32], 'w+') as f:
    f.write("%d,%d" % (n,eflag))

The postprocessing script that finds keys which share a prime, cracks them, and decrypts the flag:

import os

from Crypto.Util.number import GCD, inverse, long_to_bytes

os.chdir("./pubdumps/")

fns = filter(os.path.isfile, os.listdir('.'))

pairs = []
for fn in fns:
    with open(fn) as f: 
        pairs.append(tuple(map(int,f.read().split(','))))

print "%d files loaded" % len(pairs)

def getpriv(p,q):
    return inverse(0x10001, (p-1) * (q-1))

total = 1
for n,x in pairs:
    d = GCD(n,total)
    if d != 1 and d.bit_length() < 1024 and d.bit_length() > 200:
        for m,y in pairs:
            if m % d == 0:
                print m, y
                print repr(long_to_bytes(pow(y, getpriv(m / d, d), m)))
        print "WOOT", d
        break
    total *= n

print "processing done"

Comments are closed.