/************************************************************************************
*																					*
*		The purpose of this file is basic manipulation of multivariate polynomials.					*
*  The calls to the underlying field are assumed ONB, but these can be replaced 						*
*  with any field. 	Multivariates are assumed variable length arrays, so watch for					*
*  memory leaks!!																		*
*																					*
*									Author = Mike Rosing								*
*									 date  = May 10, 1999								*
*  Modified to use R_p field, July 20, 1999													*
*																					*
************************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include "field2n.h"
#include "eliptic.h"

#include "multipoly.h"

extern RAMDATA ram_block[];

/*  add two multivariate polynomials.  C = A + B
	Creates new output space.  takes care of memory management,
	but ASSUMES A and B were previously allocated.  
	Returns 1 on success, 0 on failure to create C.
*/

int multi_add( MULTIPOLY A, MULTIPOLY B, MULTIPOLY *C)
{
	ELEMENT		i, j;
	MULTIPOLY	shortpoly, result, longpoly;
	FIELD2N		*Result, *Aptr, *Bptr;
	
/*  find out which polynomial is larger and allocate that much space  */

	if (A.degree > B.degree)
	{
		shortpoly.degree = B.degree;
		shortpoly.memdex = B.memdex;
		longpoly.degree = A.degree;
		longpoly.memdex = A.memdex;
	}
	else
	{
		shortpoly.degree = A.degree;
		shortpoly.memdex = A.memdex;
		longpoly.degree = B.degree;
		longpoly.memdex = B.memdex;
	}
	result.degree = longpoly.degree;
	if (!get_space( &result)) return 0;

/*  now add the shorter length amounts together  */

	Result = Address(result);
	Aptr = Address( A);
	Bptr = Address( B);
	for( i=0; i<= shortpoly.degree; i++)
	{
		SUMLOOP(j) Result->e[j] = Aptr->e[j] ^ Bptr->e[j];
		Result++;
		Aptr++;
		Bptr++;
	}

/*  and copy the rest  */

	Result = Address(result) + shortpoly.degree + 1;
	Aptr = Address(longpoly) + shortpoly.degree + 1;
	multi_copy( longpoly.degree - shortpoly.degree, Aptr, Result);
	Result = Address(result);
	while( result.degree)
		if( !zero_check( &Result[result.degree])) result.degree--;
		else break;
		
/*  take care of memory management  */

	if ( A.memdex == C->memdex ) free_space( &A);
	if ( B.memdex == C->memdex ) free_space( &B);
	C->degree = result.degree;
	C->memdex = result.memdex;
	return 1;
}

/*  Multiply two multivariate polynomials. Uses Knuth's construction
	from page 399 of Semi Numerical Algorithms (sect. 4.6)
	computes C = A*B 
	returns 0 if C can't be allocated.
*/

int multi_mul( MULTIPOLY A, MULTIPOLY B, MULTIPOLY *C)
{
	ELEMENT		i, j, k;
	MULTIPOLY	shortmulti, longmulti, result;
	FIELD2N		*Short, *Long, *Result;
	FIELD2N		temp;
	
/*  create space for result using sum of degrees of source polynomials */

	result.degree = A.degree + B.degree;
	if ( !get_space( &result) ) return 0;
	if (A.degree > B.degree)
	{
		longmulti.memdex = A.memdex;
		longmulti.degree = A.degree;
		shortmulti.memdex = B.memdex;
		shortmulti.degree = B.degree;
	}
	else
	{
		longmulti.memdex = B.memdex;
		longmulti.degree = B.degree;
		shortmulti.memdex = A.memdex;
		shortmulti.degree = A.degree;
	}
	for( k=0; k<shortmulti.degree; k++)
	{
		Result = Address(result) + k;
		null( Result);
		for( i=0; i<=k; i++)
		{
			j = k - i;
			Short = Address( shortmulti) + i;
			Long = Address( longmulti) + j;
			ring_mul(  Short, Long,  &temp);
			Result = Address( result) + k;
			SUMLOOP(j) Result->e[j] ^= temp.e[j];
		}
	}
	for( k=shortmulti.degree; k<longmulti.degree; k++)
	{
		Result = Address(result) + k;
		null( Result);
		for( i=0; i<=shortmulti.degree; i++)
		{
			j = k - i;
			Short = Address(shortmulti) + i;
			Long = Address( longmulti) + j;
			ring_mul( Short,  Long, &temp);
			Result = Address( result) + k;
			SUMLOOP(j) Result->e[j] ^= temp.e[j];
		}
	}
	for( k=longmulti.degree; k<=result.degree; k++)
	{
		Result = Address( result) + k;
		null( Result);
		for( i = k-longmulti.degree; i <= shortmulti.degree;  i++)
		{
			j = k - i;
			Short = Address( shortmulti) + i;
			Long = Address( longmulti) + j;
			ring_mul( Short, Long, &temp);
			Result = Address( result) + k;
			SUMLOOP(j) Result->e[j] ^= temp.e[j];
		}
	}
	Result = Address( result);
	while( result.degree)
		if( !zero_check( &Result[result.degree])) result.degree--;
		else break;
	
/*  take care of memory management  */

	if (A.memdex == C->memdex) free_space( &A);
	if (B.memdex == C->memdex) free_space( &B);
	C->memdex = result.memdex;
	C->degree = result.degree;
	return 1;
}

