#
#	gridTableTools.py
#
#	(c) 2025 by Miguel Angel Reina Ortega & Andreas Kraft
#	License: BSD 3-Clause License. See the LICENSE file for further details.
#
""" Tools for working with grid tables in markdown files. """

from typing import Optional, Callable
from regexMatches import *


_alignLeft = 'align="left"'
_alignRight = 'align="right"'
_alignCenter = 'align="center"'
_nextListElementMark = '∆'	# Marks a continuing list in the line before. !!! Must be a single character


printInfo = print
printDebug = print
printError = print

def setLoggers(info:Callable=print, debug:Callable=print, error:Callable=print) -> None:
	global printInfo, printDebug, printError

	printInfo = info
	printDebug = debug
	printError = error


class GridCell:
	"""	Represents a grid table cell. """
	
	def __init__(self) -> None:
		"""	Initialize a new grid table cell. 
		"""
		self.content:Optional[str] = None
		self.rowspan:int = 0
		self.colspan:int = 0
		self.colspanAdjusted:bool = False
		self.alignment:str = 'align="center"'
		self.positionStart:Optional[int] = None
		self.position:Optional[int] = None
		self.listFlag:bool = False
		self.auxiliarIndex:int = 0


	def calculateAndSetAlignment(self, 
							  	 headerDelimiterPositions:list[int], 
							  	 delimiterPositions:list[int], 
								 defaultAlignments:list[str], 
								 hasHeader:bool) -> None:
		"""	Set the alignment of the cell based on the position of the delimiter. 

			Args:
				headerDelimiterPositions: The positions of the header delimiters.
				delimiterPositions: The positions of the delimiters.
				defaultAlignments: The default alignments.
				hasHeader: True if the table has a header, False otherwise.
		"""
		if self.position is None or self.positionStart is None:
			raise ValueError('Cell position must be set before calculating alignment.')
		
		if hasHeader:
			headerDelimiterIndex = 0
			while headerDelimiterIndex < len(defaultAlignments) and self.positionStart > headerDelimiterPositions[headerDelimiterIndex]:
				headerDelimiterIndex += 1
			if headerDelimiterIndex < len(defaultAlignments):
				self.alignment = defaultAlignments[headerDelimiterIndex]
			else:
				raise ValueError('Invalid table formatting')
	
	def __str__(self):
		return f'(Content: {self.content}, Rowspan: {self.rowspan}, Colspan: {self.colspan}, Alignment: {self.alignment}, Position: {self.position}, ListFlag: {self.listFlag}, AuxiliarIndex: {self.auxiliarIndex})'
	

	def __repr__(self):
		return self.__str__()


class GridRow():
	"""	Represents a row in a grid table. """
	cells:list[GridCell] = []


	def __init__(self, length: int = 1) -> None:
		self.cells = [GridCell() for _ in range(length)]


	def __getitem__(self, item):
		return self.cells[item]


	def __setitem__(self, key, value):
		self.cells[key] = value


	def __str__(self):
		return str(self.cells)


	def __repr__(self):
		return self.__str__()
	

class GridRowsTracker():
	"""	Represents the document object. """
	def __init__(self, size:int) -> None:
		self.gridRowTracker = [0 for _ in range(size)]


	def __getitem__(self, item:int) -> int:
		return self.gridRowTracker[item]


	def __setitem__(self, key:int, value:int) -> None:
		self.gridRowTracker[key] = value


	def __str__(self):
		return str(self.gridRowTracker)


	def __repr__(self):
		return self.__str__()

	def max(self) -> int:
		return max(self.gridRowTracker)
		


# Some type aliases
GridTableRow = list[GridCell]
GridTableRowList = list[GridTableRow]

