from PIL import Image, ImageOps
import io
import os

black_base = [0, 0, 0]
white_base = [255, 255, 255]
green_base = [0, 255, 0]
blue_base = [0, 0, 255]
red_base = [255, 0, 0]
yellow_base = [255, 255, 0]
orange_base = [255, 128, 0]


# leaves roughly 42 shades of each color
def int_mult(list, factor):
    return [int(x * factor) for x in list]

whites = [int_mult(white_base, x/42) for x in range(1,43)]
greens = [int_mult(green_base, x/42) for x in range(1,43)]
blues = [int_mult(blue_base, x/42) for x in range(1,43)]
reds = [int_mult(red_base, x/42) for x in range(1,43)]
yellows = [int_mult(yellow_base, x/42) for x in range(1,43)]
oranges = [int_mult(orange_base, x/42) for x in range(1,43)]

palette_unflat = [black_base, *whites, *greens, *blues, *reds, *yellows, *oranges]
palette = [item for sublist in palette_unflat for item in sublist]

ref_image = Image.new("P", (1,1))
ref_image.putpalette(palette)


class ImageShrink:
    """Shrinks a given image (bytearray) to a given resolution (width, height)"""
    resolution = (480, 800)

    def __init__(self) -> None:
        pass

    def convert(self, image: bytearray) -> Image:
        # load image from bytearray
        image = Image.open(io.BytesIO(image))
        image = self.shrink(image)
        image = self.convert_to_reduced_colors(image)
        if os.uname().machine == "x86_64":
            image.save("test.png")
        return image


    def shrink(self, image: Image) -> Image:
        """"Shrinks a given image (bytearray) to a given resolution (width, height)"""
        # image = ImageOps.contain(image, self.resolution)
        # maintains the aspect ratio of the original image and leaves empty space
        image = ImageOps.fit(image, self.resolution, centering=(0.5, 0.5))
        # crops the image to the given resolution from the center on

        # image.thumbnail(self.resolution)
        return image


    def convert_to_reduced_colors(self, image: Image) -> Image:
        # convert image to RGB if it's not first
        if image.mode != "RGB":
            print("Converting image to RGB")
            image = image.convert("RGB")
        new_image = image.quantize(colors = len(palette), palette=ref_image, dither=True)
        return new_image