/*  Subroutine to test if a FIELD2N value is 0.
	returns 0 if all bits in field are 0, first non-zero ELEMENT
	otherwise.
*/

ELEMENT zero_check( FIELD2N *z)
{
	INDEX	i;
	ELEMENT	mask;
	
	mask = z->e[0] & UPRMASK;
	if (mask) return mask;
	for( i=1; i<MAXLONG; i++)
		if( z->e[i] ) return z->e[i];
	return 0;
}
	
/*  Divide two multivariate polynomials. Uses Knuth's construction
	Algorithm D pg 402 of Semi Numerical Algorithms (sect. 4.6.1)
	
	Quotient = Top/Bottom, Remainder = Top mod Bottom
	
	returns 0 if can't allocate space, 1 on success and -1 if
	attempting to divide by 0
*/

int multi_div( MULTIPOLY Top, MULTIPOLY Bottom, 
				MULTIPOLY *Quotient, MULTIPOLY *Remainder)
{
	INDEX		j;
	ELEMENT		i, k;
	FIELD2N		temp;
	MULTIPOLY	Q, R, T, B;
	FIELD2N		*q, *r, *t, *b, *bot, *top;

/*  check to see if attempting to divide by 0  */

	i = 0;
	bot = Address( Bottom) + Bottom.degree;
	SUMLOOP(j) i |= bot->e[j];
	if (!i) return (-1);

/*  check to see if we are dividing by a scalar  */

	if ( !Bottom.degree )
	{
		bot = Address( Bottom);
		ring_inv( bot, &temp);
		Q.degree = Top.degree;
		if ( !get_space( &Q)) return 0;
		top = Address(Top);
		q = AddressOf(Quotient);
		for( i=0; i<=Q.degree; i++)
		{
			ring_mul( &temp, top, q);
			top++;
			q++;
		}
		R.degree = 0;
		if ( !get_space( &R)) return 0;
		null( Address(R));
		goto divfree;
	}
	
/*  check to see if top > bottom.  if not result is trivial  */

	if ( Top.degree < Bottom.degree )
	{
		Q. degree = 0;
		if ( !get_space( &Q) ) return 0;
		R.degree = Top.degree;
		if ( !get_space( &R) ) return 0;
		top = Address( Top);
		r = Address( R);
		multi_copy( R.degree+1, top, r);
		null( Address(Q));
		goto divfree;
	}

/*  Scale top and bottom so that bottom is monic  */

	bot = Address( Bottom) + Bottom.degree;
	ring_inv( bot, &temp);
	T.degree = Top.degree;
	if ( !get_space( &T)) return 0;
	B.degree = Bottom.degree;
	if ( !get_space( &B)) return 0;
/*	
	for( i=0; i<=B.degree; i++)
		ring_mul( &temp, &(Bottom.p[i]), &(B.p[i]));
	for( i=0; i<=T.degree; i++)
		ring_mul( &temp, &(Top.p[i]), &(T.p[i]));
*/
	bot = Address( Bottom);
	b = Address( B);
	for( i=0; i<=B.degree; i++)
	{
		ring_mul( &temp, bot, b);
		bot++;
		b++;
	}
	top = Address( Top);
	t = Address( T);
	for( i=0; i<=T.degree; i++)
	{
		ring_mul( &temp, top, t);
		top++;
		t++;
	}
	
/*  allocate space for results  */

	Q.degree = T.degree - B.degree;
	R.degree = B.degree - 1;
	if( !get_space( &Q) ) return 0;
	if( !get_space( &R) ) return 0;

/*  core of division algorithm.  Reduce top by
	bottom until remainder found.  */
	
	for( k=Q.degree; k!=0; k--)
	{
		q = Address( Q) + k;
		t = Address( T) + B.degree + k;
		copy( t, q);
		for( i=k; i<B.degree+k; i++)
		{
			b = Address( B) + i - k;
			ring_mul( q, b,  &temp);
			t = Address( T) + i;
			SUMLOOP(j) t->e[j] ^= temp.e[j];
		}
	}

/*  since I've defined degree to be an unsigned long, I
	can't use <=0 as a test.  do k=0 subscript as it's own loop */
	
	t = Address( T) + B.degree;
	q = Address( Q);
	copy( t, q);
	for( i=0; i<B.degree; i++)
	{
		b = Address( B) + i;
		ring_mul( q, b, &temp);
		t = Address( T) + i;
		SUMLOOP(j) t->e[j] ^= temp.e[j];
	}
	
/*  set final degree of remainder */

	i = R.degree;
	t = Address( T) + i;
	while( !zero_check( t ) && i!=0)
	{
		 i--;
		 t--;
	}
	R.degree = i;
	
/*  copy remainder and deallocate space  */

	t = Address( T);
	r = Address( R);
	multi_copy( R.degree+1, t, r);
	free_space( &T);
	free_space( &B);
divfree:
	if( (Quotient->memdex == Top.memdex) || 
		(Remainder->memdex == Top.memdex) ) 
			free_space( &Top);
	if( (Quotient->memdex == Bottom.memdex) || 
		(Remainder->memdex == Bottom.memdex) ) 
			free_space( &Bottom);
	Quotient->degree = Q.degree;
	Quotient->memdex = Q.memdex;
	Remainder->degree = R.degree;
	Remainder->memdex = R.memdex;
	return 1;
}

