home - links - about - /git/ - rss

Zer0pts 2021 - OT or not OT [EN]

OT or not OT

This was a crypto challenge at zer0pts CTF 2021.

We are given the script below, and an endpoint on which its runnning, crypto.ctf.zer0pts.com:10130.

import os
import signal
import random
from base64 import b64encode
from Crypto.Util.number import getStrongPrime, bytes_to_long
from Crypto.Util.Padding import pad
from Crypto.Cipher import AES
from flag import flag

p = getStrongPrime(1024)

key = os.urandom(32)
iv = os.urandom(AES.block_size)
aes = AES.new(key=key, mode=AES.MODE_CBC, iv=iv)
c = aes.encrypt(pad(flag, AES.block_size))

key = bytes_to_long(key)
print("Encrypted flag: {}".format(b64encode(iv + c).decode()))
print("p = {}".format(p))
print("key.bit_length() = {}".format(key.bit_length()))

signal.alarm(600)
while key > 0:
    r = random.randint(2, p-1)
    s = random.randint(2, p-1)
    t = random.randint(2, p-1)
    print("t = {}".format(t))

    a = int(input("a = ")) % p
    b = int(input("b = ")) % p
    c = int(input("c = ")) % p
    d = int(input("d = ")) % p
    assert all([a > 1 , b > 1 , c > 1 , d > 1])
    assert len(set([a,b,c,d])) == 4

    u = pow(a, r, p) * pow(c, s, p) % p
    v = pow(b, r, p) * pow(c, s, p) % p
    x = u ^ (key & 1)
    y = v ^ ((key >> 1) & 1)
    z = pow(d, r, p) * pow(t, s, p) % p

    key = key >> 2

    print("x = {}".format(x))
    print("y = {}".format(y))
    print("z = {}".format(z))

Analysis

For each two bits in the AES key, the server will generate \(r,s,t \in [2, p-1]\) and gives us only \(t\).

Then we have to provide \(a,b,c,d \in [2,p-1]\), which must all be different from each other.

The server will compute

\begin{eqnarray} u = a^r \times c^s \mod p \\ v = b^r \times c^s \mod p \\ z = d^r \times t^s \mod p \\ \end{eqnarray}

Finally ,we are given \(x = u \oplus k_0,y = v \oplus k_1\) and \(z\), where \(k_0\) and \(k_1\) are two key bits.

Choosing \(a,b,c,d\)

By fixing specific values for \(a,b,c,d\) its possible to recover the original values of \(u\) and \(v\), and compare them to the given ones to recover \(k_0\) and \(k_1\)

First, lets set \(d = p - 1\).

By doing this, we have

\begin{cases} d^r = 1 \mod p & r \text{ even} \\d^r = p - 1 \mod p & r \text{ odd} \end{cases}

thus,

\begin{cases} z = t^s \mod p & r \text{ even} \\z = (p-1) \times t^s \mod p & r \text{ odd} \end{cases}

Now, lets set \(c = t\), and \(b = 2a\). For simplicity, lets fix \(a = 2\). This gives us

\begin{eqnarray} u = 2^r \times t^s \mod p \\ v = 4^r \times t^s \mod p \\ \end{eqnarray}

What we can do now is try both possibilites for \(r\).

  • If \(r\) is even, \(z = t^s \mod p\), so we can easily get \(t^{-s}\) by inverting \(z\).

  • Otherwise, if \(r\) is odd, we have \(t^s = z \times (p-1)^{-1} \mod p\) and we can also find \(t^{-s}\)

Now that we have \(t^{-s}\), we can try all \((k_0', k_1') \in [0,1]^2\) to find the key bits.

In order to validate our guess, we compute the following :

\begin{eqnarray} v_0 = (x \oplus k_0') \times t^{-s} \mod p = ((2^r \times t^s) \oplus k_0 \oplus k_0') \times t^{-s} \mod p \\ v_1 = (y \oplus k_1') \times t^{-s} \mod p = ((4^r \times t^s) \oplus k_0 \oplus k_0') \times t^{-s} \mod p \\ \end{eqnarray}

Now, if our guess is good, ie \(k_0' = k_0\) and \(k_1' = k_1\), everything should simplify to :

\begin{eqnarray} v_0 = 2^r \times t^s \times t^{-s} = 2^r \mod p \\ v_1 = 4^r \times t^s \times t^{-s} = 4^r \mod p \\ \end{eqnarray}

To check if that's the case, we only have to verify if \(v_0^2\) is equal to \(v_1 \mod p\),

Here is an implementation of the above solution. It starts by guessing that \(r\) is even, and if the relation \(v_0^2 = v_1 \mod p\) doesnt hold for any \((k_0',k_1')\), it tries the same withe \(r\) odd.

from pwn import remote, log
import random
from sage.all import inverse_mod
from Crypto.Util.number import long_to_bytes
from Crypto.Cipher import AES
import base64

io = remote("crypto.ctf.zer0pts.com", 10130)
encflag = base64.b64decode(io.recvline().decode().split(": ")[1])
p = int(io.recvline().decode().split(" = ")[1])
bitlen = int(io.recvline().decode().split(" = ")[1])

key = 0
for i in range(bitlen//2):
    t = int(io.recvline().decode().split(" = ")[1])
    c = t
    d = p-1
    a = 2
    b = 2*a

    io.sendline(str(a).encode())
    io.sendline(str(b).encode())
    io.sendline(str(c).encode())
    io.sendline(str(d).encode())

    x = int(io.recvline().decode().split(" = ")[-1])
    y = int(io.recvline().decode().split(" = ")[-1])
    z = int(io.recvline().decode().split(" = ")[-1])

    def find_bits(tsinv):
        for k0 in [0,1]:
            for k1 in [0,1]:
                v0 = (x^k0) * ts_inv % p
                v1 = (y^k1) * ts_inv % p

                if (v0*v0) % p == v1:
                    return 2*k1 + k0

    # Case r even, d^r = 1, z = t^s
    ts_inv = inverse_mod(z, p)
    bits = find_bits(ts_inv)

    # Case r odd
    if bits is None:
        ts = z * inverse_mod(p-1, p) % p
        ts_inv = inverse_mod(ts, p)
        bits = find_bits(ts_inv)
    key |= (bits << 2*i)
    log.info("bits {} & {} found, key={:b}".format(2*i, 2*i+1, key))

key = long_to_bytes(key, 32)
iv,flag = encflag[:16], encflag[16:]
aes = AES.new(key=key, mode=AES.MODE_CBC, iv=iv)

print(aes.decrypt(flag)) # zer0pts{H41131uj4h_H41131uj4h}

Creative Commons License