Skip to content

ai module

Genie (VBox)

A widget for interacting with the Genie AI model.

The source code is adapted from the ee_genie.ipynb at https://bit.ly/3YEm7B6. Credit to the original author Simon Ilyushchenko (https://github.com/simonff).

Parameters:

Name Type Description Default
project Optional[str]

Google Cloud project ID. Defaults to None.

None
google_api_key Optional[str]

Google API key. Defaults to None.

None
gemini_model str

The Gemini model to use. Defaults to "gemini-1.5-flash". For a list of available models, see https://bit.ly/4fKfXW7.

'gemini-1.5-flash'
target_score float

The target score for the model. Defaults to 0.8.

0.8
widget_height str

The height of the widget. Defaults to "600px".

'600px'
initialize_ee bool

Whether to initialize Earth Engine. Defaults to True.

True

Exceptions:

Type Description
ValueError

If the project ID or Google API key is not provided.

Source code in geemap/ai.py
class Genie(widgets.VBox):
    """A widget for interacting with the Genie AI model.

    The source code is adapted from the ee_genie.ipynb at <https://bit.ly/3YEm7B6>.
    Credit to the original author Simon Ilyushchenko (<https://github.com/simonff>).

    Args:
        project (Optional[str], optional): Google Cloud project ID. Defaults to None.
        google_api_key (Optional[str], optional): Google API key. Defaults to None.
        gemini_model (str, optional): The Gemini model to use. Defaults to "gemini-1.5-flash".
            For a list of available models, see https://bit.ly/4fKfXW7.
        target_score (float, optional): The target score for the model. Defaults to 0.8.
        widget_height (str, optional): The height of the widget. Defaults to "600px".
        initialize_ee (bool, optional): Whether to initialize Earth Engine. Defaults to True.

    Raises:
        ValueError: If the project ID or Google API key is not provided.
    """

    def __init__(
        self,
        project: Optional[str] = None,
        google_api_key: Optional[str] = None,
        gemini_model: str = "gemini-1.5-flash",
        target_score: float = 0.8,
        widget_height: str = "600px",
        initialize_ee: bool = True,
    ) -> None:
        # Initialization

        if project is None:
            project = get_api_key("EE_PROJECT_ID") or get_api_key("GOOGLE_PROJECT_ID")
        if project is None:
            raise ValueError(
                "Please provide a valid project ID via the 'project' parameter."
            )

        if google_api_key is None:
            google_api_key = get_api_key("GOOGLE_API_KEY")
        if google_api_key is None:
            raise ValueError(
                "Please provide a valid Google API key via the 'google_api_key' parameter."
            )

        if initialize_ee:
            ee_initialize(project=project)

        genai.configure(api_key=google_api_key)
        storage_client = storage.Client(project=project)
        bucket = storage_client.get_bucket("earthengine-stac")

        # Score to aim for (on the 0-1 scale). The exact meaning of what "score" means
        # is left to the LLM.

        # Count of analysis rounds

        self.iteration = 1
        self.map_dirty = False

        m = Map()
        m.add("layer_manager")
        self.map = m

        analysis_model = None

        image_model = genai.GenerativeModel(gemini_model)

        # UI widget definitions

        # We define the widgets early because some functions will write to the debug
        # and/or chat panels.

        command_input = widgets.Text(
            value="show a whole continent Australia DEM visualization using a palette that captures the elevation range",
            # value='show NYC',
            # value='show an area with many center pivot irrigation circles',
            # value='show a fire scar',
            # value='show an open pit mine',
            # value='a sea port',
            # value='flood consequences',
            # value='show an interesting modis composite with the relevant visualization and analyze it over Costa Rica',
            description="❓",
            layout=widgets.Layout(width="100%", height="50px"),
        )

        command_output = widgets.Label(
            value="Last command will be here",
        )

        status_label = widgets.Textarea(
            value="LLM response will be here",
            layout=widgets.Layout(width="50%", height="100px"),
        )

        # widget_height = "600px"
        debug_output = widgets.Output(
            layout={
                "border": "1px solid black",
                "height": widget_height,
                "overflow": "scroll",
                "width": "500px",
                "padding": "5px",
            }
        )
        with debug_output:
            print("DEBUG COLUMN\n")

        logo = requests.get(
            "https://drive.usercontent.google.com/download?id=1zE6G_nxXa3G5N0G_32jEhzdum2kMDfNY&export=view&authuser=0"
        ).content

        image_widget = widgets.Image(value=logo, format="png", width=400, height=600)

        chat_output = widgets.Output(
            layout={
                "border": "1px solid black",
                "height": "600px",
                "overflow": "scroll",
                "width": "300px",
            }
        )

        with chat_output:
            print("CHAT COLUMN\n")

        # Simple functions that LLM will call

        def set_center(x: float, y: float, zoom: int) -> str:
            """Sets the map center to the given coordinates and zoom level and
            returns instructions on what to do next."""
            with debug_output:
                print(f"SET_CENTER({x}, {y}, {zoom})\n")
            m.set_center(x, y)
            m.zoom = zoom
            # global map_dirty
            self.map_dirty = True
            return (
                "Do not call any more functions in this request to let geemap bounds "
                "update. Wait for user input."
            )

        def add_image_layer(image_id: str) -> str:
            """Adds to the map center an ee.Image with the given id
            and returns status message (success or failure)."""
            m.clear()
            command_output.value = f"add_image_layer('{image_id}')"
            m.addLayer(ee.Image(image_id))
            return "success"

        def get_dataset_description(dataset_id: str) -> str:
            """Fetches JSON STAC description for the given Earth Engine dataset id."""
            with debug_output:
                print(f"LOOKING UP {dataset_id}\n")
            parent = dataset_id.split("/")[0]

            # Get the blob (file)
            path = (
                os.path.join("catalog", parent, dataset_id.replace("/", "_")) + ".json"
            )
            blob = bucket.blob(path)

            if not blob.exists():
                return "dataset file not found: " + path

            file_contents = blob.download_as_string().decode()
            return file_contents

        def get_image(image_url: str) -> bytes:
            """Fetches from Earth Engine the content of the given URL as bytes."""
            response = requests.get(image_url)

            if response.status_code == 200:
                image_widget.value = response.content
                return response.content
            else:
                error_message = f"Error downloading image: {response}"
                try:
                    error_details = (
                        json.loads(response.content.decode())
                        .get("error", {})
                        .get("message")
                    )
                    if error_details:
                        error_message += f" - {error_details}"
                except json.JSONDecodeError:
                    pass
                with debug_output:
                    print(error_message)
                raise ValueError("URL %s causes %s" % (image_url, error_message))

        def show_layer(python_code: str) -> str:
            """Execute the given Earth Engine Python client code and add the result to
            the map. Returns the status message (success or error message)."""
            m.layers = m.layers[:2]
            while '\\"' in python_code:
                python_code = python_code.replace('\\"', '"')
            command_output.value = f"show_layer('{python_code}')"
            with debug_output:
                print(f"IMAGE:\n {python_code}\n")
            try:
                locals = {}
                exec(f"import ee; im = {python_code}", {}, locals)
                m.addLayer(locals["im"])
            except Exception as e:
                with debug_output:
                    print(f"ERROR: {e}")
                return str(e)
            return "success"

        def inner_monologue(thoughts: str) -> str:
            """Sends the current thinking of the LLM model to the user so that they are
            aware of what the model is thinking between function calls."""
            with debug_output:
                print(f"THOUGHTS:\n {thoughts}\n")
            return "success"

        # Functions for textual analysis of images

        def _lat_lon_to_tile(lon, lat, zoom_level):
            # Convert latitude and longitude to Mercator coordinates
            x_merc = (lon + 180) / 360
            y_merc = (
                1
                - math.log(
                    math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat))
                )
                / math.pi
            ) / 2

            # Calculate number of tiles
            n = 2**zoom_level

            # Convert to tile coordinates
            X = int(x_merc * n)
            Y = int(y_merc * n)

            return X, Y

        def analyze_image(additional_instructions: str = "") -> str:
            """Returns GenAI image analysis describing the current map image.
            Optional additional instructions might be passed to target the analysis
            more precisely.
            """
            # global map_dirty
            if self.map_dirty:
                print("MAP DIRTY")
                return (
                    "Map is not ready. Stop further processing and ask for user input"
                )

            try:
                return _analyze_image(additional_instructions)
            except ValueError as e:
                return str(e)

        def _analyze_image(additional_instructions: str = "") -> str:
            # bounds = m.bounds
            # s, w = bounds[0]
            # n, e = bounds[1]
            # zoom = int(m.zoom)

            # min_tile_x, max_tile_y = _lat_lon_to_tile(w, s, zoom)
            # max_tile_x, min_tile_y = _lat_lon_to_tile(e, n, zoom)
            # min_tile_x = max(0, min_tile_x)
            # max_tile_x = min(2**zoom - 1, max_tile_x)
            # min_tile_y = max(0, min_tile_y)
            # max_tile_y = min(2**zoom - 1, max_tile_y)

            # with debug_output:
            #     if additional_instructions:
            #         print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n")
            #     else:
            #         print("RUNNING IMAGE ANALYSIS...\n")

            # layers = list(m.ee_layer_dict.values())
            # if not layers:
            #     return "No data layers loaded"
            # url_template = layers[-1]["ee_layer"].url
            # tile_width = 256
            # tile_height = 256
            # image_width = (max_tile_x - min_tile_x + 1) * tile_width
            # image_height = (max_tile_y - min_tile_y + 1) * tile_height

            # # Create a new blank image
            # image = PIL.Image.new("RGB", (image_width, image_height))

            # for y in range(min_tile_y, max_tile_y + 1):
            #     for x in range(min_tile_x, max_tile_x + 1):
            #         tile_url = str.format(url_template, x=x, y=y, z=zoom)
            #         # print(tile_url)
            #         tile_img = PIL.Image.open(io.BytesIO(get_image(tile_url)))

            #         offset_x = (x - min_tile_x) * tile_width
            #         offset_y = (y - min_tile_y) * tile_height
            #         image.paste(tile_img, (offset_x, offset_y))

            # width, height = image.size
            # num_bands = len(image.getbands())

            with debug_output:
                if additional_instructions:
                    print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n")
                else:
                    print("RUNNING IMAGE ANALYSIS...\n")

            layers = list(m.ee_layer_dict.values())
            if not layers:
                return "No data layers loaded"
            image_temp_file = temp_file_path(extension="jpg")
            layer_name = layers[-1]["ee_layer"].name
            m.layer_to_image(layer_name, output=image_temp_file, scale=m.get_scale())
            image = PIL.Image.open(image_temp_file)

            image_array = np.array(image)
            image_min = np.min(image_array)
            image_max = np.max(image_array)

            file = open(image_temp_file, "rb")
            image_widget.value = file.read()
            file.close()

            # Skip an LLM call when we can simply tell that something is wrong.
            # (Also, LLMs might hallucinate on uniform images.)
            if image_min == image_max:
                return (
                    f"The image tile has a single uniform color with value "
                    f"{image_min}."
                )

            query = """You are an objective, precise overhead imagery analyst.
        Describe what the provided map tile depicts in terms of:

        1. The colors, textures, and patterns visible in the image.
        2. The spatial distribution, shape, and extent of distinct features or regions.
        3. Any notable contrasts, boundaries, or gradients between different areas.

        Avoid making assumptions about the specific geographic location, time period,
        or cause of the observed features. Focus solely on the literal contents of the
        image itself. Clearly indicate which features look natural, which look human-made,
        and which look like image artifacts. (Eg, a completely straight blue line
        is unlikely to be a river.)

        If the image is ambiguous or unclear, state so directly. Do not speculate or
        hypothesize beyond what is directly visible.

        If the image is of mostly the same color (white, gray, or black) with little
        contrast, just report that and do not describe the features.

        Use clear, concise language. Avoid subjective interpretations or analogies.
        Organize your response into structured paragraphs.
        """
            if additional_instructions:
                query += additional_instructions
            req = {
                "parts": [
                    {"text": query},
                    {"inline_data": image},
                ]
            }
            image_response = image_model.generate_content(req)
            try:
                with debug_output:
                    print(f"ANALYSIS RESULT: {image_response.text}\n")
                return image_response.text
            except ValueError as e:
                with debug_output:
                    print(f"UNEXPECTED IMAGE RESPONSE: {e}")
                    print(image_response)
                breakpoint()

        # Function for scoring how well image analysis corresponds to the user query.

        # Note that we ask for the score outside of the main agent chat to keep
        # the scoring more objective.

        scoring_system_prompt = """
        After looking at the user query and the map tile analysis, start
        your answer with a number between 0 and 1 indicating how relevant
        the image is as an answer to the query. (0=irrelevant, 1=perfect answer)

        Make sure you have enough justification to definitively declare the analysis
        relevant - it's better to give a false negative than a false positive. However,
        the image analysis identifies specific matching landmarks (eg, the
        the outlines of Manhattan island for a request to show NYC), believe it.

        Do not assume  too much (eg, that the presence of green doesn't by itself mean the
        image shows forest); attempt to find multiple (at least three) independent
        lines of evidence before declaring victory and cite all these lines of evidence
        in your response.

        Be very, very skeptical - look for specific features that match only the query
        and nothing else (eg, if the query looks for a river, a completely straight blue
        line is unlikely to be a river). Think about what size the features are based on
        the zoom level and whether this size matches the feature size expected from
        first principles.

        If there is ambiguity or uncertainty, express it in your analysis and
        lower the score accordingly. If the image analysis is inconclusive, try zooming
        out to make sure you are looking at the right spot. Do not reduce the score if
        the analysis does not mention visualization parameters - they are just given for
        your reference. The image might show an area slightly larger than requested -
        this is okay, do not reduce the score on this account.
        """

        def score_response(
            query: str, visualization_parameters: str, analysis: str
        ) -> str:
            """Returns how well the given analysis describes a map tile returned for
            the given query. The analysis starts with a number between 0 and 1.

            Arguments:
            query: user-specified query
            visualization_parameters: description of the bands used and visualization
                parameters applied to the map tile
            analysis: the textual description of the map tile
            """
            with debug_output:
                print(f"VIZ PARAMS: {visualization_parameters}\n")
            question = f"""For user query {query} please score the following analysis:
            {analysis}. The answer must start with a number between 0 and 1."""
            if visualization_parameters:
                question += f"""Do not assume that common bands or visualization
                parameters should have been used, as the visualization used the
                following parameters: {visualization_parameters}"""

            result = analysis_model.ask(question)
            # global iteration
            with debug_output:
                print(f"SCORE #{self.iteration}:\n {result}\n")
            try:
                self.iteration += 1
            except Exception as e:
                with debug_output:
                    print(f"UNEXPECTED SCORE RESPONSE: {e}")
            return result

        # Main prompt for the agent

        system_prompt = f"""
        The client is running in a Python notebook with a geemap Map displayed.
        When composing Python code, do not use getMapId - just return the single-line
        layer definition like 'ee.Image("USGS/SRTMGL1_003")' that we will pass to
        Map.addLayer(). Do not escape quotation marks in Python code.

        Be sure to use Python, not Javascript, syntax for keyword parameters in
        Python code (that is, "function(arg=value)") Using the provided functions,
        respond to the user command following below (or respond why it's not possible).
        If you get an Earth Engine error, attempt to fix it and then try again.

        Before you choose a dataset, think about what kind of dataset would be most
        suitable for the query. Also think about what zoom level would be suitable for
        the query, keeping in mind that for high-resolution image collections higher
        zoom levels are better to speed up tile loading.

        Once you have chosen a dataset, read its description using the provided function
        to see what spatial and temporal range it covers, what bands it has, as well as
        to find the recommended visualization parameters. Explain using the inner
        monlogue function why you chose a specific dataset, zoom level and map location.

        Prefer mosaicing image collections using the mosaic() function, don't get
        individual images from collections via
        'first()'. Choose a tile size and zoom level that will ensure the
        tile has enough pixels in it to avoid graininess, but not so many
        that processing becomes very expensive. Do not use wide date ranges
        with collections that have many images, but remember that Landsat and
        Sentinel-2 have revisit period of several days. Do not use sample
        locations - try to come up with actual locations that are relevant to
        the request.

        Use Landsat Collection 2, not Landsat Collection 1 ids. If you are getting
        repeated errors when filtering by a time range, read the dataset description
        to confirm that the dataset has data for the selected range.

        Important: after using the set_center() function, just say that you have called
        this function and wait for the user to hit enter, after which you should
        continue answering the original request. This will make sure the map is updated
        on the client side.

        Once the map is updated and the user told you to proceed, call the analyze_image
        function() to describe the image for the same location that will be shown in
        geemap. If you pass additional instructions to analyze_image(), do not disclose
        what the image is supposed to be to discourage hallucinations - you can only tell
        the analysis function to pay attention to specific areas (eg, center or top left)
        or shapes (eg, a line at the bottom) in the image. You can also tell the analysis
        function about the chosen bands, color palette and min/max visualization
        parameters, if any, to help it interpret the colors correctly. If the image
        turns out to be uniform in color with no features,
        use min/max visualization parameters to enhance contrast.

        Frequently call the inner_monologue() functions to tell the user about your
        current thought process. This is a good time to reflect if you have been running
        into repeated errors of the same kind, and if so, to try a different approach.

        When you are done, call the score_response() function to evaluate the analysis.
        You can also tell the scoring function about the chosen bands, color palette
        and min/max visualization parameters, if any. If the analysis score is below
        {target_score},
        keep trying to find and show a better image. You might have to change the dataset,
        map location, zoom level, date range, bands, or other parameters - think about
        what went wrong in the previous attempt and make the change that's most likely
        to improve the score.
        """

        # Class for LLM chat with function calling

        gemini_tools = [
            set_center,
            show_layer,
            analyze_image,
            inner_monologue,
            get_dataset_description,
            score_response,
        ]

        class Gemini:
            """Gemini LLM."""

            def __init__(
                self, system_prompt, tools=None, model_name="gemini-1.5-pro-latest"
            ):
                if not tools:
                    tools = []
                self._text_model = genai.GenerativeModel(
                    model_name=model_name, tools=tools
                )

                initial_messages = glm.Content(
                    role="model", parts=[glm.Part(text=system_prompt)]
                )
                self._chat_proxy = self._text_model.start_chat(
                    history=initial_messages, enable_automatic_function_calling=True
                )

            def ask(self, question, temperature=0):
                while True:
                    condition = ""
                    try:
                        sleep_duration = 10
                        response = self._text_model.generate_content(
                            question + condition
                        )
                        return response.text
                    except genai.types.generation_types.StopCandidateException as e:
                        if (
                            glm.Candidate.FinishReason.RECITATION
                            == e.args[0].finish_reason
                        ):
                            condition = (
                                "Previous attempt returned a RECITATION error. "
                                "Rephrase the answer to avoid it."
                            )
                        with chat_output:
                            command_input.description = "🆁"
                        time.sleep(1)
                        with chat_output:
                            command_input.description = "🤔"
                        continue
                    except (
                        google.api_core.exceptions.TooManyRequests,
                        google.api_core.exceptions.DeadlineExceeded,
                    ):
                        with debug_output:
                            command_input.description = "💤"
                        time.sleep(sleep_duration)
                        continue
                    except ValueError as e:
                        with debug_output:
                            print(f"Response {response} led to error: {e}")
                        breakpoint()
                        i = 1

            def chat(self, question: str, temperature=0) -> str:
                """Adds a question to the ongoing chat session."""
                # Always delay a bit to reduce the chance for rate-limiting errors.
                time.sleep(1)
                condition = ""
                sleep_duration = 10
                while True:
                    response = ""
                    try:
                        response = self._chat_proxy.send_message(
                            question + condition,
                            generation_config={
                                "temperature": temperature,
                                # Use a generous but limited output size to encourage in-depth
                                # replies.
                                "max_output_tokens": 5000,
                            },
                        )
                        if not response.parts:
                            raise ValueError(
                                "Cannot get analysis with reason"
                                f" {response.candidates[0].finish_reason.name}, terminating"
                            )
                    except genai.types.generation_types.StopCandidateException as e:
                        if (
                            glm.Candidate.FinishReason.RECITATION
                            == e.args[0].finish_reason
                        ):
                            condition = (
                                "Previous attempt returned a RECITATION error. "
                                "Rephrase the answer to avoid it."
                            )
                        with chat_output:
                            command_input.description = "🆁"
                        time.sleep(1)
                        with chat_output:
                            command_input.description = "🤔"
                        continue
                    except (
                        google.api_core.exceptions.TooManyRequests,
                        google.api_core.exceptions.DeadlineExceeded,
                    ):
                        with debug_output:
                            command_input.description = "💤"
                        time.sleep(10)
                        continue
                    try:
                        return response.text
                    except ValueError as e:
                        with debug_output:
                            print(f"Response {response} led to the error {e}")

        model = Gemini(system_prompt, gemini_tools, model_name=gemini_model)
        analysis_model = Gemini(scoring_system_prompt, model_name=gemini_model)

        # UI functions

        def set_cursor_waiting():
            js_code = """
            document.querySelector('body').style.cursor = 'wait';
            """
            display(HTML(f"<script>{js_code}</script>"))

        def set_cursor_default():
            js_code = """
            document.querySelector('body').style.cursor = 'default';
            """
            display(HTML(f"<script>{js_code}</script>"))

        def on_submit(widget):
            # global map_dirty
            self.map_dirty = False
            command_input.description = "❓"
            command = widget.value
            if not command:
                command = "go on"
            with chat_output:
                print("> " + command + "\n")
            if command != "go on":
                with debug_output:
                    print("> " + command + "\n")
            widget.value = ""
            set_cursor_waiting()
            command_input.description = "🤔"
            response = model.chat(command, temperature=0)
            if self.map_dirty:
                command_input.description = "🙏"
            else:
                command_input.description = "❓"
            set_cursor_default()
            response = response.strip()
            if not response:
                response = "<EMPTY RESPONSE, HIT ENTER>"
            with chat_output:
                print(response + "\n")
            command_input.value = ""

        command_input.on_submit(on_submit)

        # UI layout

        # Arrange the chat history and input in a vertical box
        chat_ui = widgets.VBox(
            [image_widget, chat_output],
            layout=widgets.Layout(width="420px", height=widget_height),
        )

        chat_output.layout = widgets.Layout(
            width="400px"
        )  # Fixed width for the left control
        m.layout = widgets.Layout(flex="1 1 auto", height=widget_height)

        table = widgets.HBox(
            [chat_ui, debug_output, m], layout=widgets.Layout(align_items="flex-start")
        )

        message_widget = widgets.Output()
        with message_widget:
            print("❓ = waiting for user input")
            print("🙏 = waiting for user to hit enter after calling set_center()")
            print("🤔 = thinking")
            print("💤 = sleeping due to retries")
            print("🆁 = Gemini recitation error")

        super().__init__(
            [table, command_input, message_widget], layout={"overflow": "hidden"}
        )