/*  compute greatest common divisor of two multivariate polynomials.
	Same as "Modern Euclidian Algorithm" (see Knuthpg 405).
	returns 0 if space not available, 1 otherwise
*/

int	multi_gcd( MULTIPOLY A, MULTIPOLY B, MULTIPOLY *gcd)
{
	INDEX		i;
	MULTIPOLY	u, v, r, dummy;
	FIELD2N		*uptr, *vptr, *rptr, *dumptr;
	FIELD2N		*Aptr, *Bptr;

/*  copy data so we can clobber it  */

	if( !A.degree || !B.degree)
	{
		gcd->degree = 0;
		if ( !get_space( gcd) )return 0;
		dumptr = AddressOf( gcd);
		one( dumptr);
		return 1;
	}
	if( A.degree >= B.degree)
	{
		u.degree = A.degree;
		v.degree = B.degree;
		if( !get_space( &u)) return 0;
		if( !get_space( &v)) return 0;
		Aptr = Address(A);
		uptr = Address( u);
		multi_copy( u.degree+1, Aptr, uptr);
		Bptr = Address(B);
		vptr = Address( v);
		multi_copy( v.degree+1, Bptr, vptr);
	}
	else
	{
		u.degree = B.degree;
		v.degree = A.degree;
		if( !get_space( &u)) return 0;
		if( !get_space( &v)) 
		{
			free_space( &u);
			return 0;
		}
		Bptr = Address( B);
		uptr = Address( u);
		multi_copy( u.degree+1, Bptr, uptr);
		Aptr = Address( A);
		vptr = Address( v);
		multi_copy( v.degree+1, Aptr, vptr);
	}

/*  Basic gcd algorithm loop.  divide until remainder is 0  */

	while( 1)
	{
		multi_div( u, v, &dummy, &r);
		if( !r.degree && !zero_check( Address(r)))
		{
			gcd->degree = v.degree;
			if ( !get_space( gcd)) 
			{
				free_space( &v);
				free_space( &u);
				free_space( &r);
				free_space( &dummy);
				return 0;
			}
			if (!v.degree) 
				one( AddressOf(gcd));
			else
				multi_copy( v.degree+1, Address(v), AddressOf(gcd));
			free_space( &v);
			free_space( &u);
			free_space( &r);
			free_space( &dummy);
			return 1;
		}
		u.degree = v.degree;
		vptr = Address(v);
		uptr = Address( u);
		multi_copy( u.degree+1, vptr, uptr);
		v.degree = r.degree;
		rptr = Address( r);
		multi_copy( v.degree+1, rptr, vptr);
		free_space( &r);
		free_space( &dummy);
	}
}

