#include "PCH.h"
#include "ScreenCapture.h"
#include "logger.h"

namespace ScreenCapture {

    // -- your existing helpers: StringToImageFormat, StringToDDSCompression, etc. --

    // Forward declarations of your helper functions:
    HRESULT SaveToWICFile(const DirectX::ScratchImage& image, const CaptureParams& params, const std::wstring& filepath);
    CaptureResult GIF(const CaptureParams& params);

    // Main entrypoint
    CaptureResult CaptureScreen(const CaptureParams& params) {
        switch (params.format) {
            case ImageFormat::GIF:
                return GIF(params);
            default:
                // Other formats: call your generic WIC save
                {
                    CaptureResult result;
                    HRESULT hr = CoInitializeEx(nullptr, COINIT_APARTMENTTHREADED);
                    bool comInit = SUCCEEDED(hr);

                    DirectX::ScratchImage capture;
                    hr = CaptureSingleFrame(nullptr, nullptr, nullptr, capture);
                    if (FAILED(hr)) {
                        result.message = "Failed to capture frame";
                        if (comInit) CoUninitialize();
                        return result;
                    }

                    result.filepath = GenerateFilename(params.basePath, params.format);
                    hr = SaveToWICFile(capture, params, result.filepath);
                    result.success = SUCCEEDED(hr);
                    result.message = SUCCEEDED(hr) ? "Success" : "Failed to save WIC file";

                    if (comInit) CoUninitialize();
                    return result;
                }
        }
    }

    // --------------------------------------------------
    // Single-frame WIC save (PNG, JPG, BMP, TIF, GIF)
    // --------------------------------------------------
    HRESULT SaveToWICFile(
        const DirectX::ScratchImage& image,
        const CaptureParams& params,
        const std::wstring& filepath)
    {
        // Initialize WIC
        Microsoft::WRL::ComPtr<IWICImagingFactory> wicFactory;
        RETURN_IF_FAILED(CoCreateInstance(
            CLSID_WICImagingFactory, nullptr, CLSCTX_INPROC_SERVER,
            IID_PPV_ARGS(&wicFactory)
        ));

        // Create encoder for chosen format
        Microsoft::WRL::ComPtr<IWICBitmapEncoder> encoder;
        GUID container = GetWICCodec(params.format);
        RETURN_IF_FAILED(wicFactory->CreateEncoder(container, nullptr, &encoder));

        // Create output stream
        Microsoft::WRL::ComPtr<IWICStream> stream;
        RETURN_IF_FAILED(wicFactory->CreateStream(&stream));
        RETURN_IF_FAILED(stream->InitializeFromFilename(filepath.c_str(), GENERIC_WRITE));
        RETURN_IF_FAILED(encoder->Initialize(stream.Get(), WICBitmapEncoderNoCache));

        // For single-frame GIF, set logical screen background color index = 0
        if (params.format == ImageFormat::GIF) {
            Microsoft::WRL::ComPtr<IWICMetadataQueryWriter> meta;
            if (SUCCEEDED(encoder->GetMetadataQueryWriter(&meta))) {
                PROPVARIANT pv; PropVariantInit(&pv);
                pv.vt    = VT_UI1;
                pv.bVal  = 0;
                meta->SetMetadataByName(L"/logscrdesc/BackgroundColorIndex", &pv);
                PropVariantClear(&pv);
            }
        }

        // Create frame
        Microsoft::WRL::ComPtr<IWICBitmapFrameEncode> frame;
        Microsoft::WRL::ComPtr<IPropertyBag2> props;
        RETURN_IF_FAILED(encoder->CreateNewFrame(&frame, &props));

        // JPEG quality or TIFF compression
        if (params.format == ImageFormat::JPEG && props) {
            PROPBAG2 bag = {}; bag.pstrName = const_cast<LPOLESTR>(L"ImageQuality");
            VARIANT var; VariantInit(&var);
            var.vt    = VT_R4;
            var.fltVal = params.jpegQuality / 100.0f;
            props->Write(1, &bag, &var);
            VariantClear(&var);
        }
        else if (params.format == ImageFormat::TIF && props) {
            PROPBAG2 bag = {}; bag.pstrName = const_cast<LPOLESTR>(L"TiffCompressionMethod");
            VARIANT var; VariantInit(&var);
            var.vt   = VT_UI1;
            var.bVal = static_cast<BYTE>(params.tiffMode);
            props->Write(1, &bag, &var);
            VariantClear(&var);
        }

        RETURN_IF_FAILED(frame->Initialize(props.Get()));

        // Grab the image data
        const DirectX::Image* img = image.GetImage(0, 0, 0);
        if (!img) return E_FAIL;

        UINT width  = static_cast<UINT>(img->width);
        UINT height = static_cast<UINT>(img->height);
        RETURN_IF_FAILED(frame->SetSize(width, height));

        // Handle GIF differently: must be 8-bit indexed
        if (params.format == ImageFormat::GIF) {
            // 1) Convert RGBA → BGRA byte array
            size_t size = width * height * 4;
            std::vector<uint8_t> bgra(size);
            ConvertRGBAToBGRA(img, bgra.data());

            // 2) Wrap in WICBitmap
            Microsoft::WRL::ComPtr<IWICBitmap> bmp;
            RETURN_IF_FAILED(wicFactory->CreateBitmapFromMemory(
                width, height,
                GUID_WICPixelFormat32bppBGRA,
                width * 4, static_cast<UINT>(size),
                bgra.data(), &bmp
            ));

            // 3) Convert to 8bpp indexed
            Microsoft::WRL::ComPtr