#include "sound_analysis.h"
#include "adc/adc.h"
#include <avr/io.h>
#include <util/delay.h>
#include <stdint.h>

static int16_t samples[SAMPLE_COUNT];

// ------------------------------------------------
// Audio Capture
// ------------------------------------------------
void collect_audio_samples(void)
{
    const uint16_t prescaler = 128;
    const float adc_clk = (float)F_CPU / prescaler;
    const float conv_us = 13.0f * 1e6f / adc_clk;
    const float period_us = 1e6f / DESIRED_FS;
    uint16_t delay_us = (uint16_t)(period_us - conv_us - 2);
    if (delay_us < 1) delay_us = 1;

    for (int i = 0; i < SAMPLE_COUNT; i++) {
        ADCSRA |= (1 << ADSC);
        while (ADCSRA & (1 << ADSC));

        uint16_t raw = ADC;
        // CENTER around 0 immediately
        samples[i] = (int16_t)raw - 512;
        _delay_us(delay_us);
    }
}


// ------------------------------------------------
// Volume
// ------------------------------------------------
int16_t get_max_sample(void)
{
    // Measure peak-to-peak AC amplitude
    int16_t min = 32767;
    int16_t max = -32768;

    for (int i = 0; i < SAMPLE_COUNT; i++) {
        if (samples[i] < min) min = samples[i];
        if (samples[i] > max) max = samples[i];
    }

    return max - min;
}

// ------------------------------------------------
// Pitch Detection (Autocorrelation)
// ------------------------------------------------
float detect_frequency(void)
{
    // -------- Remove DC --------
    int32_t mean = 0;
    for (int i = 0; i < SAMPLE_COUNT; i++)
        mean += samples[i];
    mean /= SAMPLE_COUNT;

    for (int i = 0; i < SAMPLE_COUNT; i++)
        samples[i] -= (int16_t)mean;

    // -------- Lag bounds (slightly shrunk to avoid extremes) --------
    uint16_t min_lag = (uint16_t)(DESIRED_FS / GTR_HIGH * 1.05f);
    uint16_t max_lag = (uint16_t)(DESIRED_FS / GTR_LOW * 0.95f);
    if (max_lag >= SAMPLE_COUNT) max_lag = SAMPLE_COUNT - 1;

    // -------- Find lag with maximum autocorrelation --------
    int32_t best_corr = 0;
    uint16_t best_lag = min_lag;

    for (uint16_t lag = min_lag; lag <= max_lag; lag++)
    {
        int32_t corr = 0;
        for (uint16_t i = 0; i < SAMPLE_COUNT - lag; i++)
            corr += (int32_t)samples[i] * samples[i + lag];

        if (corr > best_corr)
        {
            best_corr = corr;
            best_lag = lag;
        }
    }

    if (best_lag == 0)
        return 0.0f;

    // -------- Parabolic interpolation --------
    int32_t c_prev = 0;
    int32_t c_next = 0;

    if (best_lag > min_lag)
    {
        for (uint16_t i = 0; i < SAMPLE_COUNT - (best_lag - 1); i++)
            c_prev += (int32_t)samples[i] * samples[i + best_lag - 1];
    }
    if (best_lag < max_lag)
    {
        for (uint16_t i = 0; i < SAMPLE_COUNT - (best_lag + 1); i++)
            c_next += (int32_t)samples[i] * samples[i + best_lag + 1];
    }

    float correction = 0.5f * (c_prev - c_next) / (c_prev - 2.0f * best_corr + c_next);
    float refined_lag = best_lag + correction;

    return (float)DESIRED_FS / refined_lag;
}