/*  The purpose of this routine is to generate a table of
	division polynomials.
	Enter with a pointer to an array of MULTIPOLY which has
	already been allocated, the length of that array and
	the corresponding curve to compute the polynomials over.
	Returns number of elements successfully allocated.  Note
	that size of f_n(x) is proportional to n^2.
*/

int gen_division_polynomial( MULTIPOLY *f, int length, CURVE curv)
{
	int			numgen, n, done;
	ELEMENT		i;
	MULTIPOLY	term1, term2;
	FIELD2N		*fptr;
	
/* generate first five entries as described in Menezes "Elliptic
	Curve Public Key Crypto Systems", pg 102
*/
	numgen = 0;
	f[0].degree = 0;
	if (!get_space(f)) return numgen;
	fptr = Address( f[0]);
	null( fptr);
	numgen++;
	if( numgen == length) return numgen;
	
	f[1].degree = 0;
	if( !get_space( &f[1])) return numgen;
	fptr = Address( f[1]);
	one( fptr);
	numgen++;
	if ( numgen == length) return numgen;
	
	f[2].degree = 1;
	if( !get_space( &f[2])) return numgen;
	fptr = Address( f[2]);
	one( &fptr[1]);
	null( fptr);
	numgen++;
	if ( numgen == length) return numgen;
	
	f[3].degree = 4;
	if( !get_space( &f[3])) return numgen;
	fptr = Address( f[3]);
	copy( &curv.a6, fptr);
	null( &(fptr[1]));
	null( &(fptr[2]));
	one( &(fptr[3]));
	one( &(fptr[4]));
	numgen++;
	if ( numgen == length) return numgen;
	
	f[4].degree = 6;
	if( !get_space( &f[4])) return numgen;
	fptr = Address( f[4]);
	null( fptr);
	null( &(fptr[1]));
	copy( &curv.a6, &(fptr[2]));
	null( &(fptr[3]));
	null( &(fptr[4]));
	null( &(fptr[5]));
	one( &(fptr[6]));
	numgen++;
	if ( numgen == length) return numgen;

/* do odd, then even terms using the formulas 
	f_(2n+1) = (f_n)^3 * f_(n+2) + f_(n-1) * (f_(n+1))^3
	x * f_2n = (f_(n-1))^2 * f_n * f_(n+2) + f_(n-2) * f_n * (f_(n+1))^2

  if we use more internal storage the following routines could be
  implemented much faster in time.  Save squared terms for use in later
  iterations.
*/

	done = 0;
	while ( (numgen < length) && ( !done) )
	{
		if (numgen & 1)
		{
			n = (numgen - 1)/2;
			multi_mul( f[n], f[n], &term1);
			multi_mul( f[n], term1, &term1);
			multi_mul( f[n+2], term1, &term1);
			multi_mul( f[n+1], f[n+1], &term2);
			multi_mul( f[n+1], term2, &term2);
			multi_mul( f[n-1], term2, &term2);
			if ( !multi_add( term1, term2, &f[numgen])) done = 1;
			else numgen++;
		}
		else
		{
			n = numgen/2;
			multi_mul( f[n-1], f[n-1], &term1);
			multi_mul( f[n], term1, &term1);
			multi_mul( f[n+2], term1, &term1);
			multi_mul( f[n+1], f[n+1], &term2);
			multi_mul( f[n], term2, &term2);
			multi_mul( f[n-2], term2, &term2);
			if ( !multi_add( term1, term2, &term1)) done = 1;
			else
			{
				multi_div( term1, f[2], &(f[numgen]), &term1);
				i = f[numgen].degree;
				fptr = Address( f[numgen]);
				while( !zero_check( &fptr[i]) && i>0) i--;
				f[numgen].degree = i;
				numgen++;
			}
		}
		free_space( &term1);
		free_space( &term2);
		printf("finished polynomial #%d\n", numgen);
	}
	return numgen;
}	

