CoalitionPredictionServiceImpl.java

/*
 * Copyright 2010-2025 James Pether Sörling
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 *	$Id$
 *  $HeadURL$
 */
package com.hack23.cia.service.impl;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.access.annotation.Secured;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.annotation.Propagation;

import com.hack23.cia.model.internal.application.data.party.impl.ViewRiksdagenCoalitionAlignmentMatrix;
import com.hack23.cia.model.internal.application.data.party.impl.ViewRiksdagenCoalitionAlignmentMatrixEmbeddedId;
import com.hack23.cia.model.internal.application.data.party.impl.ViewRiksdagenParty;
import com.hack23.cia.service.data.api.ViewRiksdagenCoalitionAlignmentMatrixDAO;
import com.hack23.cia.service.data.api.ViewRiksdagenPartyDAO;

/**
 * Implementation of coalition prediction service.
 * 
 * Uses existing view_riksdagen_coalition_alignment_matrix to generate
 * coalition formation scenarios based on historical voting patterns.
 */
@Service
@Transactional(propagation = Propagation.REQUIRED, timeout = 1200)
@Secured({"ROLE_ANONYMOUS", "ROLE_USER", "ROLE_ADMIN"})
public class CoalitionPredictionServiceImpl implements CoalitionPredictionService {

	private static final Logger LOGGER = LoggerFactory.getLogger(CoalitionPredictionServiceImpl.class);

	private static final int MAJORITY_SEATS = 175;
	private static final List<String> ALL_PARTIES = Arrays.asList("S", "M", "SD", "C", "V", "KD", "MP", "L");
	
	// Political bloc definitions
	private static final List<String> LEFT_BLOC_PARTIES = Arrays.asList("S", "V", "MP");
	private static final List<String> RIGHT_BLOC_PARTIES = Arrays.asList("M", "KD", "L", "C");
	
	// Probability calculation weights
	private static final double ALIGNMENT_WEIGHT = 0.6;
	private static final double SEATS_WEIGHT = 0.4;
	private static final double SEATS_SCALING_FACTOR = 100.0;
	
	// Coalition type constants
	private static final String COALITION_TYPE_TWO_PARTY = "TWO_PARTY";
	private static final String COALITION_TYPE_THREE_PARTY = "THREE_PARTY";
	private static final String COALITION_TYPE_FOUR_PARTY = "FOUR_PARTY";
	
	// Bloc relationship constants
	private static final String BLOC_CROSS_BLOC = "CROSS_BLOC";
	private static final String BLOC_LEFT_SD_COALITION = "LEFT_SD_COALITION";
	private static final String BLOC_RIGHT_SD_COALITION = "RIGHT_SD_COALITION";
	private static final String BLOC_LEFT_BLOC = "LEFT_BLOC";
	private static final String BLOC_RIGHT_BLOC = "RIGHT_BLOC";
	private static final String BLOC_OTHER = "OTHER";
	
	// Fallback seat counts (2022 election results) - should be externalized to configuration
	private static final Map<String, Integer> FALLBACK_SEAT_COUNTS = new HashMap<>();
	static {
		FALLBACK_SEAT_COUNTS.put("S", 107);
		FALLBACK_SEAT_COUNTS.put("M", 68);
		FALLBACK_SEAT_COUNTS.put("SD", 73);
		FALLBACK_SEAT_COUNTS.put("C", 24);
		FALLBACK_SEAT_COUNTS.put("V", 24);
		FALLBACK_SEAT_COUNTS.put("KD", 19);
		FALLBACK_SEAT_COUNTS.put("MP", 18);
		FALLBACK_SEAT_COUNTS.put("L", 16);
	}

	@Autowired
	private ViewRiksdagenCoalitionAlignmentMatrixDAO coalitionAlignmentMatrixDAO;
	
	@Autowired
	private ViewRiksdagenPartyDAO partyDAO;

	@Override
	public List<CoalitionScenario> predictCoalitions(final String year) {
		// NOTE: Year parameter is logged for auditing but not currently used to filter alignment data
		// because ViewRiksdagenCoalitionAlignmentMatrix represents aggregate historical patterns
		// rather than year-specific snapshots. Future enhancement could filter by year if needed.
		LOGGER.info("Generating coalition scenarios for year: {}", year);

		final List<ViewRiksdagenCoalitionAlignmentMatrix> alignmentData = coalitionAlignmentMatrixDAO.getAll();
		final Map<String, Map<String, Double>> alignmentMatrix = buildAlignmentMatrix(alignmentData);
		final Map<String, Integer> seatCounts = loadSeatCounts();

		final List<CoalitionScenario> scenarios = new ArrayList<>();

		// Generate 2-party coalitions
		scenarios.addAll(generateTwoPartyCoalitions(alignmentMatrix, seatCounts));

		// Generate 3-party coalitions
		scenarios.addAll(generateThreePartyCoalitions(alignmentMatrix, seatCounts));

		// Generate 4-party coalitions
		scenarios.addAll(generateFourPartyCoalitions(alignmentMatrix, seatCounts));

		// Sort by probability descending
		scenarios.sort(Comparator.comparingDouble(CoalitionScenario::getProbability).reversed());

		// Return top 10 scenarios
		return scenarios.stream().limit(10).collect(Collectors.toList());
	}