def parseGridTableWithSpans(gridTable:str) -> tuple[GridTableRowList, GridTableRowList]:
	"""
	Parse a Pandoc-style grid table into a structure for HTML conversion with rowspan and colspan.

	:param pandoc_table: String of the Pandoc-style grid table.
	:return: List of lists representing the table with metadata for spans.
	"""
	#global hasHeader, defaultAlignments, headerDelimiterPositions, delimiterPositions, nextListElementMark
	
	# Initialize globals
	hasHeader = False
	defaultAlignments:list[str] = []
	headerDelimiterPositions:list[int] = []
	delimiterPositions:list[int] = []
	
	# Split the input into lines
	lines:list[str] = [line for line in gridTable.rstrip().split('\n')]


	# Detect separator lines by pattern (it does not take into account partial separators
	def isSeparator(line:str) -> bool:
		return matchGridTableSeparator.match(line) is not None


	# Set content on the cell - concatenating multilines, flagging lists
	def handleCellContent(cell:GridCell, content:str) -> None:
		_c = content.strip()

		if cell.content is None:	# Previous empty cell
			cell.rowspan += 1
			cell.colspan += 1
			if _c.startswith('- '):  # List in a cell
				cell.listFlag = True
				_c = re.sub(r'\\\s*$', '\n', _c)
				cell.content = _c + _nextListElementMark  # Add list element end mark to know when the list element ends		
			elif cell.listFlag and len(_c) > 0:  # any other content when handling list is concatenated to the last list element
				_c = re.sub(r'\\\s*$', '\n', _c)
				cell.content = _c + _nextListElementMark #add the list element end mark
			elif not _c:  # empty line. separation between list and other paragraph
				# cell.content = '\n' if not cell.content.endswith('\n') else ""
				cell.content = '\n' # cell content is always empty / None here.
			else:
				cell.content = re.sub(r'\\\s*$', '\n', _c)
		else: # Cell has content
			if _c.startswith('- '):  # List
				if not cell.listFlag:
					cell.content += '\n'
					#cell['content'] = cell['content'].strip("\n")
				cell.listFlag = True
				_c = re.sub(r'\\\s*$', '\n', _c)
				cell.content += _c + _nextListElementMark  # Add list element end mark to know when the list element ends
			elif cell.listFlag and len(_c) > 0:  # any other content when handling list is concatenated to the last list element
				# cell.content = cell.content.strip(nextListElementMark) #remove list element end mark
				cell.content = cell.content.removesuffix(_nextListElementMark) #remove list element end mark

				_c = re.sub(r'\\\s*$', '\n', _c)
				cell.content += ' ' + _c + _nextListElementMark #add list element end mark
			elif len(_c) == 0:  # separation between list and other paragraph
				if cell.listFlag:
					cell.listFlag = False
					cell.content += '\n\n' #end list by \n
				#content = re.sub(r'\\\s*$', "\n", content.strip())
				cell.content += '\n' if not cell.content.endswith('\n') else ''
			else:
				cell.content += ' ' + re.sub(r'\\\s*$', '\n', _c)

	# Adjust colspan of a cell
	def adjustColspan(row:GridRow, columnIndex:int, numberOfParts:int, line, numberOfColumns:int, delimiterPositions:list[int]) -> None:
		for j in range(columnIndex, numberOfParts):
			delimiterStart:Optional[int] = None
			colI = columnIndex
			while delimiterStart == None:
				delimiterStart = row[colI - 1].position if colI > 0 else 0
				colI -= 1
			positions = [line.find(delimiter, delimiterStart + 1) for delimiter in "|+" if delimiter in line[delimiterStart + 1:]]
			position = min(positions) if positions else -1
			if position > delimiterPositions[j]:  # Colspan to be increased
				row[columnIndex].colspan += 1
				if position == delimiterPositions[len(delimiterPositions) - 1]:  # last cell in row, adjust colspan to get max number columns
					colspan_allocated = row[columnIndex].colspan
					row[columnIndex].colspan += numberOfColumns - colspan_allocated - columnIndex
			elif position < delimiterPositions[j]:
				raise ValueError("Wrong cell formatting")
			else:
				break

		row[columnIndex].colspanAdjusted = True	# Mark cell as adjusted


	def checkDelimiterAlignment(line: str, delimiterPositions:list[int], delimiters: str = "|+") -> bool:
		"""
		Check if delimiters in a row align with expected positions.
		
		Args:
			line: The line of text to check
			delimiter_positions: List of expected positions (based on + characters)
			delimiters: String containing valid delimiter characters (default: "|+")
		
		Returns:
			bool: True if delimiters align correctly, False otherwise
		"""
		if not line or not delimiterPositions:
			return False
		
		printDebug(f'\nChecking line: "{line}"')
		printDebug(f'Expected delimiter positions: {delimiterPositions}')
		
		# For full separator lines (only +)
		if '+' in line and '|' not in line:
			currentPositions = [i for i, char in enumerate(line) if (char == '+' and i > 0)]
			printDebug(f'Full separator line - Found + at positions: {currentPositions}')
			return all(delimiterPositions[-1] in currentPositions and line.startswith('+') and pos in delimiterPositions 
					   for pos in currentPositions)
		
		# For data lines (only |)
		if '|' in line and '+' not in line:
			currentPositions = [i for i, char in enumerate(line) if (char == '|' and i > 0)]
			printDebug(f'Data line - Found | at positions: {currentPositions}')
			return all(delimiterPositions[-1] in currentPositions and line.startswith("|") and pos in delimiterPositions 
			  		   for pos in currentPositions)
		
		# For partial separators (mix of + and |)
		currentPositions = [i for i, char in enumerate(line) if (char in delimiters and i > 0)]
		printDebug(f'Partial separator - Found delimiters at positions: {currentPositions}')
		printDebug(f'Characters at those positions: {[line[pos] for pos in currentPositions]}')
		return all(delimiterPositions[-1] in currentPositions and line.startswith(('+', '|')) and pos in delimiterPositions 
			 	   for pos in currentPositions)

	separatorIndices = [i for i, line in enumerate(lines) if isSeparator(line)]

	if not separatorIndices:
		raise ValueError('No valid separators found in the provided grid table.')

	# Calculate max number of columns
	delimiterPositions = []
	numberOfColumns:int = 0

	for separatorIndex in separatorIndices:
		if (_cnt := lines[separatorIndex].count('+') - 1) > numberOfColumns:
			numberOfColumns = _cnt
			delimiterPositions = []
			for rowIndex in range(numberOfColumns):
				delimiterPositionsStart = delimiterPositions[rowIndex - 1] if rowIndex != 0 else 0
				delPositions = [lines[separatorIndex].find(delimiter, delimiterPositionsStart + 1) 
								for delimiter in '+' if delimiter in lines[separatorIndex][delimiterPositionsStart + 1:]]
				delimiterPositions.append(min(delPositions) if delPositions else -1)
	
	# Determine delimter positions and alignments
	headerRows:GridTableRowList = []
	dataRows:GridTableRowList = []

	for index in separatorIndices:
		if matchGridTableHeaderSeparator.match(lines[index]):
			hasHeader = True
			headerSeparatorIndex = index
			parts = re.split(r'\+', lines[index].strip('+'))
			#Calculate default alignments and positions of delimiters
			for partIndex in range(len(parts)):
				# Left alignment
				if parts[partIndex].startswith(':') and not parts[partIndex].endswith(':'):	
					defaultAlignments.append(_alignLeft)

				# Right alignment
				elif not parts[partIndex].startswith(':') and parts[partIndex].endswith(':'): 
					defaultAlignments.append(_alignRight)

				# Center alignment
				else:
					defaultAlignments.append(_alignCenter)	

				# Delimiter position
				delimiterPositionsStart = delimiterPositions[partIndex - 1] if partIndex != 0 else 0
				delPositions = [lines[index].find(delimiter, delimiterPositionsStart + 1) 
								for delimiter in '+' if delimiter in lines[index][delimiterPositionsStart + 1:]]
				headerDelimiterPositions.append(min(delPositions) if delPositions else -1)

	if not hasHeader:
		# Set default alignments from the first separator which takes the role of header
		hasHeader = True
		headerSeparatorIndex = 0
		parts = re.split(r'\+', lines[0].strip('+'))

        # Calculate default alignments and positions of delimiters
		for partIndex in range(len(parts)):
			if parts[partIndex].startswith(':') and not parts[partIndex].endswith(':'):
				defaultAlignments.append(_alignLeft)

			elif not parts[partIndex].startswith(':') and parts[partIndex].endswith(':'):
				defaultAlignments.append(_alignRight)

			else:
				defaultAlignments.append(_alignCenter)

		    # Delimiter position
			delimiterPositionsStart = delimiterPositions[partIndex - 1] if partIndex != 0 else 0
			delPositions = [lines[index].find(delimiter, delimiterPositionsStart + 1)
                            for delimiter in '+' if delimiter in lines[index][delimiterPositionsStart + 1:]]
			headerDelimiterPositions.append(min(delPositions) if delPositions else -1)

	#Check end table delimiter alignment (not checked during the lines processing)
	if not checkDelimiterAlignment(lines[-1], delimiterPositions):
		raise ValueError(f'Misaligned delimiters in end table separator: {lines[-1]}')
					
	for rowNumber in range(len(separatorIndices) - 1):
		rows:list[GridRow] = []
		rowsTracker:GridRowsTracker
		inDataRow = False
		start, end = separatorIndices[rowNumber], separatorIndices[rowNumber + 1]
		rowLines = lines[start:end]  # Lines between separators including separator line start as it gives information about the number of columns of the row
		if rowLines:
			# Combine multiline content into single strings for each cell
			for line in rowLines:
				line = line.rstrip()
				if isSeparator(line) and not inDataRow:
					inDataRow = True
					# Add delimiter alignment check for separator lines
					if not checkDelimiterAlignment(line, delimiterPositions):
						raise ValueError(f'Misaligned delimiters in separator row: {line}')
					
					parts = re.split(r'\s*\+\s*', line.strip('+'))
					delimiterIndex = 0

					rows.append(GridRow(numberOfColumns))
					rowsTracker = GridRowsTracker(numberOfColumns)
					columnIndex = 0

					for rowIndex in range(len(parts)):
						if columnIndex in range(numberOfColumns):
							delimiterIndex += len(parts[rowIndex]) + 1
							cell = rows[-1][columnIndex]
							
							# Set position
							cell.positionStart = delimiterIndex - len(parts[rowIndex])
							cell.position = delimiterIndex # Position of cell delimiter +
							
							# Set alignment as defined by header separator line
							cell.calculateAndSetAlignment(headerDelimiterPositions, delimiterPositions, defaultAlignments, hasHeader)

							while delimiterIndex > delimiterPositions[columnIndex]:
								columnIndex += 1
							columnIndex += 1

				elif inDataRow:
					# Regular data row or partial separator
					if matchGridTableBodySeparator.match(line): # Partial separator
						# Add delimiter alignment check for partial separators
						if not checkDelimiterAlignment(line, delimiterPositions):
							raise ValueError(f'Misaligned delimiters in partial separator: {line}')

						cellsContent = re.split(r'[\|\+]', line.strip('|').strip('+'))  # (?<!\\)[\|\+]
						#Add another row, set delimiters for each cell
						rows.append(GridRow(numberOfColumns))
						auxDelimiterIndex = 0
						auxiliarCellIndex = 0

						for columnIndex, content in enumerate(cellsContent):
							if auxiliarCellIndex < numberOfColumns:
								auxDelimiterIndex += len(content) + 1
								cell = rows[-1][auxiliarCellIndex]
								cell.positionStart = auxDelimiterIndex - len(content)  # Position of cell delimiter +
								cell.position = auxDelimiterIndex  # Position of cell delimiter +
								cell.calculateAndSetAlignment(headerDelimiterPositions, delimiterPositions, defaultAlignments, hasHeader)
								while auxDelimiterIndex > delimiterPositions[auxiliarCellIndex]:
									auxiliarCellIndex += 1
								auxiliarCellIndex += 1

						if len(cellsContent) <= numberOfColumns: # Colspan: Positions of | with respect to + need to be determined
							columnCellIndex = 0

							maxRowsTracker = rowsTracker.max()
							# Go through all cells in a columnt
							for columnIndex, content in enumerate(cellsContent):
								rowIndex = rowsTracker[columnCellIndex]
								cell = rows[rowIndex][columnCellIndex]

								# Check whether a cell contains a header separator
								if matchGridTableBodySeparatorLine.match(content):  # A new row is to be added
									rowsTracker[columnCellIndex] = maxRowsTracker + 1	# That actual row will have more than one row
									rowIndex = rowsTracker[columnCellIndex]
									cell = rows[rowIndex][columnCellIndex]

									cell.listFlag = False
									columnForward = 0
								
									for delIndex in range(columnCellIndex, len(delimiterPositions)):
										rowIndex = rowsTracker[columnCellIndex]	# Correcting the rowIndex. Might have been changed by a previous iteration
										if rows[rowIndex][columnCellIndex].position >= delimiterPositions[delIndex]:
											columnForward += 1
											rowsTracker[columnCellIndex + columnForward - 1] = maxRowsTracker + 1 if columnForward > 1 else 0
									columnCellIndex += columnForward

									continue

								else:
									# Handle content of the cell
									handleCellContent(cell, cellsContent[columnIndex])
									cell.rowspan += 1
									if not cell.colspanAdjusted:
										# TO BE CHECKED Most probably the code below is never executed, colspan should be already adjusted when dealing with a partial separator
										adjustColspan(rows[rowIndex], columnCellIndex, numberOfColumns, line, numberOfColumns, delimiterPositions)

									if cell.position >= delimiterPositions[columnCellIndex]:
										columnCellIndex += cell.colspan if cell.colspan != 0 else 1
									continue

						else:
							raise ValueError(f'More cells than columns found ({len(cellsContent)} {numberOfColumns})')
						
					else: # Data row
						cellsContent = re.split(r'\|', line.strip('|'))
						
						# Add delimiter alignment check
						if not checkDelimiterAlignment(line, delimiterPositions):
							raise ValueError(f'Misaligned delimiters in row: {line}')
							
						columnCellIndex = 0
						if len(cellsContent) < numberOfColumns: # Colspan: Positions of | with respect to + need to be determined
							for columnIndex, content in enumerate(cellsContent):
								row = rows[rowsTracker[columnCellIndex]]
								cell = row[columnCellIndex]
								# Handle content of the cell
								handleCellContent(cell, content)
								if not cell.colspanAdjusted:
									#TO BE CHECKED Most probably the code below is never executed, colspan should be already adjusted when dealing with a partial separator
									adjustColspan(row, columnCellIndex, numberOfColumns, line, numberOfColumns, delimiterPositions)
								if cell.position >= delimiterPositions[columnCellIndex]:
									columnCellIndex += cell.colspan  # Move forward index i

						elif len(cellsContent) == numberOfColumns: # Simple row
							for columnIndex, content in enumerate(cellsContent):
								rowIndex = rowsTracker[columnIndex]
								handleCellContent(rows[rowIndex][columnIndex], content)
						else:
							raise ValueError(f'More cells than columns found ({len(cellsContent)} {numberOfColumns})')
				else:
					raise ValueError('No separator line found for row starting')

			if hasHeader and start >= headerSeparatorIndex: # table_row and auxiliar_row are part of data_rows
				for row in rows:
					dataRows.append(row.cells)
			elif hasHeader and start < headerSeparatorIndex: # table_row and auxiliar_row are part of header_rows
				for row in rows:	# header rows
					headerRows.append(row.cells)
			else:
				#only body
				for row in rows:
					dataRows.append(row.cells)

	# Check if there are any data rows
	if not dataRows and not headerRows:
		raise ValueError('No valid rows found in the provided grid table.')

	# Format text
	for gridRows in [headerRows, dataRows]:
		for gridRow in gridRows:
			for cell in gridRow:
				if cell.content is not None:
					# Replacing "<" by &lt;
					cell.content = cell.content.replace('<', '&lt;')

					
					# Bold replacements
					# Regex to detect markdown bold formatting in cell content
					if cell.content is not None:
						cell.content = matchBold.sub(r'\1<strong>\g<text></strong>', cell.content)

					# Italic replacements
					# Regex to detect markdown italic formatting in cell content
					if cell.content is not None:
						cell.content = matchItalic.sub(r'\1<i>\g<text></i>', cell.content)

	# Correct newlines characters
	for headerRow in headerRows:
		for cell in headerRow:
			cell.content = cell.content.replace('\n', '<br />') if cell.content is not None else None
	for dataRow in dataRows:
		for cell in dataRow:
			cell.content = cell.content.replace('\n', '<br />') if cell.content is not None else None

	#
	# Checking that the grid is correct Not too much tested - need to take into account rowspan of previous rows
	#

	# Checking the header rows
	forwardRowspan:list[int] = []
	for idx, headerRow in enumerate(headerRows):
		if len(forwardRowspan) == 0:
			forwardRowspan = [0] * len(headerRows[idx])
		sum = 0

		for cellIndex, cell in enumerate(headerRow):
			sum += cell.colspan
			if idx > 0 and cell.colspan == 0:
				if forwardRowspan[cellIndex] > 0:
					sum += 1
				forwardRowspan[cellIndex] -= 1
			if forwardRowspan[cellIndex] == 0 and cell.rowspan > 1:
				forwardRowspan[cellIndex] = cell.rowspan -1

		if not sum == numberOfColumns:
			raise ValueError('Grid table not converted properly')

	# Checking the data rows
	forwardRowspan = []
	for idx, dataRow in enumerate(dataRows):
		if len(forwardRowspan) == 0:
			forwardRowspan = [0] * len(dataRows[idx])
		sum = 0

		for cellIndex, cell in enumerate(dataRows[idx]):
			sum += cell.colspan
			if idx > 0 and cell.colspan == 0:
				if forwardRowspan[cellIndex] > 0:
					sum += 1
				forwardRowspan[cellIndex] -= 1
			if forwardRowspan[cellIndex] == 0 and cell.rowspan > 1:
				forwardRowspan[cellIndex] = cell.rowspan - 1
		if not sum == numberOfColumns:
			raise ValueError('Grid table not converted properly')

	return headerRows, dataRows


