/* Floating point matrix manipulation routines for S-Lang */
/* 
 * Copyright (c) 1992, 1994 John E. Davis 
 * All rights reserved.
 *
 * Permission is hereby granted, without written agreement and without
 * license or royalty fees, to use, copy, and distribute this
 * software and its documentation for any purpose, provided that the
 * above copyright notice and the following two paragraphs appear in
 * all copies of this software.
 *
 * IN NO EVENT SHALL JOHN E. DAVIS BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF JOHN E. DAVIS
 * HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * JOHN E. DAVIS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
 * PARTICULAR PURPOSE.  THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS"
 * BASIS, AND JOHN E. DAVIS HAS NO OBLIGATION TO PROVIDE MAINTENANCE,
 * SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
 */
#include <stdio.h>
#ifndef FLOAT_TYPE
#define FLOAT_TYPE
#endif
#include "slang.h"
#include "_slang.h"
#include "slarray.h"

static SLArray_Type *SLpop_float_matrix (void)
{
   SLArray_Type *a;
   int sa = 0;
   
   if (NULL == (a = SLang_pop_array (&sa))) return NULL;
   
   if (a->type == FLOAT_TYPE) return a;
   SLang_Error = TYPE_MISMATCH;
   return NULL;
}

   
/* multiply 2 matrices (assumed float) producing a third */
static void SLmatrix_multiply (void)
{
   SLArray_Type *a, *b, *c;
   FLOAT *aa, *bb, *cc, sum;
   int c_handle, dim;
   unsigned int imax, jmax, kmax;
   unsigned int i, j, k, ofs_a, ofs_b;
   
   if ((NULL == (b = SLpop_float_matrix ()))
       || (NULL == (a = SLpop_float_matrix ())))
     return;
     
   /* Now is a is n*m, then b must be m*x.  Result it n*x */
   
   imax = a->x;
   jmax = a->y;
   
   if ((b->x != jmax) || (a->dim > 2) || (b->dim > 2))
     {
	SLang_Error = TYPE_MISMATCH;
	return;
     }
   
   kmax = b->y;
   
   /* Now result will be  imax by kmax 2d array */
   if (kmax == 1) dim = 1; else dim = 2;
   
   if (-1 == (c_handle = SLcreate_array(NULL, dim, imax, kmax, 1, 
					   'f', 0)))
     {
	SLang_doerror("Unable to create array.");
	return;
     }
   c = SLarray_from_handle (c_handle);
   /* Finally!! */
   cc = (FLOAT *) c->ptr;
   bb = (FLOAT *) b->ptr;
   aa = (FLOAT *) a->ptr;
   
   for (i = 0; i < imax; i++)
     {
	for (k = 0; k < kmax; k++)
	  {
	     sum = 0.0;
	     ofs_a = i * jmax;
	     ofs_b = k;
	     for (j = 0; j < jmax; j++)
	       {
		  sum += *(aa + ofs_a) * *(bb + ofs_b);
		  ofs_a++;
		  ofs_b += kmax;
	       }
	     
	     /* cc[i][k] */
	     *(cc + (int) (i * kmax + k)) = sum;
	  }
     }
   
   SLpush_array (c_handle);
}

static void SLmatrix_addition (void)
{
   SLArray_Type *a, *b, *c;
   FLOAT *aa, *bb, *cc;
   int c_handle;
   unsigned int imax, jmax, kmax, jmaxkmax;
   unsigned int i, j, k, ofs;
   
   if ((NULL == (b = SLpop_float_matrix ()))
       || (NULL == (a = SLpop_float_matrix ())))
     return;
     
   /* for the addition to make sence, they must be same type. */
   imax = a->x; jmax = a->y; kmax = a->z;
   
   if ((b->dim != a->dim) || (b->x != imax) 
       || (b->y != jmax) || (b->z != kmax))
     {
	SLang_Error = TYPE_MISMATCH;
	return;
     }
   
   if (-1 == (c_handle = SLcreate_array(NULL, a->dim, imax, jmax,
					   kmax, 'f', 0)))
     {
	SLang_doerror("Unable to create array.");
	return;
     }
   
   c = SLarray_from_handle (c_handle);
   /* Finally!! */
   cc = (FLOAT *) c->ptr;
   bb = (FLOAT *) b->ptr;
   aa = (FLOAT *) a->ptr;
   
   
   /* Probably more efficent if we work in this order */
   jmaxkmax = jmax * kmax;
   for (k = 0; k < kmax; k++)
     {
	for (j = 0; j < jmax; j++)
	  {
	     ofs = j * kmax + k;
	     for (i = 0; i < imax; i++)
	       {
		  *(cc + ofs) = *(aa + ofs) + *(bb + ofs);
		  ofs += jmaxkmax;
	       }
	  }
     }
   
   SLpush_array (c_handle);
}

static SLang_Name_Type slmatrix_table[] =
{
   MAKE_INTRINSIC(".matrix_multiply", SLmatrix_multiply, VOID_TYPE, 0),
   MAKE_INTRINSIC(".matrix_add", SLmatrix_addition, VOID_TYPE, 0),
   SLANG_END_TABLE
};

int init_SLmatrix()
{
   return SLang_add_table(slmatrix_table, "_Matrix");
}