	@Override
	public Map<String, Map<String, Double>> getAlignmentMatrix(final String year) {
		// NOTE: Year parameter is accepted for API consistency but not currently used for filtering
		// because ViewRiksdagenCoalitionAlignmentMatrix represents aggregate historical patterns.
		final List<ViewRiksdagenCoalitionAlignmentMatrix> alignmentData = coalitionAlignmentMatrixDAO.getAll();
		return buildAlignmentMatrix(alignmentData);
	}

	@Override
	public int calculateStabilityIndex(final List<String> parties, final String year) {
		if (parties == null || parties.size() < 2) {
			return 0;
		}

		final Map<String, Map<String, Double>> alignmentMatrix = getAlignmentMatrix(year);
		final List<Double> pairwiseAlignments = new ArrayList<>();

		for (int i = 0; i < parties.size(); i++) {
			for (int j = i + 1; j < parties.size(); j++) {
				final String party1 = parties.get(i);
				final String party2 = parties.get(j);
				final double alignment = getAlignment(alignmentMatrix, party1, party2);
				pairwiseAlignments.add(alignment);
			}
		}

		if (pairwiseAlignments.isEmpty()) {
			return 0;
		}

		final double avgAlignment = pairwiseAlignments.stream()
				.mapToDouble(Double::doubleValue)
				.average()
				.orElse(0.0);

		return (int) (avgAlignment * 100);
	}

	private Map<String, Map<String, Double>> buildAlignmentMatrix(final List<ViewRiksdagenCoalitionAlignmentMatrix> alignmentData) {
		final Map<String, Map<String, Double>> matrix = new HashMap<>();

		for (final String party : ALL_PARTIES) {
			matrix.put(party, new HashMap<>());
		}

		for (final ViewRiksdagenCoalitionAlignmentMatrix data : alignmentData) {
			final ViewRiksdagenCoalitionAlignmentMatrixEmbeddedId embeddedId = data.getEmbeddedId();
			if (embeddedId == null) {
				continue;
			}
			
			final String party1 = embeddedId.getParty1();
			final String party2 = embeddedId.getParty2();
			// Use alignmentRate from the main entity, not the embedded ID
			final Double alignmentRate = data.getAlignmentRate();

			if (party1 != null && party2 != null && alignmentRate != null) {
				// alignment_rate is already a rate (0-1), no conversion needed
				matrix.computeIfAbsent(party1, k -> new HashMap<>()).put(party2, alignmentRate);
				matrix.computeIfAbsent(party2, k -> new HashMap<>()).put(party1, alignmentRate);
			}
		}

		// Set self-alignment to 1.0
		for (final String party : ALL_PARTIES) {
			matrix.get(party).put(party, 1.0);
		}

		return matrix;
	}

	private Map<String, Integer> loadSeatCounts() {
		final Map<String, Integer> seatCounts = new HashMap<>();
		
		try {
			final List<ViewRiksdagenParty> parties = partyDAO.getAll();
			for (final ViewRiksdagenParty party : parties) {
				// ViewRiksdagenParty uses partyNumber as the short code (e.g., "S", "M", "SD")
				// and partyId is the primary key
				final String partyShortCode = party.getPartyNumber();
				final long headCount = party.getHeadCount();
				if (headCount > 0 && partyShortCode != null) {
					seatCounts.put(partyShortCode, (int) headCount);
				}
			}
		} catch (final Exception e) {
			LOGGER.warn("Could not load seat counts from database, using defaults", e);
		}
		
		if (seatCounts.isEmpty()) {
			seatCounts.putAll(FALLBACK_SEAT_COUNTS);
		}
		
		return seatCounts;
	}

	private List<CoalitionScenario> generateTwoPartyCoalitions(
			final Map<String, Map<String, Double>> alignmentMatrix,
			final Map<String, Integer> seatCounts) {
		final List<CoalitionScenario> scenarios = new ArrayList<>();

		for (int i = 0; i < ALL_PARTIES.size(); i++) {
			for (int j = i + 1; j < ALL_PARTIES.size(); j++) {
				final String party1 = ALL_PARTIES.get(i);
				final String party2 = ALL_PARTIES.get(j);
				final List<String> coalition = Arrays.asList(party1, party2);

				final int totalSeats = getTotalSeats(coalition, seatCounts);
				if (totalSeats < MAJORITY_SEATS) {
					continue;
				}

				final double alignment = getAlignment(alignmentMatrix, party1, party2);
				final double probability = calculateProbability(alignment, totalSeats);
				final int stability = (int) (alignment * 100);
				final String blocRelation = determineBlocRelationship(coalition);

				scenarios.add(new CoalitionScenario(coalition, totalSeats, probability, 
						stability, COALITION_TYPE_TWO_PARTY, blocRelation));
			}
		}

		return scenarios;
	}

