/************************************************************************************
*																					*
*    The following routines implement GF(2^n) math via the Ring R_p using J. Siverman's descriptions		*
*  in the paper "Fast Multiplication in Finite Fields GF(2^n).  The idea here is that the ring taken modulo	*
*  X^p-1 where p is a prime such that p-1 gives a Type I ONB allows us to compute much faster than in		*
*  the field GF(2^n) (n = p-1) directly.  This magic happens because X^p-1 factors into (X-1)*phi(X) and		*
*  phi(X) is the irreducible prime governing the ONB field.										*
*																					*
*									Author = mike rosing							*
*									 date  = July 7, 1999							*
************************************************************************************/

#include "field2n.h"

#define  GHOST_BIT		(1L << UPRSHIFT)
#define  GHOST_SHIFT	(field_prime % WORDSIZE)
#define  GHOST_MASK	~(~0 << GHOST_SHIFT)

/*  globals initialized once and used for multiply and
	inversion routines as well as quadratic solver.
*/

static	INDEX	two_inx[field_prime+1];
static	ELEMENT	two_bit[field_prime+1];
static	INDEX	sqr_inx[field_prime+1];
static	ELEMENT	sqr_bit[field_prime+1];
static unsigned char shift_by[256];

void ring_right(FIELD2N *a)
{
        INDEX i;
        ELEMENT bit,temp;

        bit = (a->e[NUMWORD] & 1) ? GHOST_BIT : 0L;
        SUMLOOP(i) {
           temp = ( a->e[i] >> 1)  | bit;
           bit = (a->e[i] & 1) ? MSB : 0L;
           a->e[i] = temp;
        }
        a->e[0] &= GHOST_MASK;
}

void null(a)
FIELD2N *a;
{
        INDEX i;

        SUMLOOP(i)  a->e[i] = 0;
}

void copy (a,b)
FIELD2N *a,*b;
{
        INDEX i;

        SUMLOOP(i)  b->e[i] = a->e[i];
}

/*  binary search for most significant bit within word */

INDEX log_2( x)
ELEMENT x;
{
	INDEX	k, lg2;
	ELEMENT ebit, bitsave, bitmask;

	lg2 = 0;
	bitsave = x;					/* grab bits we're interested in.  */
	k = WORDSIZE/2;				/* first see if msb is in top half  */
	bitmask = -1L<<k;				/* of all bits  */
	while (k)
	{
		ebit = bitsave & bitmask;		/* did we hit a bit?  */
		if (ebit)					/* yes  */
		{
			lg2 += k;				/* increment degree by minimum possible offset  */
			bitsave = ebit;			/* and zero out non useful bits  */
		}
		k /= 2;
		bitmask ^= (bitmask >> k);
	}
	return( lg2);
}


/*  Initialize globals for ring math system.  These terms used in several places.  */

void init_ring_math(void)
{
	INDEX n, i, j;

	j = 1;
	for ( i=0;  i<field_prime;  i++ ) 
	{
	    two_inx[i] = LONGWORD-(j / WORDSIZE);
	    two_bit[i] = 1L << (j % WORDSIZE);
            j = (j << 1) % field_prime;
	}

	for ( i=1; i<256; i++ )
	    shift_by[i] = 0;
	shift_by[0] = 1;
	for ( j=2; j<256; j+=j )
        for ( i=0; i<256; i+=j )
	    shift_by[i]++;

	n = 	NUMBITS/2 + 1;
	for( i=0; i<n; i++)
	{
		sqr_inx[i] = LONGWORD - ((2*i)/WORDSIZE);
		sqr_bit[i] = 1L << ((2*i) % WORDSIZE);
		sqr_inx[i+n] = LONGWORD - ((2*i+1)/WORDSIZE);
		sqr_bit[i+n] = 1L << ((2*i+1) % WORDSIZE);
	}
}

/*  check to see if a FIELD2N value is zero.
	Returns 1 if it's zero, 0 if it's not  (so name iszero makes sense).
*/

int	iszero(  FIELD2N  *num)
{
	INDEX	i;
	
	SUMLOOP(i) if( num->e[i]) return 0;
	return 1;
}

/*  Ring multiply for Type I ONB
	Taken from J.H. Silverman "Fast Multiplication in Finite Fields GF(2^n)",
	CHES workshop '99 (to appear in LNCS, Springer).  This method only works
	for Type I ONB at present but has complexity n+1, so is twice as fast as
	any other method.
*/

