Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 118 additions & 18 deletions src/ecdsa/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from . import ecdsa
from . import der
from . import rfc6979
from . import ellipticcurve
from .curves import NIST192p, find_curve
from .numbertheory import square_root_mod_prime, SquareRootError
from .ecdsa import RSZeroError
from .util import string_to_number, number_to_string, randrange
from .util import sigencode_string, sigdecode_string
Expand All @@ -23,6 +25,10 @@ class BadDigestError(Exception):
pass


class MalformedPointError(AssertionError):
pass


class VerifyingKey:
def __init__(self, _error__please_use_generate=None):
if not _error__please_use_generate:
Expand All @@ -38,9 +44,8 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1):
self.pubkey.order = curve.order
return self

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
@staticmethod
def _from_raw_encoding(string, curve, validate_point):
order = curve.order
assert (len(string) == curve.verifying_key_length), \
(len(string), curve.verifying_key_length)
Expand All @@ -50,10 +55,72 @@ def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
assert len(ys) == curve.baselen, (len(ys), curve.baselen)
x = string_to_number(xs)
y = string_to_number(ys)
if validate_point:
assert ecdsa.point_is_valid(curve.generator, x, y)
from . import ellipticcurve
point = ellipticcurve.Point(curve.curve, x, y, order)
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
raise MalformedPointError("Point does not lie on the curve")

return ellipticcurve.Point(curve.curve, x, y, order)

@staticmethod
def _from_compressed(string, curve, validate_point):
if string[:1] not in (b('\x02'), b('\x03')):
raise MalformedPointError("Malformed compressed point encoding")

is_even = string[:1] == b('\x02')
x = string_to_number(string[1:])
order = curve.order
p = curve.curve.p()
alpha = (pow(x, 3, p) + (curve.curve.a() * x) + curve.curve.b()) % p
try:
beta = square_root_mod_prime(alpha, p)
except SquareRootError as e:
raise MalformedPointError(
"Encoding does not correspond to a point on curve", e)
if is_even == bool(beta & 1):
y = p - beta
else:
y = beta
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
raise MalformedPointError("Point does not lie on curve")
return ellipticcurve.Point(curve.curve, x, y, order)

@classmethod
def _from_hybrid(cls, string, curve, validate_point):
assert string[:1] in (b('\x06'), b('\x07'))

# primarily use the uncompressed as it's easiest to handle
point = cls._from_raw_encoding(string[1:], curve, validate_point)

# but validate if it's self-consistent if we're asked to do that
if validate_point and \
(point.y() & 1 and string[:1] != b('\x07') or
(not point.y() & 1) and string[:1] != b('\x06')):
raise MalformedPointError("Inconsistent hybrid point encoding")

return point

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
sig_len = len(string)
if sig_len == curve.verifying_key_length:
point = klass._from_raw_encoding(string, curve, validate_point)
elif sig_len == curve.verifying_key_length + 1:
if string[:1] in (b('\x06'), b('\x07')):
point = klass._from_hybrid(string, curve, validate_point)
elif string[:1] == b('\x04'):
point = klass._from_raw_encoding(string[1:], curve,
validate_point)
else:
raise MalformedPointError(
"Invalid X9.62 encoding of the public point")
elif sig_len == curve.baselen + 1:
point = klass._from_compressed(string, curve, validate_point)
else:
raise MalformedPointError(
"Length of string does not match lengths of "
"any of the supported encodings of {0} "
"curve.".format(curve.name))

return klass.from_public_point(point, curve, hashfunc)

@classmethod
Expand All @@ -74,14 +141,20 @@ def from_der(klass, string):
if empty != b(""):
raise der.UnexpectedDER("trailing junk after DER pubkey objects: %s" %
binascii.hexlify(empty))
assert oid_pk == oid_ecPublicKey, (oid_pk, oid_ecPublicKey)
if not oid_pk == oid_ecPublicKey:
raise der.UnexpectedDER("Unexpected object identifier in DER "
"encoding: {0!r}".format(oid_pk))
curve = find_curve(oid_curve)
point_str, empty = der.remove_bitstring(point_str_bitstring)
if empty != b(""):
raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" %
binascii.hexlify(empty))
assert point_str.startswith(b("\x00\x04"))
return klass.from_string(point_str[2:], curve)
# the point encoding is padded with a zero byte
# raw encoding of point is invalid in DER files
if not point_str.startswith(b("\x00")) or \
len(point_str[1:]) == curve.verifying_key_length:
raise der.UnexpectedDER("Malformed encoding of public point")
return klass.from_string(point_str[1:], curve)