	private List<CoalitionScenario> generateThreePartyCoalitions(
			final Map<String, Map<String, Double>> alignmentMatrix,
			final Map<String, Integer> seatCounts) {
		final List<CoalitionScenario> scenarios = new ArrayList<>();

		for (int i = 0; i < ALL_PARTIES.size(); i++) {
			for (int j = i + 1; j < ALL_PARTIES.size(); j++) {
				for (int k = j + 1; k < ALL_PARTIES.size(); k++) {
					final String party1 = ALL_PARTIES.get(i);
					final String party2 = ALL_PARTIES.get(j);
					final String party3 = ALL_PARTIES.get(k);
					final List<String> coalition = Arrays.asList(party1, party2, party3);

					final int totalSeats = getTotalSeats(coalition, seatCounts);
					if (totalSeats < MAJORITY_SEATS) {
						continue;
					}

					final double avgAlignment = calculateAverageAlignment(coalition, alignmentMatrix);
					final double probability = calculateProbability(avgAlignment, totalSeats);
					final int stability = (int) (avgAlignment * 100);
					final String blocRelation = determineBlocRelationship(coalition);

					scenarios.add(new CoalitionScenario(coalition, totalSeats, probability,
							stability, COALITION_TYPE_THREE_PARTY, blocRelation));
				}
			}
		}

		return scenarios;
	}

	private List<CoalitionScenario> generateFourPartyCoalitions(
			final Map<String, Map<String, Double>> alignmentMatrix,
			final Map<String, Integer> seatCounts) {
		final List<CoalitionScenario> scenarios = new ArrayList<>();

		for (int i = 0; i < ALL_PARTIES.size(); i++) {
			for (int j = i + 1; j < ALL_PARTIES.size(); j++) {
				for (int k = j + 1; k < ALL_PARTIES.size(); k++) {
					for (int l = k + 1; l < ALL_PARTIES.size(); l++) {
						final String party1 = ALL_PARTIES.get(i);
						final String party2 = ALL_PARTIES.get(j);
						final String party3 = ALL_PARTIES.get(k);
						final String party4 = ALL_PARTIES.get(l);
						final List<String> coalition = Arrays.asList(party1, party2, party3, party4);

						final int totalSeats = getTotalSeats(coalition, seatCounts);
						if (totalSeats < MAJORITY_SEATS) {
							continue;
						}

						final double avgAlignment = calculateAverageAlignment(coalition, alignmentMatrix);
						final double probability = calculateProbability(avgAlignment, totalSeats);
						final int stability = (int) (avgAlignment * 100);
						final String blocRelation = determineBlocRelationship(coalition);

						scenarios.add(new CoalitionScenario(coalition, totalSeats, probability,
								stability, COALITION_TYPE_FOUR_PARTY, blocRelation));
					}
				}
			}
		}

		return scenarios;
	}

	private int getTotalSeats(final List<String> parties, final Map<String, Integer> seatCounts) {
		return parties.stream()
				.mapToInt(party -> seatCounts.getOrDefault(party, 0))
				.sum();
	}

	private double getAlignment(final Map<String, Map<String, Double>> alignmentMatrix, 
			final String party1, final String party2) {
		return alignmentMatrix.getOrDefault(party1, new HashMap<>()).getOrDefault(party2, 0.0);
	}

	private double calculateAverageAlignment(final List<String> parties, 
			final Map<String, Map<String, Double>> alignmentMatrix) {
		final List<Double> alignments = new ArrayList<>();

		for (int i = 0; i < parties.size(); i++) {
			for (int j = i + 1; j < parties.size(); j++) {
				final double alignment = getAlignment(alignmentMatrix, parties.get(i), parties.get(j));
				alignments.add(alignment);
			}
		}

		return alignments.stream().mapToDouble(Double::doubleValue).average().orElse(0.0);
	}

	private double calculateProbability(final double alignment, final int totalSeats) {
		final double alignmentFactor = alignment * ALIGNMENT_WEIGHT;
		final double seatsFactor = Math.min((totalSeats - MAJORITY_SEATS) / SEATS_SCALING_FACTOR, SEATS_WEIGHT);
		return Math.min(alignmentFactor + seatsFactor, 1.0);
	}

	private String determineBlocRelationship(final List<String> parties) {
		final boolean hasLeft = parties.stream().anyMatch(LEFT_BLOC_PARTIES::contains);
		final boolean hasRight = parties.stream().anyMatch(RIGHT_BLOC_PARTIES::contains);
		final boolean hasSD = parties.contains("SD");

		if (hasLeft && hasRight) {
			return BLOC_CROSS_BLOC;
		} else if (hasLeft && hasSD) {
			return BLOC_LEFT_SD_COALITION;
		} else if (hasRight && hasSD) {
			return BLOC_RIGHT_SD_COALITION;
		} else if (hasLeft) {
			return BLOC_LEFT_BLOC;
		} else if (hasRight) {
			return BLOC_RIGHT_BLOC;
		} else {
			return BLOC_OTHER;
		}
	}
}