void ring_mul( FIELD2N *a,  FIELD2N *b, FIELD2N *c)
{
	INDEX	i, j, k;
	ELEMENT	mask;
	
	null( c);

/*  check to see if a or b are zero, just return if they are  */

	if ( iszero( a)) return;
	if ( iszero( b)) return;
	
	mask = 1L;
	j = NUMWORD;
	for( i=0; i<field_prime; i++)    /*  i counts over bits, j is the ELEMENT index */
	{
		if( mask & a->e[j]) 
			SUMLOOP(k) c->e[k] ^= b->e[k];
		ring_right( c);
		mask <<= 1;
		if( !mask)
		{
			mask = 1L;
			j--;
		}
	}

/*  if msb is set, flip back to GF(2^n).  Otherwise we're already there */

	if( c->e[0] & GHOST_BIT)
	{
		SUMLOOP(k) c->e[k] ^= ~0;
		c->e[0] &= UPRMASK;
	}
}

/*  Square a number in ring R_p.  Since this is modulo the cyclotomic
	polynomial, the bottom half bits spread out and the upper half bits
	interleave between them.  The initialization routine builds this 
	table as an index and bit mask which we can use to permute the 
	input bits.
*/

void ring_square( FIELD2N *a, FIELD2N *c)
{
	INDEX	i, j;
	ELEMENT	 mask;
	
	null( c);
	j = NUMWORD;
	mask = 1;
	for( i=0; i<field_prime; i++)
	{
		if( a->e[j] & mask)
			c->e[sqr_inx[i]] |= sqr_bit[i];
		mask <<= 1;
		if( !mask)
		{
			mask = 1;
			j--;
		}
	}
	if( c->e[0] & GHOST_BIT)
	{
		SUMLOOP(i)  c->e[i] ^= ~0;
		c->e[0] &= UPRMASK;
	}
}

/*  compute square root of number in R_p  */

void ring_sqroot( FIELD2N *a, FIELD2N *c)
{
	INDEX	i, j;
	ELEMENT	mask;
	
	null( c);
	j = NUMWORD;
	mask = 1;
	for( i=0; i<field_prime; i++)
	{
		if( a->e[sqr_inx[i]] & sqr_bit[i])
			c->e[j] |= mask;
		mask <<= 1;
		if( !mask)
		{
			mask = 1;
			j--;
		}
	}
	if( c->e[0] & GHOST_BIT)
	{
		SUMLOOP(i)  c->e[i] ^= ~0;
		c->e[0] &= UPRMASK;
	}
}

/*  Routine to divide input a by x^k.
	Enter with pointer to a, value of k and pointer to
	result b.
	Returns with b = a/x^k in R_p.  This is a k fold
	shift of all field_prime bits to the right.
*/

void ring_divk( FIELD2N *a, INDEX k, FIELD2N *b)
{
	INDEX	bottom1, bottom2, top1, top2, i, j, start;
	ELEMENT	mask1, mask2;
	
	null( b);
	k = k % field_prime;
	if ( !k )
	{
		copy( a, b);
		return;
	}
	start = LONGWORD  - k/WORDSIZE;
	bottom2 = k % 	WORDSIZE;
	bottom1 =  WORDSIZE - bottom2;
	mask1 = ~(~0 << bottom1);

/*  copy top to bottom  */

	j = start;
	i = LONGWORD;
	while ( i>=0)
	{
		b->e[i] = ( a->e[j] >> bottom2) & mask1;
		j--;
		if( j<0 ) break;
		b->e[i] |= a->e[j] << bottom1;
		i--;
	}

/*  mask off unused portion that got transfered  */

	top2 = ( field_prime - k ) % WORDSIZE;
	top1 = WORDSIZE - top2;
	mask1 = ~(~0 << top2);
	b->e[i] &= mask1;
	j = LONGWORD;
	i = LONGWORD - (field_prime - k) / WORDSIZE;
	mask2 = ~( ~0 << top1);
	while ( j>=0 )
	{
		b->e[i] |= a->e[j] << top2;
		i--;
		if (i<0) break;
		b->e[i] = ( a->e[j] >> top1) & mask2;
		j--;
	}
	b->e[0] &= GHOST_MASK;
}

/* This algorithm is the Almost Inverse Algorithm of Schroeppel, et al. given
   in "Fast Key Exchange with Elliptic Curve Systems
   Code originally written by Dave Dahm. 
   Modified to work with R_p by removing conversion to polynomial basis.
   Conversion to R_p only requires one extra bit
*/