def generateHtmlTableWithSpans(gridTable:str) -> str:
	"""	Generate an HTML table from a Pandoc-style grid table with row and column spans.

		Args:
			gridTable: The Pandoc-style grid table.

		Returns:
			The HTML table in string format.
	"""
	regex1 = r'\s*([-*+]|\s*\d+\.)\s+((?:(?!' + re.escape(_nextListElementMark) + r').)+)' + re.escape(_nextListElementMark)
	regex2 = r'(\s*([-*+]|\s*\d+\.)\s+(?:(?!∆).)+' + re.escape(_nextListElementMark) + r')+'

	try:
		gridHeader, gridBody = parseGridTableWithSpans(gridTable)
	except Exception as e:
		printDebug('Grid table could not be generated')
		raise RuntimeError(f'HTML TABLE COULD NOT BE GENERATED FROM MARKDOWN GRID TABLE:\n{str(e)}')
		
	# Generate table HTML...
	html = '<table>\n'
	hasHeader = False

	for row in gridHeader:
		for cell in row:
			if cell.rowspan != 0 and cell.colspan != 0:
				hasHeader = True
				break

	if hasHeader:
		html += '    <thead>\n'
		for row in gridHeader:
			html += "        <tr>\n"
			for cell in row:
				if cell.rowspan == 0 or cell.colspan == 0:
					continue
				else:
					# Prepare content, in case there's a list
					if cell.content is not None and (matches := re.findall(regex1, cell.content)):  # Update cell in new row
						list = '<ul>'
						# Build list the matches
						for match in matches:
							list += '<li>' + match[1] + '</li>'
						list += '</ul>'
						cell.content = re.sub(regex2, list, cell.content)
						# Enforce left alignment if cell contains a list
						cell.alignment = _alignLeft

					rowspan = f' rowspan="{cell.rowspan}"' if cell.rowspan > 1 else ''
					colspan = f' colspan="{cell.colspan}"' if cell.colspan > 1 else ''
					html +=   f'            <th{rowspan}{colspan} {cell.alignment}>{cell.content}</th>\n'
			html += '        </tr>\n'
		html += '    </thead>\n'


	html += '    <tbody>\n'
	for row in gridBody:
		html += '        <tr>\n'
		for cell in row:
			if cell.rowspan == 0 or cell.colspan == 0:
				continue
			else:
				#Prepare content, in case there's a list
				if cell.content is not None and (matches := re.findall(regex1, cell.content)):  # Update cell in new row
					list = '<ul>'
					# Build list the matches
					for match in matches:
						list += f'<li>{match[1]}</li>'
					list += '</ul>'
					cell.content = re.sub(regex2, list, cell.content)
					# Enforce left alignment if cell contains a list
					cell.alignment = _alignLeft

				rowspan = f' rowspan="{cell.rowspan}"' if cell.rowspan > 1 else ''
				colspan = f' colspan="{cell.colspan}"' if cell.colspan > 1 else ''
				html +=   f'            <td{rowspan}{colspan} {cell.alignment}>{cell.content}</td>\n'
		html += '        </tr>\n'

	html += '    </tbody>\n'
	html += '</table>'
	return html