diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp new file mode 100644 index 000000000000..bf0807a4e68e --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp @@ -0,0 +1,19 @@ +#include "../bn254/fr.hpp" +#include "barretenberg/common/wasm_export.hpp" + +using namespace bb; + +WASM_EXPORT void bn254_fr_sqrt(uint8_t const* input, uint8_t* result) +{ + using serialize::write; + auto input_fr = from_buffer(input); + auto [is_sqr, root] = input_fr.sqrt(); + + uint8_t* is_sqrt_result_ptr = result; + uint8_t* root_result_ptr = result + 1; + + write(is_sqrt_result_ptr, is_sqr); + write(root_result_ptr, root); +} + +// NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier) \ No newline at end of file diff --git a/yarn-project/foundation/src/crypto/random/index.ts b/yarn-project/foundation/src/crypto/random/index.ts index a64dc4f4a957..76ee5a9a6a9b 100644 --- a/yarn-project/foundation/src/crypto/random/index.ts +++ b/yarn-project/foundation/src/crypto/random/index.ts @@ -74,3 +74,12 @@ export const randomBigInt = (max: bigint) => { const randomBigInt = BigInt(`0x${randomBuffer.toString('hex')}`); // Convert buffer to a large integer. return randomBigInt % max; // Use modulo to ensure the result is less than max. }; + +/** + * Generate a random boolean value. + * @returns A random boolean value. + */ +export const randomBoolean = () => { + const randomByte = randomBytes(1)[0]; // Generate a single random byte. + return randomByte % 2 === 0; // Use modulo to determine if the byte is even or odd. +}; diff --git a/yarn-project/foundation/src/fields/fields.test.ts b/yarn-project/foundation/src/fields/fields.test.ts index 2f8686bb4c5b..f07db9fe8992 100644 --- a/yarn-project/foundation/src/fields/fields.test.ts +++ b/yarn-project/foundation/src/fields/fields.test.ts @@ -109,7 +109,7 @@ describe('Bn254 arithmetic', () => { expect(actual).toEqual(expected); }); - it('High Bonudary', () => { + it('High Boundary', () => { // -1 - (-1) = 0 const a = new Fr(Fr.MODULUS - 1n); const b = new Fr(Fr.MODULUS - 1n); @@ -184,6 +184,30 @@ describe('Bn254 arithmetic', () => { }); }); + describe('Square root', () => { + it.each([ + [new Fr(0), 0n], + [new Fr(4), 2n], + [new Fr(9), 3n], + [new Fr(16), 4n], + ])('Should return the correct square root for %p', (input, expected) => { + const actual = input.sqrt()!.toBigInt(); + + // The square root can be either the expected value or the modulus - expected value + const isValid = actual == expected || actual == Fr.MODULUS - expected; + + expect(isValid).toBeTruthy(); + }); + + it('Should return the correct square root for random value', () => { + const a = Fr.random(); + const squared = a.mul(a); + + const actual = squared.sqrt(); + expect(actual!.mul(actual!)).toEqual(squared); + }); + }); + describe('Comparison', () => { it.each([ [new Fr(5), new Fr(10), -1], diff --git a/yarn-project/foundation/src/fields/fields.ts b/yarn-project/foundation/src/fields/fields.ts index 436003adbfda..5fa3c9cac295 100644 --- a/yarn-project/foundation/src/fields/fields.ts +++ b/yarn-project/foundation/src/fields/fields.ts @@ -1,3 +1,5 @@ +import { BarretenbergSync } from '@aztec/bb.js'; + import { inspect } from 'util'; import { toBigIntBE, toBufferBE } from '../bigint-buffer/index.js'; @@ -280,6 +282,25 @@ export class Fr extends BaseField { return new Fr(this.toBigInt() / rhs.toBigInt()); } + /** + * Computes a square root of the field element. + * @returns A square root of the field element (null if it does not exist). + */ + sqrt(): Fr | null { + const wasm = BarretenbergSync.getSingleton().getWasm(); + wasm.writeMemory(0, this.toBuffer()); + wasm.call('bn254_fr_sqrt', 0, Fr.SIZE_IN_BYTES); + const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES, Fr.SIZE_IN_BYTES + 1)); + const isSqrt = isSqrtBuf[0] === 1; + if (!isSqrt) { + // Field element is not a quadratic residue mod p so it has no square root. + return null; + } + + const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1)); + return Fr.fromBuffer(rootBuf); + } + toJSON() { return { type: 'Fr', diff --git a/yarn-project/foundation/src/fields/point.test.ts b/yarn-project/foundation/src/fields/point.test.ts new file mode 100644 index 000000000000..6fa64160b414 --- /dev/null +++ b/yarn-project/foundation/src/fields/point.test.ts @@ -0,0 +1,35 @@ +import { Fr } from './fields.js'; +import { Point } from './point.js'; + +describe('Point', () => { + it('converts to and from x and sign of y coordinate', () => { + const p = new Point( + new Fr(0x30426e64aee30e998c13c8ceecda3a77807dbead52bc2f3bf0eae851b4b710c1n), + new Fr(0x113156a068f603023240c96b4da5474667db3b8711c521c748212a15bc034ea6n), + false, + ); + + const [x, sign] = p.toXAndSign(); + const p2 = Point.fromXAndSign(x, sign); + + expect(p.equals(p2)).toBeTruthy(); + }); + + it('creates a valid random point', () => { + expect(Point.random().isOnGrumpkin()).toBeTruthy(); + }); + + it('converts to and from buffer', () => { + const p = Point.random(); + const p2 = Point.fromBuffer(p.toBuffer()); + + expect(p.equals(p2)).toBeTruthy(); + }); + + it('converts to and from compressed buffer', () => { + const p = Point.random(); + const p2 = Point.fromCompressedBuffer(p.toCompressedBuffer()); + + expect(p.equals(p2)).toBeTruthy(); + }); +}); diff --git a/yarn-project/foundation/src/fields/point.ts b/yarn-project/foundation/src/fields/point.ts index 26c84d88ec06..3bcf4a00ede1 100644 --- a/yarn-project/foundation/src/fields/point.ts +++ b/yarn-project/foundation/src/fields/point.ts @@ -1,4 +1,4 @@ -import { poseidon2Hash } from '../crypto/index.js'; +import { poseidon2Hash, randomBoolean } from '../crypto/index.js'; import { BufferReader, FieldReader, serializeToBuffer } from '../serialize/index.js'; import { Fr } from './fields.js'; @@ -10,6 +10,7 @@ import { Fr } from './fields.js'; export class Point { static ZERO = new Point(Fr.ZERO, Fr.ZERO, false); static SIZE_IN_BYTES = Fr.SIZE_IN_BYTES * 2; + static COMPRESSED_SIZE_IN_BYTES = Fr.SIZE_IN_BYTES + 1; /** Used to differentiate this class from AztecAddress */ public readonly kind = 'point'; @@ -37,8 +38,17 @@ export class Point { * @returns A randomly generated Point instance. */ static random() { - // TODO make this return an actual point on curve. - return new Point(Fr.random(), Fr.random(), false); + while (true) { + try { + return Point.fromXAndSign(Fr.random(), randomBoolean()); + } catch (e: any) { + if (!(e instanceof NotOnCurveError)) { + throw e; + } + // The random point is not on the curve - we try again + continue; + } + } } /** @@ -53,6 +63,18 @@ export class Point { return new this(Fr.fromBuffer(reader), Fr.fromBuffer(reader), false); } + /** + * Create a Point instance from a compressed buffer. + * The input 'buffer' should have exactly 33 bytes representing the x coordinate and the sign of the y coordinate. + * + * @param buffer - The buffer containing the x coordinate and the sign of the y coordinate. + * @returns A Point instance. + */ + static fromCompressedBuffer(buffer: Buffer | BufferReader) { + const reader = BufferReader.asReader(buffer); + return this.fromXAndSign(Fr.fromBuffer(reader), reader.readBoolean()); + } + /** * Create a Point instance from a hex-encoded string. * The input 'address' should be prefixed with '0x' or not, and have exactly 128 hex characters representing the x and y coordinates. @@ -78,6 +100,46 @@ export class Point { return new this(reader.readField(), reader.readField(), reader.readBoolean()); } + /** + * Uses the x coordinate and isPositive flag (+/-) to reconstruct the point. + * @dev The y coordinate can be derived from the x coordinate and the "sign" flag by solving the grumpkin curve + * equation for y. + * @param x - The x coordinate of the point + * @param sign - The "sign" of the y coordinate - note that this is not a sign as is known in integer arithmetic. + * Instead it is a boolean flag that determines whether the y coordinate is <= (Fr.MODULUS - 1) / 2 + * @returns The point as an array of 2 fields + */ + static fromXAndSign(x: Fr, sign: boolean) { + // Calculate y^2 = x^3 - 17 + const ySquared = x.square().mul(x).sub(new Fr(17)); + + // Calculate the square root of ySquared + const y = ySquared.sqrt(); + + // If y is null, the x-coordinate is not on the curve + if (y === null) { + throw new NotOnCurveError(); + } + + const yPositiveBigInt = y.toBigInt() > (Fr.MODULUS - 1n) / 2n ? Fr.MODULUS - y.toBigInt() : y.toBigInt(); + const yNegativeBigInt = Fr.MODULUS - yPositiveBigInt; + + // Choose the positive or negative root based on isPositive + const finalY = sign ? new Fr(yPositiveBigInt) : new Fr(yNegativeBigInt); + + // Create and return the new Point + return new this(x, finalY, false); + } + + /** + * Returns the x coordinate and the sign of the y coordinate. + * @dev The y sign can be determined by checking if the y coordinate is greater than half of the modulus. + * @returns The x coordinate and the sign of the y coordinate. + */ + toXAndSign(): [Fr, boolean] { + return [this.x, this.y.toBigInt() <= (Fr.MODULUS - 1n) / 2n]; + } + /** * Returns the contents of the point as BigInts. * @returns The point as BigInts @@ -111,6 +173,14 @@ export class Point { return buf; } + /** + * Converts the Point instance to a compressed Buffer representation of the coordinates. + * @returns A Buffer representation of the Point instance + */ + toCompressedBuffer() { + return serializeToBuffer(this.toXAndSign()); + } + /** * Convert the Point instance to a hexadecimal string representation. * The output string is prefixed with '0x' and consists of exactly 128 hex characters, @@ -194,3 +264,10 @@ export function isPoint(obj: object): obj is Point { const point = obj as Point; return point.kind === 'point' && point.x !== undefined && point.y !== undefined; } + +class NotOnCurveError extends Error { + constructor() { + super('The given x-coordinate is not on the Grumpkin curve'); + this.name = 'NotOnCurveError'; + } +}