void ring_inv(FIELD2N *a, FIELD2N *dest)
{
	FIELD2N		f, b, c, g;
	INDEX		i, j, k, m, n, f_top, c_top;
    	ELEMENT		bits, t, mask;

	/* f, b, c, and g are not in optimal normal basis format: they are held
	    in 'customary format', i.e. a0 + a1*u^1 + a2*u^2 + ...; For the
	    comments in this routine, the polynomials are assumed to be
	    polynomials in u. */

	/* Set g to polynomial (u^p-1)/(u-1) */

	for ( i=1; i<=LONGWORD; i++ )
	    g.e[i] = ~0;
        g.e[0] = GHOST_MASK;

	/* Set c to 0, b to 1, and n to 0, f to a*/

	null(&c);
	null(&b);
	b.e[LONGWORD] = 1;
	n = 0;
	copy( a, &f);

	/* Now find a polynomial b, such that a*b = u^n */

	/* f and g shrink, b and c grow.  The code takes advantage of this.
	c_top and f_top are the variables which control this behavior */

	c_top = LONGWORD;
	f_top = 0;
	do {
	    i = shift_by[f.e[LONGWORD] & 0xff];
	    n+=i;
    /* Shift f right i (divide by u^i) */
	    m = 0;
	    for ( j=f_top; j<=LONGWORD; j++ ) {
		bits = f.e[j];
		f.e[j] = (bits>>i) | ((ELEMENT)m << (WORDSIZE-i));
		m = bits;
	    }
	} while ( i == 8 && (f.e[LONGWORD] & 1) == 0 );
	for ( j=0; j<LONGWORD; j++ )
	    if ( f.e[j] ) break;
	if ( j<LONGWORD || f.e[LONGWORD] != 1 )
	{
	/* There are two loops here: whenever we need to exchange f with g and
		b with c, jump to the other loop which has the names reversed! */
		
	    do 
	    {
	    /* Shorten f and g when possible */
	    
		while ( f.e[f_top] == 0 && g.e[f_top] == 0 ) f_top++;
		
	    /* f needs to be bigger - if not, exchange f with g and b with c.
	       (Actually jump to the other loop instead of doing the exchange)
	       The published algorithm requires deg f >= deg g, but we don't
	       need to be so fine */
	       
		if ( f.e[f_top] < g.e[f_top] ) goto loop2;
loop1:
	    /* f = f+g, making f divisible by u */
	    
		for ( i=f_top; i<=LONGWORD; i++ )
		    f.e[i] ^= g.e[i];
		    
	    /* b = b+c */
	    
		for ( i=c_top; i<=LONGWORD; i++ )
		    b.e[i] ^= c.e[i];
		do 
		{
		    i = shift_by[f.e[LONGWORD] & 0xff];
		    n+=i;
		    
	    /* Shift c left i (multiply by u^i), lengthening it if needed */
	    
		    m = 0;
		    for ( j=LONGWORD; j>=c_top; j-- ) 
		    {
			bits = c.e[j];
			c.e[j] = (bits<<i) | m;
			m = bits >> (WORDSIZE-i);
		    }
		    if ( m ) 
		    {
		    	c.e[j] = m;
		    	c_top=j;
		    }

	    /* Shift f right i (divide by u^i) */
	    
		    m = 0;
		    for ( j=f_top; j<=LONGWORD; j++ ) 
		    {
			bits = f.e[j];
			f.e[j] = (bits>>i) | ((ELEMENT)m << (WORDSIZE-i));
			m = bits;
		    }
		} while ( i == 8 && (f.e[LONGWORD] & 1) == 0 );
		
	    /* Check if we are done (f=1) */
	    
		for ( j=f_top; j<LONGWORD; j++ )
		    if ( f.e[j] ) break;
	    } while ( j<LONGWORD || f.e[LONGWORD] != 1 );
	    
	    if ( j>0 )   goto done;
	    
	    do 
	    {
	    /* Shorten f and g when possible */
	    
		while ( g.e[f_top] == 0 && f.e[f_top] == 0 ) f_top++;
		
	    /* g needs to be bigger - if not, exchange f with g and b with c.
	       (Actually jump to the other loop instead of doing the exchange)
	       The published algorithm requires deg g >= deg f, but we don't
	       need to be so fine */
	       
		if ( g.e[f_top] < f.e[f_top] ) goto loop1;
loop2:
	    /* g = f+g, making g divisible by u */
	    
		for ( i=f_top; i<=LONGWORD; i++ )
		    g.e[i] ^= f.e[i];
		    
	    /* c = b+c */
	    
		for ( i=c_top; i<=LONGWORD; i++ )
		    c.e[i] ^= b.e[i];
		do 
		{
		    i = shift_by[g.e[LONGWORD] & 0xff];
		    n+=i;
		    
	    /* Shift b left i (multiply by u^i), lengthening it if needed */
	    
		    m = 0;
		    for ( j=LONGWORD; j>=c_top; j-- ) 
		    {
			bits = b.e[j];
			b.e[j] = (bits<<i) | m;
			m = bits >> (WORDSIZE-i);
		    }
		    if ( m )
		    {
		    	b.e[j] = m;
		    	c_top=j;
		    }

	    /* Shift g right i (divide by u^i) */
	    
		    m = 0;
		    for ( j=f_top; j<=LONGWORD; j++ ) 
		    {
			bits = g.e[j];
			g.e[j] = (bits>>i) | ((ELEMENT)m << (WORDSIZE-i));
			m = bits;
		    }
		} while ( i == 8 && (g.e[LONGWORD] & 1) == 0 );
		
	    /* Check if we are done (g=1) */
	    
		for ( j=f_top; j<LONGWORD; j++ )
		    if ( g.e[j] ) break;
		    
	    } while ( j<LONGWORD || g.e[LONGWORD] != 1 );
            copy(&c, &b);
	}
done:
	/* Now b is a polynomial such that a*b = u^n, so multiply b by u^(-n) */
	
	ring_divk(&b,  n, &c);

        /* Convert b back to optimal normal basis form (into dest) from ring */

	if (c.e[0] & GHOST_BIT)
	{
		SUMLOOP(i) c.e[i] ^= ~0;
		c.e[0] &= UPRMASK;
	}
	copy( &c, dest);
} /* ring_inv */

