#include "zsm_backend.hpp"
#include <algorithm>
extern "C" {
#include "x16emu/glue.h"
#include "x16emu/vera_pcm.h"
#include "x16emu/vera_psg.h"
#include "x16emu/ymglue.h"
}
#include <ipc/common.pb.h>
#include <exception>
#include <filesystem>
#include "file_backend.hpp"
#include <stddef.h>
#include <string.h>
#include <file_backend.hpp>
#include <util.hpp>
#include <license.hpp>
#include <assets/assets.h>
#define HZ (AUDIO_SAMPLERATE)
#define BUFFERS 32
#define _PROPERTY(name, type, default_value) \
type ZsmBackend::name() { \
	std::optional<google::protobuf::Any> value_maybe = get(#name); \
	if (value_maybe.has_value()) { \
		return resolve_value<type>(value_maybe.value()); \
	} \
	return default_value; \
}

#include "properties.inc"
#undef _PROPERTY
std::vector<Property> ZsmBackend::get_property_list() {
	std::vector<Property> properties;
	properties.push_back(make_property(PropertyType::Boolean, "Enable PCM channel", "pcm_enable"));
	properties.push_back(make_property(PropertyType::Boolean, "Enable PSG channels", "psg_enable"));
	properties.push_back(make_property(PropertyType::Boolean, "Enable FM channels", "fm_enable"));
	properties.push_back(make_property(PropertyType::Double, "Volume of PCM channel", "pcm_volume", make_hint(0.0, 1.0)));
	properties.push_back(make_property(PropertyType::Double, "Volume of PSG channels", "psg_volume", make_hint(0.0, 1.0)));
	properties.push_back(make_property(PropertyType::Double, "Volume of FM channels", "fm_volume", make_hint(0.0, 1.0)));
	return properties;
}
void ZsmBackend::load(const char *filename) {
	memset(&spec, 0, sizeof(spec));
    current_file = filename;
    spec.format = AUDIO_S16SYS;
    spec.samples = 100;
    spec.channels = 2;
    spec.freq = PSG_FREQ;
    spec.size = 100 * 2 * sizeof(int16_t);
    file = open_file(filename);
    char magic[2];
    file->read(magic, 2, 1);
    if (magic[0] != 0x7a || magic[1] != 0x6d) {
        throw std::exception();
    }
    uint8_t version;
    file->read(&version, 1, 1);
    uint8_t loop_point[3];
    file->read(loop_point, 3, 1);
    this->loop_point = loop_point[0] | ((uint32_t)(loop_point[1]) << 8) | ((uint32_t)(loop_point[2]) << 16);
    file->read(loop_point, 3, 1);
    this->pcm_offset = loop_point[0] | ((uint32_t)(loop_point[1]) << 8) | ((uint32_t)(loop_point[2]) << 16);
    pcm_offset += 3;
    file->read(&fm_mask, 1, 1);
    file->read(loop_point, 2, 1);
    this->psg_channel_mask = loop_point[0] | ((uint16_t)(loop_point[1]) << 8);
    file->read(loop_point, 2, 1);
    this->tick_rate = loop_point[0] | ((uint16_t)(loop_point[1]) << 8);
    file->read(loop_point, 2, 1); // Reserved.
    music_data_start = file->get_pos();
    this->loop_point += music_data_start;
    file->seek(pcm_offset, SeekType::SET);
    file->read(loop_point, 1, 1);
    pcm_offset++;
    pcm_data_offs = ((((uint16_t)loop_point[0]) + 1) * 16) + pcm_offset;
    for (uint8_t i = 0; i <= loop_point[0]; i++) {
	    uint16_t instdef = (i * 16) + 1;
		pcm_instrument *inst = new pcm_instrument();
	    file->seek(pcm_offset + instdef, SeekType::SET);
	    file->read(&inst->geom, 1, 1);
	    uint8_t bytes[10];
	    file->read(bytes, 10, 1);
	    inst->loop_rem = bytes[9];
	    inst->loop_rem <<= 8;
	    inst->loop_rem |= bytes[8];
	    inst->loop_rem <<= 8;
	    inst->loop_rem |= bytes[7];
	    inst->loop = loop_rem;
	    inst->islooped = bytes[6] & 0x80;
	    inst->remain = bytes[5];
	    inst->remain <<= 8;
	    inst->remain |= bytes[4];
	    inst->remain <<= 8;
	    inst->remain |= bytes[3];
	    uint32_t cur = bytes[2];
	    cur <<= 8;
	    cur |= bytes[1];
	    cur <<= 8;
	    cur |= bytes[0];
	    cur += pcm_data_offs;
	    inst->data = (uint8_t*)malloc(inst->remain);
	    file->seek(cur, SeekType::SET);
	    file->read(inst->data, 1, inst->remain);
	    inst->loop_rem = inst->remain - inst->loop_rem;
		instruments.push_back(inst);
    }
    file->seek(music_data_start, SeekType::SET);
    this->loop_point = std::max(this->loop_point, (uint32_t)music_data_start);
    double prev_time = 0.0;
    double time = 0.0;
    double tmpDelayTicks = 0.0;
    loop_pos = -1.0;
    uint32_t prev_pos = music_data_start;
    while (true) {
        tmpDelayTicks -= get_delay_per_frame();
        if (tmpDelayTicks < 0.0) {
            ZsmCommand cmd = get_command();
            size_t cur_pos = file->get_pos();
            if (cur_pos >= this->loop_point && this->loop_pos < 0) {
                loop_pos = time;
                this->loop_point = cur_pos;
            }
            if (cmd.id == ZsmEOF) {
                break;
            } else if (cmd.id == Delay) {
                time += ((double)cmd.delay) / ((double)(tick_rate));
                tmpDelayTicks += cmd.delay;
            }
            prev_pos = file->get_pos();
            prev_time = time;
        }
    }
    if (this->loop_pos < 0.0) {
    	this->loop_pos = 0.0;
    	this->loop_point = music_data_start;
    }
    length = time;
    music_data_len = file->get_pos();
    switch_stream(0);
    loop_end = length;
    loop_start = this->loop_pos;
    fm_stream = SDL_NewAudioStream(AUDIO_S16SYS, 2, YM_FREQ, AUDIO_S16SYS, 2, PSG_FREQ);
    DEBUG.writefln("fm_stream: %ld -> %ld (Is null: %s)", YM_FREQ, PSG_FREQ, fm_stream == NULL ? "true" : "false");
#define _PROPERTY(name, type, default_value) \
	{ \
		std::string type_str = #type; \
		google::protobuf::Any value; \
		if (type_str == "bool") { \
			BooleanProperty value_b; \
			value_b.set_value(default_value); \
			value.PackFrom(value_b); \
		} else if (type_str == "double") { \
			DoubleProperty value_d; \
			value_d.set_value(default_value); \
			value.PackFrom(value_d); \
		} \
		property_defaults[#name] = value; \
	}
#include "properties.inc"
}
extern SDL_AudioSpec obtained;
void ZsmBackend::switch_stream(int idx) {
    YM_Create(YM_FREQ);
    YM_init(YM_FREQ/64, 60);
    YM_reset();
    psg_reset();
    pcm_reset();
    for (uint8_t i = 0; i < 16; i++) {
        psg_writereg(i * 4 + 2, 0);
    }
    file->seek(music_data_start, SeekType::SET);
    this->cpuClocks = 0.0;
    this->delayTicks = 0.0;
    this->ticks = 0.0;
}
void ZsmBackend::cleanup() {
    delete file;
    file = nullptr;
    audio_buf.clear();
    SDL_FreeAudioStream(fm_stream);
    fm_stream = nullptr;
    audio_sample = nullptr;
    for (auto inst : instruments) {
    	delete inst;
    }
    instruments.clear();
}
void ZsmBackend::tick(bool step) {
    delayTicks -= 1;
    const double ClocksPerTick = ((double)HZ) / ((double)tick_rate);
    double prevCpuClocks = cpuClocks;
    double nextCpuClocks = cpuClocks + ClocksPerTick;
    double ticks_remaining = ClocksPerTick;
    while (delayTicks <= 0) {
        ZsmCommand cmd = get_command();
        switch (cmd.id) {
            case ZsmEOF: {
                if (step) {
                    file->seek(this->loop_point, SeekType::SET);
                    this->position = loop_pos;
                } else {
                    throw std::exception();
                }
            } break;
            case PsgWrite: {
                psg_writereg(cmd.psg_write.reg, cmd.psg_write.val);
            } break;
            case FmWrite: {
                for (uint8_t i = 0; i < cmd.fm_write.len; i++) {
                    auto &pair = cmd.fm_write.regs[i];
                    YM_write_reg(pair.reg, pair.val);
                    while (YM_read_status()) {
                        size_t clocksToAddForYm = (size_t)std::ceil(((double)YM_FREQ)/((double)PSG_FREQ));
                        ticks_remaining -= clocksToAddForYm;
                        if (ticks_remaining < 0) {
                            delayTicks -= 1;
                            nextCpuClocks += ClocksPerTick;
                            ticks_remaining += ClocksPerTick;
                        }
                        audio_step(clocksToAddForYm);
                    }
                }
            } break;
            case Delay: {
                delayTicks += cmd.delay;
                position += ((double)cmd.delay) / ((double)(tick_rate));
            } break;
            case ExtCmd: {
                //cmd.extcmd.channel
                switch (cmd.extcmd.channel) {
                    case 0: {
                        for (size_t i = 0; i < cmd.extcmd.bytes; i += 2) {
                            switch (cmd.extcmd.pcm[i]) {
                                case 0: { // ctrl
                                	uint8_t ctrl = cmd.extcmd.pcm[i + 1];
                                 	if (ctrl & 0x80) {
                                  		remain = 0;
                                  	}
                                    pcm_write_ctrl(ctrl);
                                } break;
                                case 1: { // rate
                                    pcm_write_rate(cmd.extcmd.pcm[i + 1]);
                                } break;
                                default: { // trigger
                                    size_t file_pos = file->get_pos();
                                    uint8_t ctrl = pcm_read_ctrl();
                                    pcm_write_ctrl(ctrl | 0x80);
                                    uint16_t pcm_idx = cmd.extcmd.pcm[i + 1];
                                    pcm_instrument *inst = instruments[pcm_idx];
                                    ctrl = pcm_read_ctrl() & 0x0F;
                                    ctrl |= inst->geom & 0x30;
                                    pcm_write_ctrl(ctrl);
                                    audio_sample = inst->data;
                                    loop = inst->loop;
                                    loop_rem = inst->loop_rem;
                                    remain = inst->remain;
                                    islooped = inst->islooped;
                                    cur = 0;
                                } break;
                            }
                        } break;
                    }
                // Nothing handled yet.
                }
            } break;
        }
    }
    size_t nextCpuClocksInt = std::floor(nextCpuClocks);
    size_t prevCpuClocksInt = std::floor(prevCpuClocks);
    size_t cpuClocksIntDelta = nextCpuClocksInt - prevCpuClocksInt;
    audio_step(ticks_remaining);
    cpuClocks = std::fmod(nextCpuClocks, ClocksPerTick);
}
size_t ZsmBackend::render(void *buf, size_t maxlen) {
    size_t sample_type_len = 2;
    maxlen /= sample_type_len;
    while (audio_buf.size() <= maxlen) {
        tick(true);
    }
    size_t copied = copy_out(buf, maxlen) * sample_type_len;
    maxlen *= sample_type_len;
    return copied;
}
uint64_t ZsmBackend::get_min_samples() {
    return spec.size;
}
std::optional<uint64_t> ZsmBackend::get_max_samples() {
    return get_min_samples();
}
ZsmCommand ZsmBackend::get_command() {
    ZsmCommandId cmdid;
    uint8_t cmd_byte;
    file->read(&cmd_byte, 1, 1);
    if (cmd_byte == 0x80) {
        cmdid = ZsmEOF;
    } else {
        if ((cmd_byte >> 6) == 0) {
            cmdid = PsgWrite;
        } else if ((cmd_byte >> 6) == 0b01) {
            if (cmd_byte == 0b01000000) {
                cmdid = ExtCmd;
            } else {
                cmdid = FmWrite;
            }
        } else {
            cmdid = Delay;
        }
    }
    ZsmCommand output;
    output.id = cmdid;
    if (cmdid == ZsmEOF) {
        return output;
    } else if (cmdid == PsgWrite) {
        uint8_t value;
        file->read(&value, 1, 1);
        output.psg_write.reg = cmd_byte & 0x3F;
        output.psg_write.val = value;
    } else if (cmdid == FmWrite) {
        uint16_t _value;
        uint8_t *value = (uint8_t*)(void*)(&_value);
        uint8_t pairs = cmd_byte & 0b111111;
        output.fm_write.len = pairs;
        output.fm_write.regs = (reg_pair*)malloc((sizeof(reg_pair))*pairs);
        for (uint8_t i = 0; i < pairs; i++) {
        	file->read(value, 2, 1);
            output.fm_write.regs[i].reg = value[0];
            output.fm_write.regs[i].val = value[1];
        }
    } else if (cmdid == ExtCmd) {
        uint8_t ext_cmd_byte;
        file->read(&ext_cmd_byte, 1, 1);
        uint8_t bytes = ext_cmd_byte & 0x3F;
        uint8_t ch = ext_cmd_byte >> 6;
        output.extcmd.channel = ch;
        output.extcmd.bytes = bytes;
        if (ch == 1) {
        	output.extcmd.expansion.write_bytes = NULL;
        } else {
            output.extcmd.pcm = (uint8_t*)malloc(bytes); // Handles all other cases due to them being in a union, and each having the same type.
        }
        for (size_t i = 0; i < bytes; i++) {
            uint8_t byte;
            file->read(&byte, 1, 1);
            switch (ch) {
                case 0: {
                    output.extcmd.pcm[i] = byte;
                } break;
                case 1: {
                    if (i == 0) {
                        output.extcmd.expansion.chip_id = byte;
                    } else if (i == 1) {
                        output.extcmd.expansion.writes = byte;
                        output.extcmd.expansion.write_bytes = (uint8_t*)malloc(byte);
                    } else {
                        output.extcmd.expansion.write_bytes[i - 2] = byte;
                    }
                } break;
                case 2: {
                    output.extcmd.sync[i] = byte;
                } break;
                case 3: {
                    output.extcmd.custom[i] = byte;
                } break;
            }
        }
    } else if (cmdid == Delay) {
        output.delay = cmd_byte & 0x7F;
    }
    return output;
}
ZsmCommand::~ZsmCommand() {
    switch (id) {
        case ExtCmd: {
            if (extcmd.channel == 1) {
            	if (extcmd.expansion.write_bytes != NULL) {
             		free(extcmd.expansion.write_bytes);
	            }
            } else {
                free(extcmd.pcm);
            }
        } break;
        case FmWrite: {
            free(fm_write.regs);
        }
    }
}
void ZsmBackend::seek_internal(double position, bool loop) {
    this->position = std::floor(this->position * PSG_FREQ) / PSG_FREQ;
    position = std::floor(position * PSG_FREQ) / PSG_FREQ;
    if (this->position > position) {
    	switch_stream(0);
        file->seek(music_data_start, SeekType::SET);
        this->cpuClocks = 0.0;
        this->delayTicks = 0;
        this->ticks = 0.0;
        this->position = 0.0;
    } else if (this->position == position) {
        audio_buf.clear();
        return;
    }
    while (this->position < position) {
        audio_buf.clear();
        try {
            tick(false);
        } catch (std::exception) {
            switch_stream(0);
            file->seek(music_data_start, SeekType::SET);
            this->cpuClocks = 0.0;
            this->delayTicks = 0;
            this->ticks = 0.0;
            this->position = 0.0;
            audio_buf.clear();
            return;
        }
    }
    size_t samples = std::min((size_t)((this->position - position) * PSG_FREQ), audio_buf.size());
    while (samples--) {
        audio_buf.pop();
    }
    this->position = position;
}
void ZsmBackend::seek(double position) {
    seek_internal(position, false);
}
double ZsmBackend::get_position() {
    return position;
}
int ZsmBackend::get_stream_idx() {
    return 0;
}

void ZsmBackend::audio_step(size_t samples) {
	if (samples == 0) return;
 	while (remain != 0 && pcm_fifo_avail() < samples) {
 		if (pcm_read_rate() == 0) break;
        if ((--remain) == 0) {
            if (islooped) {
             	cur = loop;
                remain = loop_rem;
            } else {
            	break;
            }
        }
        size_t oldpos = file->get_pos();
        uint8_t sample = audio_sample[cur++];
        pcm_write_fifo(sample);
    }
    samples *= 2;
	int16_t *psg_ptr = psg_buf.get_item_sized<int16_t>(samples);
    int16_t *pcm_ptr = pcm_buf.get_item_sized<int16_t>(samples);
	psg_render(psg_ptr, samples / 2);
    pcm_render(pcm_ptr, samples / 2);
    int16_t *out_ptr = out_buf.get_item_sized<int16_t>(samples);
    // The exact amount of samples needed for the stream.
    double ratio = ((double)YM_FREQ) / ((double)PSG_FREQ);
    size_t needed_samples = ((size_t)std::floor(samples * ratio)) / 2;
    int16_t *ym_ptr = ym_buf.get_item_sized<int16_t>(needed_samples * 2);
    YM_stream_update(ym_ptr, needed_samples);
    SDL_AudioStreamPut(fm_stream, ym_ptr, needed_samples * 2 * sizeof(int16_t));
    while (SDL_AudioStreamAvailable(fm_stream) < (samples * sizeof(int16_t))) {
        YM_stream_update(ym_ptr, 8);
        SDL_AudioStreamPut(fm_stream, ym_ptr, 8 * 2 * sizeof(int16_t));
    }
    int16_t *ym_resample_ptr = ym_resample_buf.get_item_sized<int16_t>(samples);
    ssize_t ym_resample_len = SDL_AudioStreamGet(fm_stream, ym_resample_ptr, (samples + 2) * sizeof(int16_t));
    assert(ym_resample_len >= 0);
    ym_resample_len /= sizeof(int16_t);
    for (size_t i = 0; i < samples / 2; i++) {
        size_t j = i * 2;
        int16_t psg[2] = {(int16_t)(psg_ptr[j] >> 1), (int16_t)(psg_ptr[j + 1] >> 1)};
        int16_t pcm[2] = {(int16_t)(pcm_ptr[j] >> 2), (int16_t)(pcm_ptr[j + 1] >> 2)};
        if (!pcm_enable()) memset(pcm, 0, sizeof(pcm));
        if (!psg_enable()) memset(psg, 0, sizeof(psg));
        pcm[0] *= pcm_volume();
        pcm[1] *= pcm_volume();
        psg[0] *= psg_volume();
        psg[1] *= psg_volume();
        int16_t vera[2] = {(int16_t)(psg[0] + pcm[0]), (int16_t)(psg[1] + pcm[1])};
        int16_t fm[2] = {ym_resample_ptr[j], ym_resample_ptr[j + 1]};
        if (!fm_enable()) memset(fm, 0, sizeof(fm));
        fm[0] *= fm_volume();
        fm[1] *= fm_volume();
        int16_t mix[2] = {(int16_t)(vera[0] + (fm[0] >> 1)), (int16_t)(vera[1] + (fm[1] >> 1))};
        out_ptr[j++] = mix[0];
        out_ptr[j++] = mix[1];
    }
    audio_buf.push(out_ptr, samples);
}
void ZsmBackend::add_licenses() {
	auto &license_data = get_license_data();
	auto x16emu = LicenseData("x16emu", "bsd-2-clause");
	LOAD_LICENSE(x16emu, x16emu);
	license_data.insert(x16emu);
}