#include <iostream.h>
#include <stdlib.h>

#include "element.h"
#include "real.h"
#include "matrix.h"
#include "number.h"
#include "classes.h"

Matrix::Matrix(int size)
{
	createNumbers(size, size);
	fillNumbers(0);
}

void Matrix::Matrix(int a_rows, int a_cols)
{
	createNumbers(a_rows, a_cols);
	fillNumbers(0);
}

void Matrix::createType(char *type)
{
	if (type[0] == 'I') {

		for (register int i = 0; i < rows; i++) {
			for (register int j = 0; j < cols; j++) {
				if (cell[i][j] != NULL)
					remove(i + 1, j + 1);

				cell[i][j] = new Number(i == j);
			}
		}
	}
}

void Matrix::createNumbers(int a_rows, int a_cols)
{
	rows = a_rows;
	cols = a_cols;

	cell = (Number ***) malloc(rows * sizeof (Number***));

	for (register int i = 0; i < rows; i++)
		cell[i] = (Number **) malloc(cols * sizeof (Number**));
}

void Matrix::fillNumbers(int value)
{
	for (register int i = 0; i < rows; i++)
		for (register int j = 0; j < cols; j++)
			cell[i][j] = new Number(value);
}


Matrix::Matrix(int a_rows, int a_cols, Element &element)
{
	createNumbers(a_rows, a_cols);

	for (register int i = 0; i < rows; i++)
		for (register int j = 0; j < cols; j++)
			cell[i][j] = new Number(element);
}

Matrix::Matrix(Matrix &matrix)
{
	createNumbers(matrix.getRows(), matrix.getCols());

	for (register int i = 0; i < rows; i++)
		for (register int j = 0; j < cols; j++)
			cell[i][j] = new Number(matrix(i + 1, j + 1));
}

Matrix::Matrix(Element &element)
{
	if (element.getClass() == classMatrix) {
		Matrix *matrix;
		matrix = (Matrix *) (&element);

		createNumbers(matrix->getRows(), matrix->getCols());

		for (register int i = 0; i < rows; i++)
			for (register int j = 0; j < cols; j++)
				cell[i][j] = new Number((*matrix)(i + 1, j + 1));
	}
	else {
		createNumbers(1, 1);
		cell[0][0] = new Number(0);
		cerr << "invalid class used for matrix creation" << endl;
	}
}

Matrix::Matrix(char *type, int size)
{
	createNumbers(size, size);

	for (register int i = 0; i < rows; i++)
		for (register int j = 0; j < cols; j++)
			cell[i][j] = NULL;

	createType(type);
}

Matrix::~Matrix()
{
	register int i, j;

	for (i = 0; i < rows; i++)
		for (j = 0; j < cols; j++)
			if (cell[i][j] != NULL)
				delete cell[i][j];

	for (i = 0; i < rows; i++)
		free(cell[i]);

	free(cell);
}

Number& Matrix::operator () (int row, int col)
{
	return row >= 0 && row <= rows
		&& col >= 0 && col <= cols
		? *cell[row - 1][col - 1]
		: *cell[      0][      0];
}

Element& Matrix::copy()
{
	Matrix *matrix = new Matrix(rows, cols, Real(0));

	for (register int i = 0; i < rows; i++)
		for (register int j = 0; j < cols; j++)
			matrix->cell[i][j] = new Number(*cell[i][j]);

	return *matrix;
}

Element& Matrix::operator+(Element &element)
{
	if (element.getClass() == classMatrix) {
		Matrix *matrix;
		matrix = (Matrix *) (&element);

		if( matrix->getRows() != rows || matrix->getCols() != cols)
			return *(new Real(getNum() + element.getNum()));

		Matrix *ret = new Matrix(rows, cols, Real(0));

		for (register int i = 0; i < rows; i++)
			for (register int j = 0; j < cols; j++)
				*(ret->cell[i][j]) = *cell[i][j] + *(matrix->cell[i][j]);

		return *ret;
	}

	return *(new Real(getNum() + element.getNum()));
}