/*  Solving quadratic formula in R_p
	Expand on Silverman's description to compute a solution
	to y^2 + ay + b = 0    via a change in variables to
	az = y,  c = b/a^2 and then solve for z^2 + z + c = 0.
	
	Enter with a, b an pointer to y[2] storage.
	Returns 0 error and y filled with 2 results or
	             1 error and y values = 0
*/

int ring_quadratic( FIELD2N *a, FIELD2N *b, FIELD2N *y)
{
	INDEX	i, j, bits;
	ELEMENT	r, mask;
	FIELD2N	c, temp1, temp2, z;
	
/*  test for a = 0 and return y = b^(1/2)  (square root of b)  */

	r = 0;
	SUMLOOP(i) r |= a->e[i];
	if  ( !r)
	{
		ring_sqroot( b, y);
		copy( &y[0], &y[1]);
		return (0);
	}

/*  compute c = b/a^2  */

	ring_square( a, &temp1);
	ring_inv( &temp1, &temp2);
	ring_mul(b, &temp2, &c);

/*  if c[0] = 0, check that Tr(c) = 0. 
	Otherwise, compliement c, then check trace  */

	if( c.e[NUMWORD] & 1)
	{
		SUMLOOP(i) c.e[i] ^= ~0;
		c.e[0] &= GHOST_MASK;
	}

/*  next comput Tr(c)  */

	r = 0;
	SUMLOOP(i) r ^= c.e[i];
	mask = ~0;
	for( bits=WORDSIZE/2; bits>0; bits >>= 1)
	{
		mask >>= bits;
		r = ( ( r & mask) ^ ( r >> bits));
	}

/*  if not zero, return error code  */

	if ( r)
	{
		null( &y[0]);
		null( &y[1]);
		return(1);
	}

/*  Good chance we have a solution, so let's try.
	Set z[0] = 0, z[1] = 1 an index j = 2*i mod field_prime
	for i = 1 to NUMBITS.
*/

	null( &z);
	z.e[NUMWORD] = 2;
	j = NUMWORD;
	for( bits=1; bits<NUMBITS; bits++)
	{
		i = j;
		j = two_inx[bits];
		r = ( c.e[j] & two_bit[bits]) ? 1 : 0;
		mask = ( z.e[i] & two_bit[bits - 1]) ? 1 : 0;
		if( r ^ mask) z.e[j] |= two_bit[bits];
	}

/*  check that final bit is consistent with full solution.
      If not, then this ring element has no solution.
*/

	j = NUMWORD - ((NUMBITS/2 + 1) / WORDSIZE);
	r = 1L << ( (NUMBITS/2 + 1) % WORDSIZE);
	mask = (z.e[NUMWORD] ^ c.e[NUMWORD]) & 2L;
	if( ((z.e[j] & r) && !mask) || ( !(z.e[j] & r) && mask))
	{
		null( &y[0]);
		null( &y[1]);
		return (1);
	}

/*  convert back to GF(2^n)  */

	if( z.e[0] & GHOST_BIT)
	{
		SUMLOOP(i) z.e[i] ^= ~0;
		z.e[0] &= UPRMASK;
	}

/*  return correct y = az values  */

	ring_mul( &z, a,  &y[0]);
	null( &y[1]);
	SUMLOOP(i) y[1].e[i] = y[0].e[i] ^ a->e[i];
	return (0);
}

/*  This routine used in several places for ONB.  Not really
	efficient this way, but beats rewriting lots of code.
*/
void one( FIELD2N *a)
{
	null( a);
	a->e[NUMWORD] = 1L;
}
	
	