use core::fmt;
use std::hash::{Hash, Hasher};
use std::ops::{Add, Div, Mul, Neg, Sub};
pub trait Field:
Neg<Output = Self>
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Eq
+ Copy
+ fmt::Debug
{
const CHARACTERISTIC: u64;
const ZERO: Self;
const ONE: Self;
fn inverse(self) -> Self;
fn integer_mul(self, a: i64) -> Self;
fn from_integer(a: i64) -> Self {
Self::ONE.integer_mul(a)
}
type ElementsIter: Iterator<Item = Self>;
fn elements() -> Self::ElementsIter;
}
#[derive(Clone, Copy)]
pub struct PrimeField<const P: u64> {
a: i64,
}
impl<const P: u64> PrimeField<P> {
fn reduce(self) -> Self {
let Self { a } = self;
let p: i64 = P.try_into().expect("module not fitting into signed 64 bit");
let a = a.rem_euclid(p);
assert!(a >= 0);
Self { a }
}
pub fn to_integer(&self) -> u64 {
self.reduce().a as u64
}
}
impl<const P: u64> From<i64> for PrimeField<P> {
fn from(a: i64) -> Self {
Self { a }
}
}
impl<const P: u64> PartialEq for PrimeField<P> {
fn eq(&self, other: &Self) -> bool {
self.reduce().a == other.reduce().a
}
}
impl<const P: u64> Eq for PrimeField<P> {}
impl<const P: u64> Neg for PrimeField<P> {
type Output = Self;
fn neg(self) -> Self::Output {
Self { a: -self.a }
}
}
impl<const P: u64> Add for PrimeField<P> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
a: self.a.checked_add(rhs.a).unwrap_or_else(|| {
let x = self.reduce();
let y = rhs.reduce();
x.a + y.a
}),
}
}
}
impl<const P: u64> Sub for PrimeField<P> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self {
a: self.a.checked_sub(rhs.a).unwrap_or_else(|| {
let x = self.reduce();
let y = rhs.reduce();
x.a - y.a
}),
}
}
}
impl<const P: u64> Mul for PrimeField<P> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self {
a: self.a.checked_mul(rhs.a).unwrap_or_else(|| {
let x = self.reduce();
let y = rhs.reduce();
x.a * y.a
}),
}
}
}
impl<const P: u64> Div for PrimeField<P> {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self::Output {
self * rhs.inverse()
}
}
impl<const P: u64> fmt::Debug for PrimeField<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let x = self.reduce();
write!(f, "{}", x.reduce().a)
}
}
impl<const P: u64> Field for PrimeField<P> {
const CHARACTERISTIC: u64 = P;
const ZERO: Self = Self { a: 0 };
const ONE: Self = Self { a: 1 };
fn inverse(self) -> Self {
assert_ne!(self.a, 0);
Self {
a: mod_inverse(
self.a,
P.try_into().expect("module not fitting into signed 64 bit"),
),
}
}
fn integer_mul(self, mut n: i64) -> Self {
if n == 0 {
return Self::ZERO;
}
let mut x = self;
if n < 0 {
x = -x;
n = -n;
}
let mut y = Self::ZERO;
while n > 1 {
if n % 2 == 1 {
y = y + x;
n -= 1;
}
x = x + x;
n /= 2;
}
x + y
}
type ElementsIter = PrimeFieldElementsIter<P>;
fn elements() -> Self::ElementsIter {
PrimeFieldElementsIter::default()
}
}
#[derive(Default)]
pub struct PrimeFieldElementsIter<const P: u64> {
x: i64,
}
impl<const P: u64> Iterator for PrimeFieldElementsIter<P> {
type Item = PrimeField<P>;
fn next(&mut self) -> Option<Self::Item> {
if self.x as u64 == P {
None
} else {
let res = PrimeField::from_integer(self.x);
self.x += 1;
Some(res)
}
}
}
impl<const P: u64> Hash for PrimeField<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self { a } = self.reduce();
state.write_i64(a);
}
}
fn mod_inverse(mut a: i64, mut b: i64) -> i64 {
let mut s = 1;
let mut t = 0;
let step = |x, y, q| (y, x - q * y);
while b != 0 {
let q = a / b;
(a, b) = step(a, b, q);
(s, t) = step(s, t, q);
}
assert!(a == 1 || a == -1);
a * s
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
#[test]
fn test_field_elements() {
fn test<const P: u64>() {
let expected: HashSet<PrimeField<P>> = (0..P as i64).map(Into::into).collect();
for gen in 1..P - 1 {
let gen = PrimeField::from(gen as i64);
let mut generated: HashSet<PrimeField<P>> = [gen].into_iter().collect();
let mut x = gen;
for _ in 0..P {
x = x + gen;
generated.insert(x);
}
assert_eq!(generated, expected);
}
}
test::<5>();
test::<7>();
test::<11>();
test::<13>();
test::<17>();
test::<19>();
test::<23>();
test::<71>();
test::<101>();
}
#[test]
fn large_prime_field() {
const P: u64 = 2_u64.pow(63) - 25;
type F = PrimeField<P>;
let x = F::from(P as i64 - 1);
let y = x.inverse();
assert_eq!(x * y, F::ONE);
}
#[test]
fn inverse() {
fn test<const P: u64>() {
for x in -7..7 {
let x = PrimeField::<P>::from(x);
if x != PrimeField::ZERO {
assert_eq!(x.inverse() * x, PrimeField::ONE);
assert_eq!(x * x.inverse(), PrimeField::ONE);
assert_eq!((x.inverse().a * x.a).rem_euclid(P as i64), 1);
assert_eq!(x / x, PrimeField::ONE);
}
assert_eq!(x + (-x), PrimeField::ZERO);
assert_eq!((-x) + x, PrimeField::ZERO);
assert_eq!(x - x, PrimeField::ZERO);
}
}
test::<5>();
test::<7>();
test::<11>();
test::<13>();
test::<17>();
test::<19>();
test::<23>();
test::<71>();
test::<101>();
}
#[test]
fn test_mod_inverse() {
assert_eq!(mod_inverse(-6, 7), 1);
assert_eq!(mod_inverse(-5, 7), -3);
assert_eq!(mod_inverse(-4, 7), -2);
assert_eq!(mod_inverse(-3, 7), 2);
assert_eq!(mod_inverse(-2, 7), 3);
assert_eq!(mod_inverse(-1, 7), -1);
assert_eq!(mod_inverse(1, 7), 1);
assert_eq!(mod_inverse(2, 7), -3);
assert_eq!(mod_inverse(3, 7), -2);
assert_eq!(mod_inverse(4, 7), 2);
assert_eq!(mod_inverse(5, 7), 3);
assert_eq!(mod_inverse(6, 7), -1);
}
#[test]
fn integer_mul() {
type F = PrimeField<23>;
for x in 0..23 {
let x = F { a: x };
for n in -7..7 {
assert_eq!(x.integer_mul(n), F { a: n * x.a });
}
}
}
#[test]
fn from_integer() {
type F = PrimeField<23>;
for x in -100..100 {
assert_eq!(F::from_integer(x), F { a: x });
}
assert_eq!(F::from(0), F::ZERO);
assert_eq!(F::from(1), F::ONE);
}
}