/*  The folloing routine computes a table of MULTIPOLY entries
	consisting of the residues of x^2j modulo a given MULTIPOLY f.
	Input is MULTIPOLY f, output is a vector of MULTIPOLY
	of length f.degree and each entry corresponds to
	x^2j mod f.
	Returns 1 if space allocated successfully for all table
	entries and 0 on failure.  Space for output MULTIPOLY vector 
	should already be allocated (change this??)
	
	Write a "destructor" routine to make sure you don't have
	memory leak problems!!
*/

int gen_xmodf( MULTIPOLY f, MULTIPOLY *xmod)
{
	ELEMENT		i, j;
	MULTIPOLY	dummy;
	FIELD2N		*xptr;
	
/* create the "simple" elements first  */

	xmod[0].degree = 0;
	if( !get_space( &(xmod[0]))) return 0;
	xptr = AddressOf( xmod);
	one( xptr);
	if ( f.degree < 2) return 1;
	
	xmod[1].degree = 2;
	if( !get_space( &(xmod[1])) ) return 0;
	xptr = Address( xmod[1]);
	null( xptr);
	null( &(xptr[1]));
	one( &(xptr[2]));
	if( f.degree < 3) return 1;
	
/*  multiply x^2 * x^2j to get x^2(j+1) next entry.
	nothing modulo f yet since 2j < degree of f[k]  */

	i = 2;
	while ( xmod[i-1].degree < f.degree)
	{
		if( !multi_mul( xmod[1], xmod[i-1], &xmod[i])) return 0;;
		i++;
	}

/*  when 2j > degree of f[k], we need to take residue  */

	if( !multi_div( xmod[i-1], f, &dummy, &(xmod[i-1])) ) return 0;
	free_space( &dummy);
	while( i < f.degree)
	{
		if( !multi_mul( xmod[1], xmod[i-1], &(xmod[i]))) return 0;
		if( !multi_div( xmod[i], f, &dummy, &(xmod[i]))) return 0;
		free_space( &dummy);
		i++;
	}
	return 1;
}

/*  Because I use memory, I have to free it too or the system
	crashes.  Destructor routine for xmodfTable.
	enter with f and pointer to table, it frees all entries.
*/

void destroy_xmodf( MULTIPOLY f, MULTIPOLY *table)
{
	ELEMENT	k;
	
	for( k=0; k<f.degree; k++)
	{ 
		free_space( &table[k]);
	}
}

/*  Basic modulo squaring routine.
	enter with MULTIPOLY to be squared (g) and pointer
	to vector MULTIPOLY xmodf table (xmodfTable) as well
	as modulus (f).  
	Returns h = g^2 mod f.  OK for h == g too.
	See Menezes "Elliptic Curve Public Key Cryptosystems" pg 111.
*/

int square_modf( MULTIPOLY g, MULTIPOLY f, MULTIPOLY *xmodfTable, 
					MULTIPOLY *h)
{
	INDEX		t;
	ELEMENT		i, j;
	FIELD2N		aj, xtemp;
	MULTIPOLY	useg, dummy;
	FIELD2N		*uptr, *gptr, *hptr, *xtptr;
	
/*  first check that g has degree < f  */

	if (g.degree >= f.degree) 
	{
		multi_div( g, f, &dummy, &useg);
		free_space( &dummy);
	}
	else
	{
		useg.degree = g.degree;
		if( !get_space( &useg)) return 0;
		uptr = Address( useg);
		gptr = Address( g);
		multi_copy( useg.degree+1, gptr, uptr);
	}

/*  square each coefficient in g and multiply that with the corresponding
	vector in the xmodf table. Sum with result.
*/

	if( h->memdex == g.memdex ) free_space( &g);
	h->degree = f.degree - 1;
	if( !get_space( h)) 
	{
		free_space( &useg);
		return 0;
	}
	hptr =  AddressOf( h);
	for( i=0; i<f.degree; i++) null( &(hptr[i]));
	for( j=0; j<=useg.degree; j++)
	{
		uptr = Address( useg) + j;
		ring_square( uptr,  &aj);
		for( i=0; i<= xmodfTable[j].degree; i++)
		{
			xtptr = Address( xmodfTable[j]) + i;
			ring_mul( &aj, xtptr, &xtemp);
			hptr = AddressOf( h) +  i;
			SUMLOOP(t) hptr->e[t] ^= xtemp.e[t];
		}
	}
	free_space ( &useg);
	hptr = AddressOf( h);
	while ( !zero_check( &hptr[h->degree]) && (h->degree > 0)) h->degree--;
	return 1;
}