Element& Matrix::operator-(Element &element)
{
	if (element.getClass() == classMatrix) {
		Matrix *matrix;
		matrix = (Matrix *) (&element);

		if( matrix->getRows() != rows || matrix->getCols() != cols)
			return *(new Real(getNum() - /*+*/ element.getNum()));

		Matrix *ret = new Matrix(rows, cols, Real(0));

		for (register int i = 0; i < rows; i++)
			for (register int j = 0; j < cols; j++)
				*(ret->cell[i][j]) = *cell[i][j] - *(matrix->cell[i][j]);

		return *ret;
	}

	return *(new Real(getNum() - /*+*/ element.getNum()));
}

Element& Matrix::operator-()
{
	Matrix *ret = new Matrix(*this);

	for (register int i = 0; i < rows; i++)
		for (register int j = 0; j < cols; j++)
			*(ret->cell[i][j]) =- *cell[i][j];

	return *ret;
}

Element& Matrix::operator*(Element &element)
{
	if (element.getClass() == classMatrix) {
		Matrix *matrix;
		matrix = (Matrix *) (&element);

		if (cols < matrix->getRows()) {
			cerr << "multplying error: uncompact matrixes" << endl;
			return *(new Real(getNum() + element.getNum()));
		}

		Number *num;
		Matrix *ret = new Matrix(rows, matrix->getCols(), Real(-1));

		for (register int i = 0; i < rows; i++) {
			for (register int j = 0; j < matrix->getCols(); j++) {

				// creating num
				num = &((*cell[i][0]) * *(matrix->cell[0][j]));

				for (register int k = 1; k < matrix->getRows(); k++)
					(*num) = (*num) + *cell[i][k] * *(matrix->cell[k][j]);

				*(ret->cell[i][j]) = *num;
			}
		}

		return *ret;
	}

	return *(new Real(getNum() + element.getNum()));
}

ostream & Matrix::toStream(ostream &stream)
{
	stream << "[ matrix " << rows << "x" << cols << " ]" << endl;

	for (register int i = 0; i < rows; i++) {
		stream << "(";
		for (register int j = 0; j < cols - 1; j++)
			stream << *cell[i][j] << ", ";

		stream << *cell[i][j] << ")";
		stream << (i == rows - 1 ? "" : ",");
		stream << endl;
	}

	stream << "[ matrix end ]";

	return stream;
}

Element& Matrix::scalarMultiply(Element &element)
{
	Matrix *ret = new Matrix(rows, cols, Real(0));

	for (register int i = 0; i < rows; i++)
		for (register int j = 0; j < cols; j++)
			*(ret->cell[i][j]) = *cell[i][j] * element;

	return *ret;
}


Matrix & Matrix::subMatrix(int row,int col)
{
	register int i, j;
	register int i2, j2;

	if ((rows < 2) || (cols < 2)) {
		cerr << "unable to create submatrix" << endl;
		return *(new Matrix(1, 1, cell[0][0]->getElement()));
	}

	Matrix *ret = new Matrix(rows - 1, cols - 1, Real(0));

	for (i = i2 = 0; i < rows; i++) {
		if (i == row)
			continue;

		for (j = j2 = 0; j < cols; j++) {
			if (j == col)
				continue;

			*(ret->cell[i2][j2]) = (*cell[i][j]);
			j2++;
		}

		i2++;
	}

	return *ret;
}

float Matrix::determinant()
{
	if (rows != cols) {
		cout << "unable to count determinant for non-NxN matrix" << endl;
		return 0;
	}

	if (rows == 2)
		return cell[0][0]->getNum() * cell[1][1]->getNum()
			- cell[0][1]->getNum() * cell[1][0]->getNum();

	if (rows > 2) {
		register int c, k;
		float det;

		for (det = 0, k = 1, c = 0; c < cols; c++, k = -k)
			det += (float) k
				* cell[0][c]->getNum()
				* subMatrix(0,c).determinant();

		return det;
	}

	return cell[0][0]->getNum();
}

float Matrix::getNum()
{
	return determinant();
}

Matrix& Matrix::inverse()
{
	Matrix *ret = new Matrix(rows, cols, Real(0));
	float det   = 1 / determinant();
	float k, l = 1;

	for (register int i = 0; i < rows; i++, l = -l)
		for (register int j = 0, k = l; j < cols; j++, k = -k)
			*ret->cell[i][j] = Real(k * subMatrix(i,j).determinant() * det);

	return *ret;
}