@classmethod
def from_public_key_recovery(cls, signature, data, curve, hashfunc=sha1,
Expand Down Expand Up @@ -110,23 +183,49 @@ def from_public_key_recovery_with_digest(klass, signature, digest, curve, hashfu
verifying_keys = [klass.from_public_point(pk.point, curve, hashfunc) for pk in pks]
return verifying_keys

def to_string(self):
# VerifyingKey.from_string(vk.to_string()) == vk as long as the
# curves are the same: the curve itself is not included in the
# serialized form
def _raw_encode(self):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
y_str = number_to_string(self.pubkey.point.y(), order)
return x_str + y_str

def _compressed_encode(self):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
if self.pubkey.point.y() & 1:
return b('\x03') + x_str
else:
return b('\x02') + x_str

def _hybrid_encode(self):
raw_enc = self._raw_encode()
if self.pubkey.point.y() & 1:
return b('\x07') + raw_enc
else:
return b('\x06') + raw_enc

def to_string(self, encoding="raw"):
# VerifyingKey.from_string(vk.to_string()) == vk as long as the
# curves are the same: the curve itself is not included in the
# serialized form
assert encoding in ("raw", "uncompressed", "compressed", "hybrid")
if encoding == "raw":
return self._raw_encode()
elif encoding == "uncompressed":
return b('\x04') + self._raw_encode()
elif encoding == "hybrid":
return self._hybrid_encode()
else:
return self._compressed_encode()

def to_pem(self):
return der.topem(self.to_der(), "PUBLIC KEY")

def to_der(self):
def to_der(self, point_encoding="uncompressed"):
order = self.pubkey.order
x_str = number_to_string(self.pubkey.point.x(), order)
y_str = number_to_string(self.pubkey.point.y(), order)
point_str = b("\x00\x04") + x_str + y_str
point_str = b("\x00") + self.to_string(point_encoding)
return der.encode_sequence(der.encode_sequence(encoded_oid_ecPublicKey,
self.curve.encoded_oid),
der.encode_bitstring(point_str))
Expand Down Expand Up @@ -247,10 +346,11 @@ def to_pem(self):
# TODO: "BEGIN ECPARAMETERS"
return der.topem(self.to_der(), "EC PRIVATE KEY")

def to_der(self):
def to_der(self, point_encoding="uncompressed"):
# SEQ([int(1), octetstring(privkey),cont[0], oid(secp224r1),
# cont[1],bitstring])
encoded_vk = b("\x00\x04") + self.get_verifying_key().to_string()
encoded_vk = b("\x00") + \
self.get_verifying_key().to_string(point_encoding)
return der.encode_sequence(der.encode_integer(1),
der.encode_octet_string(self.to_string()),
der.encode_constructed(0, self.curve.encoded_oid),
Expand Down
23 changes: 16 additions & 7 deletions src/ecdsa/numbertheory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@

from __future__ import division

from six import integer_types
from six import integer_types, PY3
from six.moves import reduce
try:
xrange
except NameError:
xrange = range

import math

Expand Down Expand Up @@ -62,7 +66,7 @@ def polynomial_reduce_mod(poly, polymod, p):

while len(poly) >= len(polymod):
if poly[-1] != 0:
for i in range(2, len(polymod) + 1):
for i in xrange(2, len(polymod) + 1):
poly[-i] = (poly[-i] - poly[-1] * polymod[-i]) % p
poly = poly[0:-1]

Expand All @@ -86,8 +90,8 @@ def polynomial_multiply_mod(m1, m2, polymod, p):

# Add together all the cross-terms:

for i in range(len(m1)):
for j in range(len(m2)):
for i in xrange(len(m1)):
for j in xrange(len(m2)):
prod[i + j] = (prod[i + j] + m1[i] * m2[j]) % p

return polynomial_reduce_mod(prod, polymod, p)
Expand Down Expand Up @@ -187,7 +191,12 @@ def square_root_mod_prime(a, p):
return (2 * a * modular_exp(4 * a, (p - 5) // 8, p)) % p
raise RuntimeError("Shouldn't get here.")

for b in range(2, p):
if PY3:
range_top = p
else:
# xrange on python2 can take integers representable as C long only
range_top = min(0x7fffffff, p)
for b in xrange(2, range_top):
if jacobi(b * b - 4 * a, p) == -1:
f = (a, -b, 1)
ff = polynomial_exp_mod((0, 1), (p + 1) // 2, f, p)
Expand Down Expand Up @@ -355,7 +364,7 @@ def carmichael_of_factorized(f_list):
return 1

result = carmichael_of_ppower(f_list[0])
for i in range(1, len(f_list)):
for i in xrange(1, len(f_list)):
result = lcm(result, carmichael_of_ppower(f_list[i]))

return result
Expand Down Expand Up @@ -477,7 +486,7 @@ def is_prime(n):
while (r % 2) == 0:
s = s + 1
r = r // 2
for i in range(t):
for i in xrange(t):
a = smallprimes[i]
y = modular_exp(a, r, n)
if y != 1 and y != n - 1:
Expand Down
Loading