/*  use xmodf table to compute x^2^n mod f[k].
	for 2^i < degree of f[k] the answer is in the
	xmodf table. Start with i = log_2(f[k].degree) and
	square mod f[k] via Menezes' algorithm (see page 111
	in "Elliptic Curve Public Key Cryptosystems").
	Enter with pointers to xmodf table and result.
	Returns one MULTIPOLY of degree less than f[k], 
	0 if no space available for calculations,
	1 if all ok.
	
	Don't forget to free xq before calling this routine too
	many times!
*/
int xqmodf( MULTIPOLY f, MULTIPOLY *xmodf, MULTIPOLY *xq)
{
	ELEMENT		k, j, i;
	FIELD2N		*xptr, *xqptr;
	
/*  first figure out where to start  */

	if ( f.degree < 2) return 0;
	k =  log_2(f.degree) ;
	j =  1 << (k - 1);
	xq->degree = xmodf[j].degree;
	if( !get_space( xq)) return 0;
	xptr = Address( xmodf[j]);
	xqptr = AddressOf( xq);
	multi_copy( xmodf[j].degree+1, xptr, xqptr);
	while( k < NUMBITS)
	{
		square_modf( *xq, f, xmodf, xq);
		k++;
	}
	return 1;
}
/*
main()
{
	MULTIPOLY 	a, b, c, d, e, xtoq;
	MULTIPOLY 	f[10], TABLxmodf[40] ;
	CURVE		crv;
	int			x, y, z;
	FIELD2N		*aptr;
	
	init_ram_space();
	init_ring_math();
	
	a.degree = 7;
	if(!get_space( &a)) printf("you are royally fucked./n");
	for( x=1; x<7; x++) 
	{
		aptr = Address( a) + x;
		null( aptr);
	}
	aptr = Address(a);
	aptr->e[0] = 0x1;
	aptr[7].e[0] = 0x1;
/*	b.degree = 5;
	if(!get_space( &b)) printf("you are royally fucked./n");
	b.p[0].e[0] = 0x22;
	b.p[1].e[0] = 0x0;
	b.p[2].e[0] = 0x3FF;
	b.p[3].e[0] = 0x33;
	b.p[4].e[0] = 0x01;
	b.p[5].e[0] = 0x03;
	multi_mul( a, b, &c);
	multi_add( a, c, &c);
	multi_div( c, b, &d, &e);
	multi_div( c, a, &d, &e);
	multi_gcd( c, b, &d);
	multi_gcd( c, a, &e);*/
	
/*	crv.form = 0;
	null(&crv.a2);
	null(&crv.a6);
	crv.a6.e[0] = 0xc;
	x = 10;
	z = gen_division_polynomial( f, x, crv);
	printf("genrated %d polynomials out of %d\n", z, x);
	
	for( y=3; y<z; y++)
	{
		x = gen_xmodf( f[y], TABLxmodf);
		if (!x) 
		{
			printf("gen_xmodf failed y = %d\n", y);
			return;
		}
		x = square_modf( a, f[y], TABLxmodf, &xtoq);
		if (!x) 
		{
			printf("square _modf failed y = %d\n", y);
			return;
		}
		x = xqmodf( f[y], TABLxmodf, &xtoq);
		if (!x) 
		{
			printf("xmodf failed y = %d\n", y);
			return;
		}		
		free_space( &xtoq);
		destroy_xmodf( f[y], TABLxmodf);
	}
}
*/
	