Skip to content

ai module

This module contains functions for interacting with AI models. The Genie class source code is adapted from the ee_genie.ipynb at https://bit.ly/3YEm7B6. Credit to the original author Simon Ilyushchenko (https://github.com/simonff). The DataExplorer class source code is adapted from https://bit.ly/48cE24D. Credit to the original author Renee Johnston (https://github.com/raj02006)

BBox dataclass

Class representing a lat/lon bounding box.

Source code in geemap/ai.py
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
@dataclasses.dataclass
class BBox:
    """Class representing a lat/lon bounding box."""

    west: float
    south: float
    east: float
    north: float

    def is_global(self) -> bool:
        """Checks if the bounding box is global.

        Returns:
            bool: True if the bounding box is global, False otherwise.
        """
        return (
            self.west == -180
            and self.south == -90
            and self.east == 180
            and self.north == 90
        )

    @classmethod
    def from_list(cls, bbox_list: List[float]) -> "BBox":
        """Constructs a BBox from a list of four numbers [west, south, east, north].

        Args:
            bbox_list (List[float]): List of four numbers representing the bounding box.

        Returns:
            BBox: The constructed BBox object.

        Raises:
            ValueError: If the coordinates are not in the correct order.
        """
        if bbox_list[0] > bbox_list[2]:
            raise ValueError(
                "The smaller (west) coordinate must be listed first in a bounding box"
                f" corner list. Found {bbox_list}"
            )
        if bbox_list[1] > bbox_list[3]:
            raise ValueError(
                "The smaller (south) coordinate must be listed first in a bounding"
                f" box corner list. Found {bbox_list}"
            )
        return cls(bbox_list[0], bbox_list[1], bbox_list[2], bbox_list[3])

    def to_list(self) -> List[float]:
        """Converts the BBox to a list of four numbers [west, south, east, north].

        Returns:
            List[float]: List of four numbers representing the bounding box.
        """
        return [self.west, self.south, self.east, self.north]

    def intersects(self, query_bbox: "BBox") -> bool:
        """Checks if this bbox intersects with the query bbox.

        Doesn't handle bboxes extending past the antimeridian.

        Args:
            query_bbox (BBox): Bounding box from the query.

        Returns:
            bool: True if the two bounding boxes intersect, False otherwise.
        """
        return (
            query_bbox.west < self.east
            and query_bbox.east > self.west
            and query_bbox.south < self.north
            and query_bbox.north > self.south
        )

from_list(bbox_list) classmethod

Constructs a BBox from a list of four numbers [west, south, east, north].

Parameters:

Name Type Description Default
bbox_list List[float]

List of four numbers representing the bounding box.

required

Returns:

Name Type Description
BBox BBox

The constructed BBox object.

Raises:

Type Description
ValueError

If the coordinates are not in the correct order.

Source code in geemap/ai.py
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
@classmethod
def from_list(cls, bbox_list: List[float]) -> "BBox":
    """Constructs a BBox from a list of four numbers [west, south, east, north].

    Args:
        bbox_list (List[float]): List of four numbers representing the bounding box.

    Returns:
        BBox: The constructed BBox object.

    Raises:
        ValueError: If the coordinates are not in the correct order.
    """
    if bbox_list[0] > bbox_list[2]:
        raise ValueError(
            "The smaller (west) coordinate must be listed first in a bounding box"
            f" corner list. Found {bbox_list}"
        )
    if bbox_list[1] > bbox_list[3]:
        raise ValueError(
            "The smaller (south) coordinate must be listed first in a bounding"
            f" box corner list. Found {bbox_list}"
        )
    return cls(bbox_list[0], bbox_list[1], bbox_list[2], bbox_list[3])

intersects(query_bbox)

Checks if this bbox intersects with the query bbox.

Doesn't handle bboxes extending past the antimeridian.

Parameters:

Name Type Description Default
query_bbox BBox

Bounding box from the query.

required

Returns:

Name Type Description
bool bool

True if the two bounding boxes intersect, False otherwise.

Source code in geemap/ai.py
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
def intersects(self, query_bbox: "BBox") -> bool:
    """Checks if this bbox intersects with the query bbox.

    Doesn't handle bboxes extending past the antimeridian.

    Args:
        query_bbox (BBox): Bounding box from the query.

    Returns:
        bool: True if the two bounding boxes intersect, False otherwise.
    """
    return (
        query_bbox.west < self.east
        and query_bbox.east > self.west
        and query_bbox.south < self.north
        and query_bbox.north > self.south
    )

is_global()

Checks if the bounding box is global.

Returns:

Name Type Description
bool bool

True if the bounding box is global, False otherwise.

Source code in geemap/ai.py
857
858
859
860
861
862
863
864
865
866
867
868
def is_global(self) -> bool:
    """Checks if the bounding box is global.

    Returns:
        bool: True if the bounding box is global, False otherwise.
    """
    return (
        self.west == -180
        and self.south == -90
        and self.east == 180
        and self.north == 90
    )

to_list()

Converts the BBox to a list of four numbers [west, south, east, north].

Returns:

Type Description
List[float]

List[float]: List of four numbers representing the bounding box.

Source code in geemap/ai.py
895
896
897
898
899
900
901
def to_list(self) -> List[float]:
    """Converts the BBox to a list of four numbers [west, south, east, north].

    Returns:
        List[float]: List of four numbers representing the bounding box.
    """
    return [self.west, self.south, self.east, self.north]

Catalog

Class containing all collections in the EE STAC catalog.

Source code in geemap/ai.py
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
class Catalog:
    """Class containing all collections in the EE STAC catalog."""

    collections: CollectionList

    def __init__(self, storage_client: storage.Client) -> None:
        """Initializes the Catalog with collections loaded from Google Cloud Storage.

        Args:
            storage_client (storage.Client): The Google Cloud Storage client.
        """
        self.collections = CollectionList(self._load_collections(storage_client))

    def get_collection(self, id: str) -> Collection:
        """Returns the collection with the given id.

        Args:
            id (str): The ID of the collection.

        Returns:
            Collection: The collection with the given ID.

        Raises:
            ValueError: If no collection with the given ID is found.
        """
        col = self.collections.filter_by_ids([id])
        if len(col) == 0:
            raise ValueError(f"No collection with id {id}")
        return col[0]

    @tenacity.retry(
        stop=tenacity.stop_after_attempt(5),
        wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
        retry=tenacity.retry_if_exception_type(
            (
                google_exceptions.GoogleAPICallError,
                google_exceptions.RetryError,
                ConnectionError,
            )
        ),
        before_sleep=lambda retry_state: print(
            f"Error occurred: {str(retry_state.outcome.exception())}\n"
            f"Retrying in {retry_state.next_action.sleep} seconds... "
            f"(Attempt {retry_state.attempt_number}/3)"
        ),
    )
    def _read_file(self, file_blob: storage.Blob) -> Collection:
        """Reads the contents of a file from the specified bucket.

        Args:
            file_blob (storage.Blob): The blob representing the file.

        Returns:
            Collection: The collection created from the file contents.
        """
        file_contents = file_blob.download_as_string().decode()
        return Collection(json.loads(file_contents))

    def _read_files(self, file_blobs: List[storage.Blob]) -> List[Collection]:
        """Processes files in parallel.

        Args:
            file_blobs (List[storage.Blob]): The list of file blobs.

        Returns:
            List[Collection]: The list of collections created from the file contents.
        """
        collections = []
        with futures.ThreadPoolExecutor(max_workers=10) as executor:
            file_futures = [
                executor.submit(self._read_file, file_blob) for file_blob in file_blobs
            ]
            for future in file_futures:
                collections.append(future.result())
        return collections

    def _load_collections(self, storage_client: storage.Client) -> Sequence[Collection]:
        """Loads all EE STAC JSON files from GCS, with datetimes as objects.

        Args:
            storage_client (storage.Client): The Google Cloud Storage client.

        Returns:
            Sequence[Collection]: A tuple of collections loaded from the files.
        """
        bucket = storage_client.get_bucket("earthengine-stac")
        files = [
            x
            for x in bucket.list_blobs(prefix="catalog/")
            if x.name.endswith(".json")
            and not x.name.endswith("/catalog.json")
            and not x.name.endswith("/units.json")
        ]
        logging.warning("Found %d files, loading...", len(files))
        collections = self._read_files(files)

        code_samples_dict = self._load_all_code_samples(storage_client)

        res = []
        for c in collections:
            if c.is_deprecated():
                continue
            c.stac_json["code"] = code_samples_dict.get(c.hyphen_id())
            res.append(c)
        logging.warning("Loaded %d collections (skipping deprecated ones)", len(res))
        # Returning a tuple for immutability.
        return tuple(res)

    def _load_all_code_samples(
        self, storage_client: storage.Client
    ) -> Dict[str, Dict[str, str]]:
        """Loads js + py example scripts from GCS into dict keyed by dataset ID.

        Args:
            storage_client (storage.Client): The Google Cloud Storage client.

        Returns:
            Dict[str, Dict[str, str]]: A dictionary mapping dataset IDs to their code samples.
        """

        # Get json file from GCS bucket
        # 'gs://earthengine-catalog/catalog/example_scripts.json'
        bucket = storage_client.get_bucket("earthengine-catalog")
        blob = bucket.blob("catalog/example_scripts.json")
        file_contents = blob.download_as_string().decode()
        data = json.loads(file_contents)

        # Flatten json to get a map from ID (using '_' rather than '/') to code
        # sample.
        all_datasets_by_provider = data[0]["contents"]
        code_samples_dict = {}
        for provider in all_datasets_by_provider:
            for dataset in provider["contents"]:
                js_code = dataset["code"]
                py_code = self._make_python_code_sample(js_code)

                code_samples_dict[dataset["name"]] = {
                    "js_code": js_code,
                    "py_code": py_code,
                }

        return code_samples_dict

    def _make_python_code_sample(self, js_code: str) -> str:
        """Converts EE JS code into python.

        Args:
            js_code (str): The JavaScript code to convert.

        Returns:
            str: The converted Python code.
        """

        # geemap appears to have some stray print statements.
        _ = io.StringIO()
        with redirect_stdout(_):
            code_list = js_snippet_to_py(
                js_code,
                add_new_cell=False,
                import_ee=False,
                import_geemap=False,
                show_map=False,
            )
        return "".join(code_list)

__init__(storage_client)

Initializes the Catalog with collections loaded from Google Cloud Storage.

Parameters:

Name Type Description Default
storage_client Client

The Google Cloud Storage client.

required
Source code in geemap/ai.py
1322
1323
1324
1325
1326
1327
1328
def __init__(self, storage_client: storage.Client) -> None:
    """Initializes the Catalog with collections loaded from Google Cloud Storage.

    Args:
        storage_client (storage.Client): The Google Cloud Storage client.
    """
    self.collections = CollectionList(self._load_collections(storage_client))

get_collection(id)

Returns the collection with the given id.

Parameters:

Name Type Description Default
id str

The ID of the collection.

required

Returns:

Name Type Description
Collection Collection

The collection with the given ID.

Raises:

Type Description
ValueError

If no collection with the given ID is found.

Source code in geemap/ai.py
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
def get_collection(self, id: str) -> Collection:
    """Returns the collection with the given id.

    Args:
        id (str): The ID of the collection.

    Returns:
        Collection: The collection with the given ID.

    Raises:
        ValueError: If no collection with the given ID is found.
    """
    col = self.collections.filter_by_ids([id])
    if len(col) == 0:
        raise ValueError(f"No collection with id {id}")
    return col[0]

Collection

A simple wrapper for a STAC Collection..

Source code in geemap/ai.py
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
class Collection:
    """A simple wrapper for a STAC Collection.."""

    stac_json: dict[str, Any]

    def __init__(self, stac_json: Dict[str, Any]) -> None:
        """Initializes the Collection.

        Args:
            stac_json (Dict[str, Any]): The STAC JSON of the collection.
        """
        self.stac_json = stac_json
        if stac_json.get("gee:status") == "deprecated":
            # Set the STAC 'deprecated' field that we don't set in the jsonnet files
            stac_json["deprecated"] = True

    def __getitem__(self, item: str) -> Any:
        """Gets an item from the STAC JSON.

        Args:
            item (str): The key of the item to get.

        Returns:
            Any: The value of the item.
        """
        return self.stac_json[item]

    def get(self, item: str, default: Optional[Any] = None) -> Optional[Any]:
        """Matches dict's get by returning None if there is no item.

        Args:
            item (str): The key of the item to get.
            default (Optional[Any]): The default value to return if the item is not found. Defaults to None.

        Returns:
            Optional[Any]: The value of the item or the default value.
        """
        return self.stac_json.get(item, default)

    def public_id(self) -> str:
        """Gets the public ID of the collection.

        Returns:
            str: The public ID of the collection.
        """
        return self["id"]

    def hyphen_id(self) -> str:
        """Gets the hyphenated ID of the collection.

        Returns:
            str: The hyphenated ID of the collection.
        """
        return self["id"].replace("/", "_")

    def get_dataset_type(self) -> str:
        """Gets the dataset type of the collection.

        Returns:
            str: The dataset type of the collection.
        """
        return self["gee:type"]

    def is_deprecated(self) -> bool:
        """Checks if the collection is deprecated or has a successor.

        Returns:
            bool: True if the collection is deprecated or has a successor, False otherwise.
        """
        if self.get("deprecated", False):
            logging.info("Skipping deprecated collection: %s", self.public_id())
            return True

    def datetime_interval(
        self,
    ) -> Iterable[Tuple[datetime.datetime, Optional[datetime.datetime]]]:
        """Returns datetime objects representing temporal extents.

        Returns:
            Iterable[Tuple[datetime.datetime, Optional[datetime.datetime]]]:
                An iterable of tuples representing temporal extents.

        Raises:
            ValueError: If the temporal interval start is not found.
        """
        for stac_interval in self.stac_json["extent"]["temporal"]["interval"]:
            if not stac_interval[0]:
                raise ValueError(
                    "Expected a non-empty temporal interval start for "
                    + self.public_id()
                )
            start_date = iso8601.parse_date(stac_interval[0])
            if stac_interval[1] is not None:
                end_date = iso8601.parse_date(stac_interval[1])
            else:
                end_date = None
            yield (start_date, end_date)

    def start(self) -> datetime.datetime:
        """Gets the start datetime of the collection.

        Returns:
            datetime.datetime: The start datetime of the collection.
        """
        return list(self.datetime_interval())[0][0]

    def start_str(self) -> str:
        """Gets the start datetime of the collection as a string.

        Returns:
            str: The start datetime of the collection as a string.
        """
        if not self.start():
            return ""
        return self.start().strftime("%Y-%m-%d")

    def end(self) -> Optional[datetime.datetime]:
        """Gets the end datetime of the collection.

        Returns:
            Optional[datetime.datetime]: The end datetime of the collection.
        """
        return list(self.datetime_interval())[0][1]

    def end_str(self) -> str:
        """Gets the end datetime of the collection as a string.

        Returns:
            str: The end datetime of the collection as a string.
        """
        if not self.end():
            return ""
        return self.end().strftime("%Y-%m-%d")

    def bbox_list(self) -> Sequence[BBox]:
        """Gets the bounding boxes of the collection.

        Returns:
            Sequence[BBox]: A sequence of bounding boxes.
        """
        if "extent" not in self.stac_json:
            # Assume global if nothing listed.
            return (BBox(-180, -90, 180, 90),)
        return tuple(
            [BBox.from_list(x) for x in self.stac_json["extent"]["spatial"]["bbox"]]
        )

    def bands(self) -> List[Dict[str, Any]]:
        """Gets the bands of the collection.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries representing the bands.
        """
        summaries = self.stac_json.get("summaries")
        if not summaries:
            return []
        return summaries.get("eo:bands", [])

    def spatial_resolution_m(self) -> float:
        """Gets the spatial resolution of the collection in meters.

        Returns:
            float: The spatial resolution of the collection in meters.
        """
        summaries = self.stac_json.get("summaries")
        if not summaries:
            return -1
        if summaries.get("gsd"):
            return summaries.get("gsd")[0]

        # Fallback for cases where the stac does not follow convention.
        gsd_lst = re.findall(r'"gsd": (\d+)', json.dumps(self.stac_json))

        if len(gsd_lst) > 0:
            return float(gsd_lst[0])

        return -1

    def temporal_resolution_str(self) -> str:
        """Gets the temporal resolution of the collection as a string.

        Returns:
            str: The temporal resolution of the collection as a string.
        """
        interval_dict = self.stac_json.get("gee:interval")
        if not interval_dict:
            return ""
        return f"{interval_dict['interval']} {interval_dict['unit']}"

    def python_code(self) -> str:
        """Gets the Python code sample for the collection.

        Returns:
            str: The Python code sample for the collection.
        """
        code = self.stac_json.get("code")
        if not code:
            return ""

        return code.get("py_code")

    def set_python_code(self, code: str) -> None:
        """Sets the Python code sample for the collection.

        Args:
            code (str): The Python code sample to set.
        """
        if not code:
            self.stac_json["code"] = {"js_code": "", "py_code": code}

        self.stac_json["code"]["py_code"] = code

    def set_js_code(self, code: str) -> None:
        """Sets the JavaScript code sample for the collection.

        Args:
            code (str): The JavaScript code sample to set.
        """
        if not code:
            return ""
        js_code = self.stac_json.get("code").get("js_code")
        self.stac_json["code"] = {"js_code": "", "py_code": code}

    def image_preview_url(self) -> str:
        """Gets the URL of the preview image for the collection.

        Returns:
            str: The URL of the preview image for the collection.

        Raises:
            ValueError: If no preview image is found.
        """
        for link in self.stac_json["links"]:
            if (
                "rel" in link
                and link["rel"] == "preview"
                and link["type"] == "image/png"
            ):
                return link["href"]
        raise ValueError(f"No preview image found for {id}")

    def catalog_url(self) -> str:
        """Gets the URL of the catalog for the collection.

        Returns:
            str: The URL of the catalog for the collection.
        """
        links = self.stac_json["links"]
        for link in links:
            if "rel" in link and link["rel"] == "catalog":
                return link["href"]

            # Ideally there would be a 'catalog' link but sometimes there isn't.
            base_url = "https://developers.google.com/earth-engine/datasets/catalog/"
            if link["href"].startswith(base_url):
                return link["href"].split("#")[0]

        logging.warning(f"No catalog link found for {self.public_id()}")
        return ""

__getitem__(item)

Gets an item from the STAC JSON.

Parameters:

Name Type Description Default
item str

The key of the item to get.

required

Returns:

Name Type Description
Any Any

The value of the item.

Source code in geemap/ai.py
938
939
940
941
942
943
944
945
946
947
def __getitem__(self, item: str) -> Any:
    """Gets an item from the STAC JSON.

    Args:
        item (str): The key of the item to get.

    Returns:
        Any: The value of the item.
    """
    return self.stac_json[item]

__init__(stac_json)

Initializes the Collection.

Parameters:

Name Type Description Default
stac_json Dict[str, Any]

The STAC JSON of the collection.

required
Source code in geemap/ai.py
927
928
929
930
931
932
933
934
935
936
def __init__(self, stac_json: Dict[str, Any]) -> None:
    """Initializes the Collection.

    Args:
        stac_json (Dict[str, Any]): The STAC JSON of the collection.
    """
    self.stac_json = stac_json
    if stac_json.get("gee:status") == "deprecated":
        # Set the STAC 'deprecated' field that we don't set in the jsonnet files
        stac_json["deprecated"] = True

bands()

Gets the bands of the collection.

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: A list of dictionaries representing the bands.

Source code in geemap/ai.py
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
def bands(self) -> List[Dict[str, Any]]:
    """Gets the bands of the collection.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries representing the bands.
    """
    summaries = self.stac_json.get("summaries")
    if not summaries:
        return []
    return summaries.get("eo:bands", [])

bbox_list()

Gets the bounding boxes of the collection.

Returns:

Type Description
Sequence[BBox]

Sequence[BBox]: A sequence of bounding boxes.

Source code in geemap/ai.py
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
def bbox_list(self) -> Sequence[BBox]:
    """Gets the bounding boxes of the collection.

    Returns:
        Sequence[BBox]: A sequence of bounding boxes.
    """
    if "extent" not in self.stac_json:
        # Assume global if nothing listed.
        return (BBox(-180, -90, 180, 90),)
    return tuple(
        [BBox.from_list(x) for x in self.stac_json["extent"]["spatial"]["bbox"]]
    )

catalog_url()

Gets the URL of the catalog for the collection.

Returns:

Name Type Description
str str

The URL of the catalog for the collection.

Source code in geemap/ai.py
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
def catalog_url(self) -> str:
    """Gets the URL of the catalog for the collection.

    Returns:
        str: The URL of the catalog for the collection.
    """
    links = self.stac_json["links"]
    for link in links:
        if "rel" in link and link["rel"] == "catalog":
            return link["href"]

        # Ideally there would be a 'catalog' link but sometimes there isn't.
        base_url = "https://developers.google.com/earth-engine/datasets/catalog/"
        if link["href"].startswith(base_url):
            return link["href"].split("#")[0]

    logging.warning(f"No catalog link found for {self.public_id()}")
    return ""

datetime_interval()

Returns datetime objects representing temporal extents.

Returns:

Type Description
Iterable[Tuple[datetime, Optional[datetime]]]

Iterable[Tuple[datetime.datetime, Optional[datetime.datetime]]]: An iterable of tuples representing temporal extents.

Raises:

Type Description
ValueError

If the temporal interval start is not found.

Source code in geemap/ai.py
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
def datetime_interval(
    self,
) -> Iterable[Tuple[datetime.datetime, Optional[datetime.datetime]]]:
    """Returns datetime objects representing temporal extents.

    Returns:
        Iterable[Tuple[datetime.datetime, Optional[datetime.datetime]]]:
            An iterable of tuples representing temporal extents.

    Raises:
        ValueError: If the temporal interval start is not found.
    """
    for stac_interval in self.stac_json["extent"]["temporal"]["interval"]:
        if not stac_interval[0]:
            raise ValueError(
                "Expected a non-empty temporal interval start for "
                + self.public_id()
            )
        start_date = iso8601.parse_date(stac_interval[0])
        if stac_interval[1] is not None:
            end_date = iso8601.parse_date(stac_interval[1])
        else:
            end_date = None
        yield (start_date, end_date)

end()

Gets the end datetime of the collection.

Returns:

Type Description
Optional[datetime]

Optional[datetime.datetime]: The end datetime of the collection.

Source code in geemap/ai.py
1038
1039
1040
1041
1042
1043
1044
def end(self) -> Optional[datetime.datetime]:
    """Gets the end datetime of the collection.

    Returns:
        Optional[datetime.datetime]: The end datetime of the collection.
    """
    return list(self.datetime_interval())[0][1]

end_str()

Gets the end datetime of the collection as a string.

Returns:

Name Type Description
str str

The end datetime of the collection as a string.

Source code in geemap/ai.py
1046
1047
1048
1049
1050
1051
1052
1053
1054
def end_str(self) -> str:
    """Gets the end datetime of the collection as a string.

    Returns:
        str: The end datetime of the collection as a string.
    """
    if not self.end():
        return ""
    return self.end().strftime("%Y-%m-%d")

get(item, default=None)

Matches dict's get by returning None if there is no item.

Parameters:

Name Type Description Default
item str

The key of the item to get.

required
default Optional[Any]

The default value to return if the item is not found. Defaults to None.

None

Returns:

Type Description
Optional[Any]

Optional[Any]: The value of the item or the default value.

Source code in geemap/ai.py
949
950
951
952
953
954
955
956
957
958
959
def get(self, item: str, default: Optional[Any] = None) -> Optional[Any]:
    """Matches dict's get by returning None if there is no item.

    Args:
        item (str): The key of the item to get.
        default (Optional[Any]): The default value to return if the item is not found. Defaults to None.

    Returns:
        Optional[Any]: The value of the item or the default value.
    """
    return self.stac_json.get(item, default)

get_dataset_type()

Gets the dataset type of the collection.

Returns:

Name Type Description
str str

The dataset type of the collection.

Source code in geemap/ai.py
977
978
979
980
981
982
983
def get_dataset_type(self) -> str:
    """Gets the dataset type of the collection.

    Returns:
        str: The dataset type of the collection.
    """
    return self["gee:type"]

hyphen_id()

Gets the hyphenated ID of the collection.

Returns:

Name Type Description
str str

The hyphenated ID of the collection.

Source code in geemap/ai.py
969
970
971
972
973
974
975
def hyphen_id(self) -> str:
    """Gets the hyphenated ID of the collection.

    Returns:
        str: The hyphenated ID of the collection.
    """
    return self["id"].replace("/", "_")

image_preview_url()

Gets the URL of the preview image for the collection.

Returns:

Name Type Description
str str

The URL of the preview image for the collection.

Raises:

Type Description
ValueError

If no preview image is found.

Source code in geemap/ai.py
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
def image_preview_url(self) -> str:
    """Gets the URL of the preview image for the collection.

    Returns:
        str: The URL of the preview image for the collection.

    Raises:
        ValueError: If no preview image is found.
    """
    for link in self.stac_json["links"]:
        if (
            "rel" in link
            and link["rel"] == "preview"
            and link["type"] == "image/png"
        ):
            return link["href"]
    raise ValueError(f"No preview image found for {id}")

is_deprecated()

Checks if the collection is deprecated or has a successor.

Returns:

Name Type Description
bool bool

True if the collection is deprecated or has a successor, False otherwise.

Source code in geemap/ai.py
985
986
987
988
989
990
991
992
993
def is_deprecated(self) -> bool:
    """Checks if the collection is deprecated or has a successor.

    Returns:
        bool: True if the collection is deprecated or has a successor, False otherwise.
    """
    if self.get("deprecated", False):
        logging.info("Skipping deprecated collection: %s", self.public_id())
        return True

public_id()

Gets the public ID of the collection.

Returns:

Name Type Description
str str

The public ID of the collection.

Source code in geemap/ai.py
961
962
963
964
965
966
967
def public_id(self) -> str:
    """Gets the public ID of the collection.

    Returns:
        str: The public ID of the collection.
    """
    return self["id"]

python_code()

Gets the Python code sample for the collection.

Returns:

Name Type Description
str str

The Python code sample for the collection.

Source code in geemap/ai.py
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
def python_code(self) -> str:
    """Gets the Python code sample for the collection.

    Returns:
        str: The Python code sample for the collection.
    """
    code = self.stac_json.get("code")
    if not code:
        return ""

    return code.get("py_code")

set_js_code(code)

Sets the JavaScript code sample for the collection.

Parameters:

Name Type Description Default
code str

The JavaScript code sample to set.

required
Source code in geemap/ai.py
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
def set_js_code(self, code: str) -> None:
    """Sets the JavaScript code sample for the collection.

    Args:
        code (str): The JavaScript code sample to set.
    """
    if not code:
        return ""
    js_code = self.stac_json.get("code").get("js_code")
    self.stac_json["code"] = {"js_code": "", "py_code": code}

set_python_code(code)

Sets the Python code sample for the collection.

Parameters:

Name Type Description Default
code str

The Python code sample to set.

required
Source code in geemap/ai.py
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
def set_python_code(self, code: str) -> None:
    """Sets the Python code sample for the collection.

    Args:
        code (str): The Python code sample to set.
    """
    if not code:
        self.stac_json["code"] = {"js_code": "", "py_code": code}

    self.stac_json["code"]["py_code"] = code

spatial_resolution_m()

Gets the spatial resolution of the collection in meters.

Returns:

Name Type Description
float float

The spatial resolution of the collection in meters.

Source code in geemap/ai.py
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
def spatial_resolution_m(self) -> float:
    """Gets the spatial resolution of the collection in meters.

    Returns:
        float: The spatial resolution of the collection in meters.
    """
    summaries = self.stac_json.get("summaries")
    if not summaries:
        return -1
    if summaries.get("gsd"):
        return summaries.get("gsd")[0]

    # Fallback for cases where the stac does not follow convention.
    gsd_lst = re.findall(r'"gsd": (\d+)', json.dumps(self.stac_json))

    if len(gsd_lst) > 0:
        return float(gsd_lst[0])

    return -1

start()

Gets the start datetime of the collection.

Returns:

Type Description
datetime

datetime.datetime: The start datetime of the collection.

Source code in geemap/ai.py
1020
1021
1022
1023
1024
1025
1026
def start(self) -> datetime.datetime:
    """Gets the start datetime of the collection.

    Returns:
        datetime.datetime: The start datetime of the collection.
    """
    return list(self.datetime_interval())[0][0]

start_str()

Gets the start datetime of the collection as a string.

Returns:

Name Type Description
str str

The start datetime of the collection as a string.

Source code in geemap/ai.py
1028
1029
1030
1031
1032
1033
1034
1035
1036
def start_str(self) -> str:
    """Gets the start datetime of the collection as a string.

    Returns:
        str: The start datetime of the collection as a string.
    """
    if not self.start():
        return ""
    return self.start().strftime("%Y-%m-%d")

temporal_resolution_str()

Gets the temporal resolution of the collection as a string.

Returns:

Name Type Description
str str

The temporal resolution of the collection as a string.

Source code in geemap/ai.py
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
def temporal_resolution_str(self) -> str:
    """Gets the temporal resolution of the collection as a string.

    Returns:
        str: The temporal resolution of the collection as a string.
    """
    interval_dict = self.stac_json.get("gee:interval")
    if not interval_dict:
        return ""
    return f"{interval_dict['interval']} {interval_dict['unit']}"

CollectionList

Bases: Sequence[Collection]

List of stac.Collections; can be filtered to return a smaller sublist.

Source code in geemap/ai.py
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
class CollectionList(Sequence[Collection]):
    """List of stac.Collections; can be filtered to return a smaller sublist."""

    _collections = Sequence[Collection]

    def __init__(self, collections: Sequence[Collection]):
        self._collections = tuple(collections)

    def __iter__(self):
        return iter(self._collections)

    def __getitem__(self, index):
        return self._collections[index]

    def __len__(self):
        return len(self._collections)

    def __eq__(self, other: object) -> bool:
        if isinstance(other, CollectionList):
            return self._collections == other._collections
        return False

    def __hash__(self) -> int:
        return hash(self._collections)

    def filter_by_ids(self, ids: Iterable[str]):
        """Returns a sublist with only the collections matching the given ids."""
        return self.__class__([c for c in self._collections if c.public_id() in ids])

    def filter_by_datetime(
        self,
        query_datetime: datetime.datetime,
    ):
        """Returns a sublist with the time interval matching the given time."""
        result = []
        for collection in self._collections:
            for datetime_interval in collection.datetime_interval():
                if matches_datetime(datetime_interval, query_datetime):
                    result.append(collection)
                    break
        return self.__class__(result)

    def filter_by_interval(
        self,
        query_interval: tuple[datetime.datetime, datetime.datetime],
    ):
        """Returns a sublist with the time interval matching the given interval."""
        result = []
        for collection in self._collections:
            for datetime_interval in collection.datetime_interval():
                if matches_interval(datetime_interval, query_interval):
                    result.append(collection)
                    break
        return self.__class__(result)

    def filter_by_bounding_box_list(self, query_bbox: BBox):
        """Returns a sublist with the bbox matching the given bbox."""
        result = []
        for collection in self._collections:
            for collection_bbox in collection.bbox_list():
                if collection_bbox.intersects(query_bbox):
                    result.append(collection)
                    break
        return self.__class__(result)

    def filter_by_bounding_box(self, query_bbox: BBox):
        """Returns a sublist with the bbox matching the given bbox."""
        result = []
        for collection in self._collections:
            for collection_bbox in collection.bbox_list():
                if collection_bbox.intersects(query_bbox):
                    result.append(collection)
                    break
        return self.__class__(result)

    def start_str(self) -> datetime.datetime:
        return self.start().strftime("%Y-%m-%d")

    def sort_by_spatial_resolution(self, reverse=False):
        """
        Sorts the collections based on their spatial resolution.
        Collections with spatial_resolution_m() == -1 are pushed to the end.

        Args:
            reverse (bool): If True, sort in descending order (highest resolution first).
                            If False (default), sort in ascending order (lowest resolution first).

        Returns:
            CollectionList: A new CollectionList instance with sorted collections.
        """

        def sort_key(collection):
            resolution = collection.spatial_resolution_m()
            if resolution == -1:
                return float("inf") if not reverse else float("-inf")
            return resolution

        sorted_collections = sorted(self._collections, key=sort_key, reverse=reverse)
        return self.__class__(sorted_collections)

    def limit(self, n: int):
        """
        Returns a new CollectionList containing the first n entries.

        Args:
            n (int): The number of entries to include in the new list.

        Returns:
            CollectionList: A new CollectionList instance with at most n collections.
        """
        return self.__class__(self._collections[:n])

    def to_df(self):
        """Converts a collection list to a dataframe with a select set of fields."""

        rows = []
        for col in self._collections:
            # Remove text in parens in dataset name.
            short_title = re.sub(r"\([^)]*\)", "", col.get("title")).strip()

            row = {
                "id": col.public_id(),
                "name": short_title,
                "temp_res": col.temporal_resolution_str(),
                "spatial_res_m": col.spatial_resolution_m(),
                "earliest": col.start_str(),
                "latest": col.end_str(),
                "url": col.catalog_url(),
            }
            rows.append(row)
        return pd.DataFrame(rows)

filter_by_bounding_box(query_bbox)

Returns a sublist with the bbox matching the given bbox.

Source code in geemap/ai.py
1249
1250
1251
1252
1253
1254
1255
1256
1257
def filter_by_bounding_box(self, query_bbox: BBox):
    """Returns a sublist with the bbox matching the given bbox."""
    result = []
    for collection in self._collections:
        for collection_bbox in collection.bbox_list():
            if collection_bbox.intersects(query_bbox):
                result.append(collection)
                break
    return self.__class__(result)

filter_by_bounding_box_list(query_bbox)

Returns a sublist with the bbox matching the given bbox.

Source code in geemap/ai.py
1239
1240
1241
1242
1243
1244
1245
1246
1247
def filter_by_bounding_box_list(self, query_bbox: BBox):
    """Returns a sublist with the bbox matching the given bbox."""
    result = []
    for collection in self._collections:
        for collection_bbox in collection.bbox_list():
            if collection_bbox.intersects(query_bbox):
                result.append(collection)
                break
    return self.__class__(result)

filter_by_datetime(query_datetime)

Returns a sublist with the time interval matching the given time.

Source code in geemap/ai.py
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
def filter_by_datetime(
    self,
    query_datetime: datetime.datetime,
):
    """Returns a sublist with the time interval matching the given time."""
    result = []
    for collection in self._collections:
        for datetime_interval in collection.datetime_interval():
            if matches_datetime(datetime_interval, query_datetime):
                result.append(collection)
                break
    return self.__class__(result)

filter_by_ids(ids)

Returns a sublist with only the collections matching the given ids.

Source code in geemap/ai.py
1209
1210
1211
def filter_by_ids(self, ids: Iterable[str]):
    """Returns a sublist with only the collections matching the given ids."""
    return self.__class__([c for c in self._collections if c.public_id() in ids])

filter_by_interval(query_interval)

Returns a sublist with the time interval matching the given interval.

Source code in geemap/ai.py
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
def filter_by_interval(
    self,
    query_interval: tuple[datetime.datetime, datetime.datetime],
):
    """Returns a sublist with the time interval matching the given interval."""
    result = []
    for collection in self._collections:
        for datetime_interval in collection.datetime_interval():
            if matches_interval(datetime_interval, query_interval):
                result.append(collection)
                break
    return self.__class__(result)

limit(n)

Returns a new CollectionList containing the first n entries.

Parameters:

Name Type Description Default
n int

The number of entries to include in the new list.

required

Returns:

Name Type Description
CollectionList

A new CollectionList instance with at most n collections.

Source code in geemap/ai.py
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
def limit(self, n: int):
    """
    Returns a new CollectionList containing the first n entries.

    Args:
        n (int): The number of entries to include in the new list.

    Returns:
        CollectionList: A new CollectionList instance with at most n collections.
    """
    return self.__class__(self._collections[:n])

sort_by_spatial_resolution(reverse=False)

Sorts the collections based on their spatial resolution. Collections with spatial_resolution_m() == -1 are pushed to the end.

Parameters:

Name Type Description Default
reverse bool

If True, sort in descending order (highest resolution first). If False (default), sort in ascending order (lowest resolution first).

False

Returns:

Name Type Description
CollectionList

A new CollectionList instance with sorted collections.

Source code in geemap/ai.py
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
def sort_by_spatial_resolution(self, reverse=False):
    """
    Sorts the collections based on their spatial resolution.
    Collections with spatial_resolution_m() == -1 are pushed to the end.

    Args:
        reverse (bool): If True, sort in descending order (highest resolution first).
                        If False (default), sort in ascending order (lowest resolution first).

    Returns:
        CollectionList: A new CollectionList instance with sorted collections.
    """

    def sort_key(collection):
        resolution = collection.spatial_resolution_m()
        if resolution == -1:
            return float("inf") if not reverse else float("-inf")
        return resolution

    sorted_collections = sorted(self._collections, key=sort_key, reverse=reverse)
    return self.__class__(sorted_collections)

to_df()

Converts a collection list to a dataframe with a select set of fields.

Source code in geemap/ai.py
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
def to_df(self):
    """Converts a collection list to a dataframe with a select set of fields."""

    rows = []
    for col in self._collections:
        # Remove text in parens in dataset name.
        short_title = re.sub(r"\([^)]*\)", "", col.get("title")).strip()

        row = {
            "id": col.public_id(),
            "name": short_title,
            "temp_res": col.temporal_resolution_str(),
            "spatial_res_m": col.spatial_resolution_m(),
            "earliest": col.start_str(),
            "latest": col.end_str(),
            "url": col.catalog_url(),
        }
        rows.append(row)
    return pd.DataFrame(rows)

DatasetExplorer

A widget for exploring Earth Engine datasets. The DataExplorer class source code is adapted from https://bit.ly/48cE24D. Credit to the original author Renee Johnston (https://github.com/raj02006)

Source code in geemap/ai.py
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
class DatasetExplorer:
    """A widget for exploring Earth Engine datasets.
    The DataExplorer class source code is adapted from <https://bit.ly/48cE24D>.
    Credit to the original author Renee Johnston (<https://github.com/raj02006>)
    """

    def __init__(
        self,
        project_id: str = "GOOGLE_PROJECT_ID",
        google_api_key: str = "GOOGLE_API_KEY",
        vertex_ai_zone: str = "us-central1",
        model: str = "gemini-1.5-pro",
        embeddings_cloud_path: str = "gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl",
    ) -> None:
        """Initializes the DatasetExplorer.

        Args:
            project_id (str): Google Cloud project ID. Defaults to "GOOGLE_PROJECT_ID".
            google_api_key (str): Google API key. Defaults to "GOOGLE_API_KEY".
            vertex_ai_zone (str): Vertex AI zone. Defaults to "us-central1".
            model (str): Model name for ChatGoogleGenerativeAI. Defaults to "gemini-1.5-pro".
            embeddings_cloud_path (str): Cloud path to the embeddings file.
                Defaults to "gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl".
        """

        # @title Setup
        project_name = get_env_var(project_id)
        if project_name is None:
            project_name = get_env_var("EE_PROJECT_ID")
        if project_name is None:
            raise ValueError("Please provide a Google Cloud project ID.")

        genai.configure(api_key=get_env_var(google_api_key))

        ee.Authenticate()
        ee.Initialize(project=project_name)
        storage_client = storage.Client(project=project_name)
        vertexai.init(project=project_name, location=vertex_ai_zone)

        # Make sure geemap initialized correctly.
        ee_initialize(project=project_name)
        m = Map(draw_ctrl=False)
        m.add("layer_manager")

        catalog = Catalog(storage_client)

        # Pre-built embeddings.
        EMBEDDINGS_CLOUD_PATH = embeddings_cloud_path

        # Copy embeddings from GCS bucket to a local file
        EMBEDDINGS_LOCAL_PATH = "catalog_embeddings.jsonl"

        def load_embeddings(
            gcs_path=EMBEDDINGS_CLOUD_PATH, local_path=EMBEDDINGS_LOCAL_PATH
        ):
            parts = gcs_path.split("/")
            bucket_name = parts[2]
            blob_path = "/".join(parts[3:])
            bucket = storage_client.get_bucket(bucket_name)
            blob = bucket.blob(blob_path)
            blob.download_to_filename(local_path)
            return local_path

        # Load our embeddings data into a dataframe:
        local_path = load_embeddings(EMBEDDINGS_CLOUD_PATH, EMBEDDINGS_LOCAL_PATH)
        embeddings_df = pd.read_json(local_path, lines=True)

        llm = ChatGoogleGenerativeAI(
            model=model, google_api_key=get_env_var(google_api_key)
        )

        local_path = load_embeddings(EMBEDDINGS_CLOUD_PATH, EMBEDDINGS_LOCAL_PATH)
        embeddings_df = pd.read_json(local_path, lines=True)
        langchain_index = make_langchain_index(embeddings_df)

        self.ee_index = EarthEngineDatasetIndex(catalog, langchain_index, llm)

    def show(self, query: Optional[str] = None, **kwargs: Any) -> widgets.VBox:
        """Displays a query interface for searching datasets.

        Args:
            query (Optional[str]): The initial query string. Defaults to None.
            **kwargs (Any): Additional keyword arguments for widget styling.

        Returns:
            widgets.VBox: A VBox containing the query input and output display.
        """

        output.no_vertical_scroll()

        if "style" not in kwargs:
            kwargs["style"] = {"description_width": "initial"}
        if "layout" not in kwargs:
            kwargs["layout"] = widgets.Layout(width="98%")

        def Question(query: str) -> None:
            """Handles the query and displays the dataset search results.

            Args:
                query (str): The query string.
            """
            if not is_valid_question(query):
                print("Invalid question. Please try again.")
                return

            datasets = self.ee_index.find_top_matches(query)
            # datasets = datasets.sort_by_spatial_resolution().limit(5)
            datasets = datasets.limit(7)
            dataset_search = DatasetSearchInterface(query, datasets)
            dataset_search.display()

        if query is None:
            query = "How have global land surface temperatures changed over time?"

        query_widget = widgets.Text(
            value=query,
            placeholder="Type your query here and press Enter",
            description="Query:",
            tooltip="Type your query here and press Enter",
            **kwargs,
        )

        output_widget = widgets.Output()

        display_widget = widgets.VBox([query_widget, output_widget])

        def on_query_change(text: widgets.Text) -> None:
            """Handles the event when the query text is submitted.

            Args:
                text (widgets.Text): The text widget containing the query.
            """
            output_widget.clear_output()
            with output_widget:
                print("Loading ...")
                output_widget.clear_output(wait=True)
                Question(text.value)

        query_widget.on_submit(on_query_change)

        return display_widget

__init__(project_id='GOOGLE_PROJECT_ID', google_api_key='GOOGLE_API_KEY', vertex_ai_zone='us-central1', model='gemini-1.5-pro', embeddings_cloud_path='gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl')

Initializes the DatasetExplorer.

Parameters:

Name Type Description Default
project_id str

Google Cloud project ID. Defaults to "GOOGLE_PROJECT_ID".

'GOOGLE_PROJECT_ID'
google_api_key str

Google API key. Defaults to "GOOGLE_API_KEY".

'GOOGLE_API_KEY'
vertex_ai_zone str

Vertex AI zone. Defaults to "us-central1".

'us-central1'
model str

Model name for ChatGoogleGenerativeAI. Defaults to "gemini-1.5-pro".

'gemini-1.5-pro'
embeddings_cloud_path str

Cloud path to the embeddings file. Defaults to "gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl".

'gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl'
Source code in geemap/ai.py
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
def __init__(
    self,
    project_id: str = "GOOGLE_PROJECT_ID",
    google_api_key: str = "GOOGLE_API_KEY",
    vertex_ai_zone: str = "us-central1",
    model: str = "gemini-1.5-pro",
    embeddings_cloud_path: str = "gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl",
) -> None:
    """Initializes the DatasetExplorer.

    Args:
        project_id (str): Google Cloud project ID. Defaults to "GOOGLE_PROJECT_ID".
        google_api_key (str): Google API key. Defaults to "GOOGLE_API_KEY".
        vertex_ai_zone (str): Vertex AI zone. Defaults to "us-central1".
        model (str): Model name for ChatGoogleGenerativeAI. Defaults to "gemini-1.5-pro".
        embeddings_cloud_path (str): Cloud path to the embeddings file.
            Defaults to "gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl".
    """

    # @title Setup
    project_name = get_env_var(project_id)
    if project_name is None:
        project_name = get_env_var("EE_PROJECT_ID")
    if project_name is None:
        raise ValueError("Please provide a Google Cloud project ID.")

    genai.configure(api_key=get_env_var(google_api_key))

    ee.Authenticate()
    ee.Initialize(project=project_name)
    storage_client = storage.Client(project=project_name)
    vertexai.init(project=project_name, location=vertex_ai_zone)

    # Make sure geemap initialized correctly.
    ee_initialize(project=project_name)
    m = Map(draw_ctrl=False)
    m.add("layer_manager")

    catalog = Catalog(storage_client)

    # Pre-built embeddings.
    EMBEDDINGS_CLOUD_PATH = embeddings_cloud_path

    # Copy embeddings from GCS bucket to a local file
    EMBEDDINGS_LOCAL_PATH = "catalog_embeddings.jsonl"

    def load_embeddings(
        gcs_path=EMBEDDINGS_CLOUD_PATH, local_path=EMBEDDINGS_LOCAL_PATH
    ):
        parts = gcs_path.split("/")
        bucket_name = parts[2]
        blob_path = "/".join(parts[3:])
        bucket = storage_client.get_bucket(bucket_name)
        blob = bucket.blob(blob_path)
        blob.download_to_filename(local_path)
        return local_path

    # Load our embeddings data into a dataframe:
    local_path = load_embeddings(EMBEDDINGS_CLOUD_PATH, EMBEDDINGS_LOCAL_PATH)
    embeddings_df = pd.read_json(local_path, lines=True)

    llm = ChatGoogleGenerativeAI(
        model=model, google_api_key=get_env_var(google_api_key)
    )

    local_path = load_embeddings(EMBEDDINGS_CLOUD_PATH, EMBEDDINGS_LOCAL_PATH)
    embeddings_df = pd.read_json(local_path, lines=True)
    langchain_index = make_langchain_index(embeddings_df)

    self.ee_index = EarthEngineDatasetIndex(catalog, langchain_index, llm)

show(query=None, **kwargs)

Displays a query interface for searching datasets.

Parameters:

Name Type Description Default
query Optional[str]

The initial query string. Defaults to None.

None
**kwargs Any

Additional keyword arguments for widget styling.

{}

Returns:

Type Description
VBox

widgets.VBox: A VBox containing the query input and output display.

Source code in geemap/ai.py
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
def show(self, query: Optional[str] = None, **kwargs: Any) -> widgets.VBox:
    """Displays a query interface for searching datasets.

    Args:
        query (Optional[str]): The initial query string. Defaults to None.
        **kwargs (Any): Additional keyword arguments for widget styling.

    Returns:
        widgets.VBox: A VBox containing the query input and output display.
    """

    output.no_vertical_scroll()

    if "style" not in kwargs:
        kwargs["style"] = {"description_width": "initial"}
    if "layout" not in kwargs:
        kwargs["layout"] = widgets.Layout(width="98%")

    def Question(query: str) -> None:
        """Handles the query and displays the dataset search results.

        Args:
            query (str): The query string.
        """
        if not is_valid_question(query):
            print("Invalid question. Please try again.")
            return

        datasets = self.ee_index.find_top_matches(query)
        # datasets = datasets.sort_by_spatial_resolution().limit(5)
        datasets = datasets.limit(7)
        dataset_search = DatasetSearchInterface(query, datasets)
        dataset_search.display()

    if query is None:
        query = "How have global land surface temperatures changed over time?"

    query_widget = widgets.Text(
        value=query,
        placeholder="Type your query here and press Enter",
        description="Query:",
        tooltip="Type your query here and press Enter",
        **kwargs,
    )

    output_widget = widgets.Output()

    display_widget = widgets.VBox([query_widget, output_widget])

    def on_query_change(text: widgets.Text) -> None:
        """Handles the event when the query text is submitted.

        Args:
            text (widgets.Text): The text widget containing the query.
        """
        output_widget.clear_output()
        with output_widget:
            print("Loading ...")
            output_widget.clear_output(wait=True)
            Question(text.value)

    query_widget.on_submit(on_query_change)

    return display_widget

DatasetSearchInterface

Interface for searching and displaying Earth Engine datasets.

Source code in geemap/ai.py
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
class DatasetSearchInterface:
    """Interface for searching and displaying Earth Engine datasets."""

    collections: CollectionList
    query: str
    dataset_table: widgets.Widget
    code_output: widgets.Widget
    details_output: widgets.Widget
    map_output: widgets.Widget
    geemap_instance: Map

    # Parent containers for controlling widget visibility.
    details_code_box: widgets.Widget
    map_widget: widgets.Widget

    def __init__(self, query: str, collections: CollectionList) -> None:
        """Initializes the DatasetSearchInterface.

        Args:
            query (str): The search query string.
            collections (CollectionList): The list of dataset collections.
        """

        self.query = query
        self.collections = collections

        # Create the output widgets
        self.code_output = widgets.Output(layout=widgets.Layout(width="50%"))
        self.details_output = widgets.Output(
            layout=widgets.Layout(height="300px", width="100%")
        )

        # Initialize dataset table
        table_html = self._build_table_html(collections)
        self.dataset_table = widgets.HTML(value=table_html)

        _callback_id = "dataset-select" + str(uuid.uuid4())
        output.register_callback(_callback_id, self.update_outputs)
        self._dataset_select_js_code = self._dataset_select_js_code(_callback_id)
        # self._dataset_select_js_code(_callback_id)

        # Initialize map
        self.map_output = widgets.Output(layout=widgets.Layout(width="100%"))
        self.geemap_instance = Map(height="600px", width="100%")

    def display(self):
        """Display the UI in the cell."""
        # Create title and description with Material Design styling
        # title = widgets.HTML(value='<h2 class="main-title">Earth Engine Dataset Explorer</h2>')

        # Wrap outputs in a widget box for border styling
        details_widget = widgets.Box(
            [self.details_output],
            layout=widgets.Layout(
                border="1px solid #E0E0E0", padding="10px", margin="5px", width="100%"
            ),
        )
        code_widget = widgets.Box(
            [self.code_output],
            layout=widgets.Layout(
                border="1px solid #E0E0E0", padding="10px", margin="5px", width="100%"
            ),
        )
        self.map_widget = widgets.Box(
            [self.map_output],
            layout=widgets.Layout(
                border="1px solid #E0E0E0",
                padding="10px",
                margin="5px",
                width="100%",
                height="600x",
            ),
        )

        # Create the vertical box for code and details
        self.details_code_box = widgets.VBox(
            [details_widget, code_widget],
            layout=widgets.Layout(width="50%", height="600px"),
        )

        # Create a horizontal box for map and details/code
        map_details_code_box = widgets.HBox(
            [self.map_widget, self.details_code_box],
            layout=widgets.Layout(
                border="1px solid #E0E0E0", padding="10px", margin="5px"
            ),
        )

        # Create the main layout with Material Design styling
        main_content = widgets.VBox(
            [self.dataset_table, map_details_code_box],
            layout=widgets.Layout(
                width="100%", border="1px solid #E0E0E0", padding="10px", margin="5px"
            ),
        )

        # Add debug panel to the main layout
        main_layout = widgets.VBox(
            [
                # title,
                main_content,
            ],
            layout=widgets.Layout(height="1500px", width="100%", padding="24px"),
        )

        # @title CSS
        # Custom CSS for Material Design styling with enhanced table styling, chat panel, and debug panel
        CSS = syntax.css(
            """
        @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500&display=swap');

        body {
            font-family: 'Roboto', sans-serif;
            margin: 0;
            padding: 0;
        }

        .main-title {
            font-size: 24px;
            font-weight: 500;
            color: #212121;
            margin-bottom: 16px;
        }

        .custom-title {
            font-size: 18px;
            font-weight: 500;
            color: #212121;
            margin-bottom: 12px;
        }

        .details-text {
            font-size: 14px;
            color: #616161;
            line-height: 1.5;
        }

        .custom-table {
            width: 100%;
            border-collapse: collapse;
            margin-bottom: 24px;
            font-family: 'Roboto', sans-serif;
        }
        .custom-table th, .custom-table td {
            text-align: left;
            padding: 12px;
            border-bottom: 1px solid #E0E0E0;
        }
        .custom-table th {
            background-color: #F5F5F5;
            font-weight: 500;
            color: #212121;
        }
        .custom-table tr:hover {
            background-color: #E3F2FD;
        }
        .custom-table tr.selected {
            background-color: #BBDEFB;
        }

        /* Ensure borders are visible */
        .jupyter-widgets.widget-box {
            border: 1px solid #E0E0E0 !important;
            overflow: auto;
        }

        /* Make the map span full width */
        .geemap-container {
            width: 100% !important;
            height: 600px !important;
        }

        # Disable table clicks while things are loading
        .disabled {
        pointer-events: none;
        /* You might also want to visually indicate the disabled state */
        opacity: 0.5;
        cursor: default;
        }
        """
        )

        # Display the widget
        display(HTML(f"<style>{CSS}</style>"))
        display(main_layout)
        display(Javascript(self._dataset_select_js_code))

    def _build_table_html(self, collections: CollectionList) -> str:
        """Builds the HTML for the dataset table.

        Args:
            collections (CollectionList): The list of dataset collections.

        Returns:
            str: The HTML string for the dataset table.
        """
        table_html = """
    <table class="custom-table">
        <tr>
            <th>Dataset ID</th>
            <th> Name </th>
            <th>Temporal Resolution</th>
            <th>Spatial Resolution (m)</th>
            <th>Earliest</th>
            <th>Latest</th>
        </tr>
    """
        for dataset in collections:
            table_html += f"""
        <tr data-dataset="{dataset.public_id()}">
            <td>{dataset.public_id()}</td>
            <td>{dataset.get('title')}</td>
            <td>{dataset.temporal_resolution_str()}</td>
            <td>{dataset.spatial_resolution_m()}</td>
            <td>{dataset.start_str()}</td>
            <td>{dataset.end_str()}</td>
        </tr>
        """

        table_html += "</table>"
        return table_html

    def update_outputs(self, selected_dataset: str) -> None:
        """Updates the output widgets based on the selected dataset.

        Args:
            selected_dataset (str): The ID of the selected dataset.
        """
        collection = self.collections.filter_by_ids([selected_dataset])

        if not collection:
            self.details_code_box.layout.visibility = "hidden"
            self.map_widget.layout.visibility = "hidden"
            return

        dataset = collection[0]

        # Clear everything when a new dataset is selected.
        self.map_output.clear_output()
        self.code_output.clear_output()
        self.details_output.clear_output()
        # Clear previous layers. Keep only the base layer
        self.geemap_instance.layers = self.geemap_instance.layers[:1]

        with self.map_output:
            display(HTML("<h3>Loading...</h3>"))
            code = fix_ee_python_code(dataset.python_code(), ee, self.geemap_instance)
            llm_thoughts = explain_relevance_from_stac_json(
                self.query, dataset.stac_json
            )

        # Code and LLM thought content is now fully loaded.
        # We ought to make this asynchronous in another version
        self.map_output.clear_output()

        with self.code_output:
            display(HTML('<div class="custom-title">Earth Engine Code</div>'))
            print(code)

        with self.details_output:
            # display(HTML('<h3>Thinking...</h3>'))
            # self.details_output.clear_output()
            display(HTML('<div class="custom-title">Thoughts with Gemini</div>'))
            print(llm_thoughts)

        with self.map_output:
            display(self.geemap_instance)

        self.details_code_box.layout.visibility = "visible"
        self.map_widget.layout.visibility = "visible"

    def _dataset_select_js_code(self, callback_id: str) -> str:
        """Generates JavaScript code for handling dataset selection.

        Args:
            callback_id (str): The callback ID for the dataset selection.

        Returns:
            str: The JavaScript code as a string.
        """
        return Template(
            syntax.javascript(
                """
    function initializeTableInteraction() {
        const table = document.querySelector('.custom-table');
        if (!table) {
            console.error('Table not found');
            return;
        }

        function selectRow(row) {
            // Remove selection from previously selected row
            const prevSelected = table.querySelector('tr.selected');
            if (prevSelected) prevSelected.classList.remove('selected');

            // Add selection to the new row
            row.classList.add('selected');
            const selectedDataset = row.dataset.dataset;
            console.log('Selected dataset:', selectedDataset);
            google.colab.kernel.invokeFunction('{{callback_id}}', [selectedDataset], {});

        }

        table.addEventListener('click', (event) => {
            const row = event.target.closest('tr');
            if (!row || !row.dataset.dataset) return;
            selectRow(row);
        });

        // Select the first row by default
        const firstRow = table.querySelector('tr[data-dataset]');
        if (firstRow) {
            selectRow(firstRow);
        }
    }

    // Run the initialization function after a short delay to ensure the DOM is ready
    setTimeout(initializeTableInteraction, 1000);
    """
            )
        ).render(callback_id=callback_id)

__init__(query, collections)

Initializes the DatasetSearchInterface.

Parameters:

Name Type Description Default
query str

The search query string.

required
collections CollectionList

The list of dataset collections.

required
Source code in geemap/ai.py
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
def __init__(self, query: str, collections: CollectionList) -> None:
    """Initializes the DatasetSearchInterface.

    Args:
        query (str): The search query string.
        collections (CollectionList): The list of dataset collections.
    """

    self.query = query
    self.collections = collections

    # Create the output widgets
    self.code_output = widgets.Output(layout=widgets.Layout(width="50%"))
    self.details_output = widgets.Output(
        layout=widgets.Layout(height="300px", width="100%")
    )

    # Initialize dataset table
    table_html = self._build_table_html(collections)
    self.dataset_table = widgets.HTML(value=table_html)

    _callback_id = "dataset-select" + str(uuid.uuid4())
    output.register_callback(_callback_id, self.update_outputs)
    self._dataset_select_js_code = self._dataset_select_js_code(_callback_id)
    # self._dataset_select_js_code(_callback_id)

    # Initialize map
    self.map_output = widgets.Output(layout=widgets.Layout(width="100%"))
    self.geemap_instance = Map(height="600px", width="100%")

display()

Display the UI in the cell.

Source code in geemap/ai.py
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
def display(self):
    """Display the UI in the cell."""
    # Create title and description with Material Design styling
    # title = widgets.HTML(value='<h2 class="main-title">Earth Engine Dataset Explorer</h2>')

    # Wrap outputs in a widget box for border styling
    details_widget = widgets.Box(
        [self.details_output],
        layout=widgets.Layout(
            border="1px solid #E0E0E0", padding="10px", margin="5px", width="100%"
        ),
    )
    code_widget = widgets.Box(
        [self.code_output],
        layout=widgets.Layout(
            border="1px solid #E0E0E0", padding="10px", margin="5px", width="100%"
        ),
    )
    self.map_widget = widgets.Box(
        [self.map_output],
        layout=widgets.Layout(
            border="1px solid #E0E0E0",
            padding="10px",
            margin="5px",
            width="100%",
            height="600x",
        ),
    )

    # Create the vertical box for code and details
    self.details_code_box = widgets.VBox(
        [details_widget, code_widget],
        layout=widgets.Layout(width="50%", height="600px"),
    )

    # Create a horizontal box for map and details/code
    map_details_code_box = widgets.HBox(
        [self.map_widget, self.details_code_box],
        layout=widgets.Layout(
            border="1px solid #E0E0E0", padding="10px", margin="5px"
        ),
    )

    # Create the main layout with Material Design styling
    main_content = widgets.VBox(
        [self.dataset_table, map_details_code_box],
        layout=widgets.Layout(
            width="100%", border="1px solid #E0E0E0", padding="10px", margin="5px"
        ),
    )

    # Add debug panel to the main layout
    main_layout = widgets.VBox(
        [
            # title,
            main_content,
        ],
        layout=widgets.Layout(height="1500px", width="100%", padding="24px"),
    )

    # @title CSS
    # Custom CSS for Material Design styling with enhanced table styling, chat panel, and debug panel
    CSS = syntax.css(
        """
    @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500&display=swap');

    body {
        font-family: 'Roboto', sans-serif;
        margin: 0;
        padding: 0;
    }

    .main-title {
        font-size: 24px;
        font-weight: 500;
        color: #212121;
        margin-bottom: 16px;
    }

    .custom-title {
        font-size: 18px;
        font-weight: 500;
        color: #212121;
        margin-bottom: 12px;
    }

    .details-text {
        font-size: 14px;
        color: #616161;
        line-height: 1.5;
    }

    .custom-table {
        width: 100%;
        border-collapse: collapse;
        margin-bottom: 24px;
        font-family: 'Roboto', sans-serif;
    }
    .custom-table th, .custom-table td {
        text-align: left;
        padding: 12px;
        border-bottom: 1px solid #E0E0E0;
    }
    .custom-table th {
        background-color: #F5F5F5;
        font-weight: 500;
        color: #212121;
    }
    .custom-table tr:hover {
        background-color: #E3F2FD;
    }
    .custom-table tr.selected {
        background-color: #BBDEFB;
    }

    /* Ensure borders are visible */
    .jupyter-widgets.widget-box {
        border: 1px solid #E0E0E0 !important;
        overflow: auto;
    }

    /* Make the map span full width */
    .geemap-container {
        width: 100% !important;
        height: 600px !important;
    }

    # Disable table clicks while things are loading
    .disabled {
    pointer-events: none;
    /* You might also want to visually indicate the disabled state */
    opacity: 0.5;
    cursor: default;
    }
    """
    )

    # Display the widget
    display(HTML(f"<style>{CSS}</style>"))
    display(main_layout)
    display(Javascript(self._dataset_select_js_code))

update_outputs(selected_dataset)

Updates the output widgets based on the selected dataset.

Parameters:

Name Type Description Default
selected_dataset str

The ID of the selected dataset.

required
Source code in geemap/ai.py
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
def update_outputs(self, selected_dataset: str) -> None:
    """Updates the output widgets based on the selected dataset.

    Args:
        selected_dataset (str): The ID of the selected dataset.
    """
    collection = self.collections.filter_by_ids([selected_dataset])

    if not collection:
        self.details_code_box.layout.visibility = "hidden"
        self.map_widget.layout.visibility = "hidden"
        return

    dataset = collection[0]

    # Clear everything when a new dataset is selected.
    self.map_output.clear_output()
    self.code_output.clear_output()
    self.details_output.clear_output()
    # Clear previous layers. Keep only the base layer
    self.geemap_instance.layers = self.geemap_instance.layers[:1]

    with self.map_output:
        display(HTML("<h3>Loading...</h3>"))
        code = fix_ee_python_code(dataset.python_code(), ee, self.geemap_instance)
        llm_thoughts = explain_relevance_from_stac_json(
            self.query, dataset.stac_json
        )

    # Code and LLM thought content is now fully loaded.
    # We ought to make this asynchronous in another version
    self.map_output.clear_output()

    with self.code_output:
        display(HTML('<div class="custom-title">Earth Engine Code</div>'))
        print(code)

    with self.details_output:
        # display(HTML('<h3>Thinking...</h3>'))
        # self.details_output.clear_output()
        display(HTML('<div class="custom-title">Thoughts with Gemini</div>'))
        print(llm_thoughts)

    with self.map_output:
        display(self.geemap_instance)

    self.details_code_box.layout.visibility = "visible"
    self.map_widget.layout.visibility = "visible"

EarthEngineDatasetIndex

Class for indexing and searching Earth Engine datasets.

Source code in geemap/ai.py
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
class EarthEngineDatasetIndex:
    """Class for indexing and searching Earth Engine datasets."""

    index: VectorStoreIndexWrapper
    vectorstore: VectorStore
    data_catalog: Catalog
    llm: BaseLanguageModel

    def __init__(self, data_catalog, index, llm):
        """Initializes the EarthEngineDatasetIndex.

        Args:
            data_catalog (Catalog): The data catalog containing the datasets.
            index (VectorStoreIndexWrapper): The vector store index wrapper.
            llm (BaseLanguageModel): The language model for query processing.
        """
        self.index = index
        self.data_catalog = data_catalog
        self.vectorstore = index.vectorstore
        self.llm = llm

    @tenacity.retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=4, max=10),
        retry=retry_if_exception_type(
            (requests.exceptions.RequestException, ConnectionError)
        ),
    )
    def find_top_matches(
        self,
        query: str,
        results: int = 10,
        threshold: float = 0.7,
        bounding_box: Optional[List[float]] = None,
        temporal_interval: Optional[Tuple[datetime.datetime, datetime.datetime]] = None,
    ) -> CollectionList:
        """Retrieve relevant datasets from the Earth Engine data catalog.

        Args:
            query (str): The kind of data being searched for, e.g., 'population'.
            results (int): The number of datasets to return. Defaults to 10.
            threshold (float): The maximum dot product between the query and catalog embeddings.
                Defaults to 0.7.
            bounding_box (Optional[List[float]]): The spatial bounding box for the query,
                in the format [lon1, lat1, lon2, lon2]. Defaults to None.
            temporal_interval (Optional[Tuple[datetime.datetime, datetime.datetime]]):
            Temporal constraints as a tuple of datetime objects. Defaults to None.

        Returns:
            CollectionList: A list of collections that match the query.
        """
        similar_docs = self.index.vectorstore.similarity_search_with_score(
            query, llm=self.llm, k=results
        )
        dataset_ids = [doc[0].page_content for doc in similar_docs]
        datasets = self.data_catalog.collections.filter_by_ids(dataset_ids)
        return datasets

    @tenacity.retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=4, max=10),
        retry=retry_if_exception_type(
            (requests.exceptions.RequestException, ConnectionError)
        ),
    )
    def find_top_matches_with_score_df(
        self,
        query: str,
        results: int = 20,
        threshold: float = 0.7,
        bounding_box: Optional[List[float]] = None,
        temporal_interval: Optional[Tuple[datetime.datetime, datetime.datetime]] = None,
    ) -> pd.DataFrame:
        """Retrieve relevant datasets and their match scores as a DataFrame.

        Args:
            query (str): The kind of data being searched for, e.g., 'population'.
            results (int): The number of datasets to return. Defaults to 20.
            threshold (float): The maximum dot product between the query and catalog embeddings.
                Defaults to 0.7.
            bounding_box (Optional[List[float]]): The spatial bounding box for the query,
                in the format [lon1, lat1, lon2, lon2]. Defaults to None.
            temporal_interval (Optional[Tuple[datetime.datetime, datetime.datetime]]):
                Temporal constraints as a tuple of datetime objects. Defaults to None.

        Returns:
            pd.DataFrame: A DataFrame containing the dataset IDs and their match scores.
        """
        scores_df = self.ids_to_match_scores_df(
            query, results, bounding_box, temporal_interval
        )
        dataset_ids = scores_df["id"].tolist()
        col_list = self.data_catalog.collections.filter_by_ids(dataset_ids)
        collection_df = col_list.to_df()
        df = pd.merge(collection_df, scores_df, on="id", how="inner")
        return df.sort_values(by="match_score", ascending=False)

    def ids_to_match_scores_df(
        self,
        query: str,
        results: int,
        bounding_box: Optional[List[float]] = None,
        temporal_interval: Optional[Tuple[datetime.datetime, datetime.datetime]] = None,
    ) -> pd.DataFrame:
        """Convert dataset IDs and match scores to a DataFrame.

        Args:
            query (str): The kind of data being searched for, e.g., 'population'.
            results (int): The number of datasets to return.
            bounding_box (Optional[List[float]]): The spatial bounding box for the query,
                in the format [lon1, lat1, lon2, lon2]. Defaults to None.
            temporal_interval (Optional[Tuple[datetime.datetime, datetime.datetime]]):
                Temporal constraints as a tuple of datetime objects. Defaults to None.

        Returns:
            pd.DataFrame: A DataFrame containing the dataset IDs and their match scores.
        """
        similar_docs = self.index.vectorstore.similarity_search_with_score(
            query, llm=self.llm, k=results
        )
        dataset_ids, scores = zip(
            *[(doc[0].page_content, doc[1]) for doc in similar_docs]
        )

        return pd.DataFrame({"id": dataset_ids, "match_score": scores})

__init__(data_catalog, index, llm)

Initializes the EarthEngineDatasetIndex.

Parameters:

Name Type Description Default
data_catalog Catalog

The data catalog containing the datasets.

required
index VectorStoreIndexWrapper

The vector store index wrapper.

required
llm BaseLanguageModel

The language model for query processing.

required
Source code in geemap/ai.py
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
def __init__(self, data_catalog, index, llm):
    """Initializes the EarthEngineDatasetIndex.

    Args:
        data_catalog (Catalog): The data catalog containing the datasets.
        index (VectorStoreIndexWrapper): The vector store index wrapper.
        llm (BaseLanguageModel): The language model for query processing.
    """
    self.index = index
    self.data_catalog = data_catalog
    self.vectorstore = index.vectorstore
    self.llm = llm

find_top_matches(query, results=10, threshold=0.7, bounding_box=None, temporal_interval=None)

Retrieve relevant datasets from the Earth Engine data catalog.

Parameters:

Name Type Description Default
query str

The kind of data being searched for, e.g., 'population'.

required
results int

The number of datasets to return. Defaults to 10.

10
threshold float

The maximum dot product between the query and catalog embeddings. Defaults to 0.7.

0.7
bounding_box Optional[List[float]]

The spatial bounding box for the query, in the format [lon1, lat1, lon2, lon2]. Defaults to None.

None
temporal_interval Optional[Tuple[datetime, datetime]]
None

Returns:

Name Type Description
CollectionList CollectionList

A list of collections that match the query.

Source code in geemap/ai.py
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
@tenacity.retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(
        (requests.exceptions.RequestException, ConnectionError)
    ),
)
def find_top_matches(
    self,
    query: str,
    results: int = 10,
    threshold: float = 0.7,
    bounding_box: Optional[List[float]] = None,
    temporal_interval: Optional[Tuple[datetime.datetime, datetime.datetime]] = None,
) -> CollectionList:
    """Retrieve relevant datasets from the Earth Engine data catalog.

    Args:
        query (str): The kind of data being searched for, e.g., 'population'.
        results (int): The number of datasets to return. Defaults to 10.
        threshold (float): The maximum dot product between the query and catalog embeddings.
            Defaults to 0.7.
        bounding_box (Optional[List[float]]): The spatial bounding box for the query,
            in the format [lon1, lat1, lon2, lon2]. Defaults to None.
        temporal_interval (Optional[Tuple[datetime.datetime, datetime.datetime]]):
        Temporal constraints as a tuple of datetime objects. Defaults to None.

    Returns:
        CollectionList: A list of collections that match the query.
    """
    similar_docs = self.index.vectorstore.similarity_search_with_score(
        query, llm=self.llm, k=results
    )
    dataset_ids = [doc[0].page_content for doc in similar_docs]
    datasets = self.data_catalog.collections.filter_by_ids(dataset_ids)
    return datasets

find_top_matches_with_score_df(query, results=20, threshold=0.7, bounding_box=None, temporal_interval=None)

Retrieve relevant datasets and their match scores as a DataFrame.

Parameters:

Name Type Description Default
query str

The kind of data being searched for, e.g., 'population'.

required
results int

The number of datasets to return. Defaults to 20.

20
threshold float

The maximum dot product between the query and catalog embeddings. Defaults to 0.7.

0.7
bounding_box Optional[List[float]]

The spatial bounding box for the query, in the format [lon1, lat1, lon2, lon2]. Defaults to None.

None
temporal_interval Optional[Tuple[datetime, datetime]]

Temporal constraints as a tuple of datetime objects. Defaults to None.

None

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing the dataset IDs and their match scores.

Source code in geemap/ai.py
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
@tenacity.retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(
        (requests.exceptions.RequestException, ConnectionError)
    ),
)
def find_top_matches_with_score_df(
    self,
    query: str,
    results: int = 20,
    threshold: float = 0.7,
    bounding_box: Optional[List[float]] = None,
    temporal_interval: Optional[Tuple[datetime.datetime, datetime.datetime]] = None,
) -> pd.DataFrame:
    """Retrieve relevant datasets and their match scores as a DataFrame.

    Args:
        query (str): The kind of data being searched for, e.g., 'population'.
        results (int): The number of datasets to return. Defaults to 20.
        threshold (float): The maximum dot product between the query and catalog embeddings.
            Defaults to 0.7.
        bounding_box (Optional[List[float]]): The spatial bounding box for the query,
            in the format [lon1, lat1, lon2, lon2]. Defaults to None.
        temporal_interval (Optional[Tuple[datetime.datetime, datetime.datetime]]):
            Temporal constraints as a tuple of datetime objects. Defaults to None.

    Returns:
        pd.DataFrame: A DataFrame containing the dataset IDs and their match scores.
    """
    scores_df = self.ids_to_match_scores_df(
        query, results, bounding_box, temporal_interval
    )
    dataset_ids = scores_df["id"].tolist()
    col_list = self.data_catalog.collections.filter_by_ids(dataset_ids)
    collection_df = col_list.to_df()
    df = pd.merge(collection_df, scores_df, on="id", how="inner")
    return df.sort_values(by="match_score", ascending=False)

ids_to_match_scores_df(query, results, bounding_box=None, temporal_interval=None)

Convert dataset IDs and match scores to a DataFrame.

Parameters:

Name Type Description Default
query str

The kind of data being searched for, e.g., 'population'.

required
results int

The number of datasets to return.

required
bounding_box Optional[List[float]]

The spatial bounding box for the query, in the format [lon1, lat1, lon2, lon2]. Defaults to None.

None
temporal_interval Optional[Tuple[datetime, datetime]]

Temporal constraints as a tuple of datetime objects. Defaults to None.

None

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing the dataset IDs and their match scores.

Source code in geemap/ai.py
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
def ids_to_match_scores_df(
    self,
    query: str,
    results: int,
    bounding_box: Optional[List[float]] = None,
    temporal_interval: Optional[Tuple[datetime.datetime, datetime.datetime]] = None,
) -> pd.DataFrame:
    """Convert dataset IDs and match scores to a DataFrame.

    Args:
        query (str): The kind of data being searched for, e.g., 'population'.
        results (int): The number of datasets to return.
        bounding_box (Optional[List[float]]): The spatial bounding box for the query,
            in the format [lon1, lat1, lon2, lon2]. Defaults to None.
        temporal_interval (Optional[Tuple[datetime.datetime, datetime.datetime]]):
            Temporal constraints as a tuple of datetime objects. Defaults to None.

    Returns:
        pd.DataFrame: A DataFrame containing the dataset IDs and their match scores.
    """
    similar_docs = self.index.vectorstore.similarity_search_with_score(
        query, llm=self.llm, k=results
    )
    dataset_ids, scores = zip(
        *[(doc[0].page_content, doc[1]) for doc in similar_docs]
    )

    return pd.DataFrame({"id": dataset_ids, "match_score": scores})

Genie

Bases: VBox

A widget for interacting with the Genie AI model.

The source code is adapted from the ee_genie.ipynb at https://bit.ly/3YEm7B6. Credit to the original author Simon Ilyushchenko (https://github.com/simonff).

Parameters:

Name Type Description Default
project Optional[str]

Google Cloud project ID. Defaults to None.

None
google_api_key Optional[str]

Google API key. Defaults to None.

None
gemini_model str

The Gemini model to use. Defaults to "gemini-1.5-flash". For a list of available models, see https://bit.ly/4fKfXW7.

'gemini-1.5-flash'
target_score float

The target score for the model. Defaults to 0.8.

0.8
widget_height str

The height of the widget. Defaults to "600px".

'600px'
initialize_ee bool

Whether to initialize Earth Engine. Defaults to True.

True

Raises:

Type Description
ValueError

If the project ID or Google API key is not provided.

Source code in geemap/ai.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
class Genie(widgets.VBox):
    """A widget for interacting with the Genie AI model.

    The source code is adapted from the ee_genie.ipynb at <https://bit.ly/3YEm7B6>.
    Credit to the original author Simon Ilyushchenko (<https://github.com/simonff>).

    Args:
        project (Optional[str], optional): Google Cloud project ID. Defaults to None.
        google_api_key (Optional[str], optional): Google API key. Defaults to None.
        gemini_model (str, optional): The Gemini model to use. Defaults to "gemini-1.5-flash".
            For a list of available models, see https://bit.ly/4fKfXW7.
        target_score (float, optional): The target score for the model. Defaults to 0.8.
        widget_height (str, optional): The height of the widget. Defaults to "600px".
        initialize_ee (bool, optional): Whether to initialize Earth Engine. Defaults to True.

    Raises:
        ValueError: If the project ID or Google API key is not provided.
    """

    def __init__(
        self,
        project: Optional[str] = None,
        google_api_key: Optional[str] = None,
        gemini_model: str = "gemini-1.5-flash",
        target_score: float = 0.8,
        widget_height: str = "600px",
        initialize_ee: bool = True,
    ) -> None:
        # Initialization

        if project is None:
            project = get_env_var("EE_PROJECT_ID") or get_env_var("GOOGLE_PROJECT_ID")
        if project is None:
            raise ValueError(
                "Please provide a valid project ID via the 'project' parameter."
            )

        if google_api_key is None:
            google_api_key = get_env_var("GOOGLE_API_KEY")
        if google_api_key is None:
            raise ValueError(
                "Please provide a valid Google API key via the 'google_api_key' parameter."
            )

        if initialize_ee:
            ee_initialize(project=project)

        genai.configure(api_key=google_api_key)
        storage_client = storage.Client(project=project)
        bucket = storage_client.get_bucket("earthengine-stac")

        # Score to aim for (on the 0-1 scale). The exact meaning of what "score" means
        # is left to the LLM.

        # Count of analysis rounds

        self.iteration = 1
        self.map_dirty = False

        m = Map()
        m.add("layer_manager")
        self.map = m

        analysis_model = None

        image_model = genai.GenerativeModel(gemini_model)

        # UI widget definitions

        # We define the widgets early because some functions will write to the debug
        # and/or chat panels.

        command_input = widgets.Text(
            value="show a whole continent Australia DEM visualization using a palette that captures the elevation range",
            # value='show NYC',
            # value='show an area with many center pivot irrigation circles',
            # value='show a fire scar',
            # value='show an open pit mine',
            # value='a sea port',
            # value='flood consequences',
            # value='show an interesting modis composite with the relevant visualization and analyze it over Costa Rica',
            description="❓",
            layout=widgets.Layout(width="100%", height="50px"),
        )

        command_output = widgets.Label(
            value="Last command will be here",
        )

        status_label = widgets.Textarea(
            value="LLM response will be here",
            layout=widgets.Layout(width="50%", height="100px"),
        )

        # widget_height = "600px"
        debug_output = widgets.Output(
            layout={
                "border": "1px solid black",
                "height": widget_height,
                "overflow": "scroll",
                "width": "500px",
                "padding": "5px",
            }
        )
        with debug_output:
            print("DEBUG COLUMN\n")

        logo = requests.get(
            "https://drive.usercontent.google.com/download?id=1zE6G_nxXa3G5N0G_32jEhzdum2kMDfNY&export=view&authuser=0"
        ).content

        image_widget = widgets.Image(value=logo, format="png", width=400, height=600)

        chat_output = widgets.Output(
            layout={
                "border": "1px solid black",
                "height": "600px",
                "overflow": "scroll",
                "width": "300px",
            }
        )

        with chat_output:
            print("CHAT COLUMN\n")

        # Simple functions that LLM will call

        def set_center(x: float, y: float, zoom: int) -> str:
            """Sets the map center to the given coordinates and zoom level and
            returns instructions on what to do next."""
            with debug_output:
                print(f"SET_CENTER({x}, {y}, {zoom})\n")
            m.set_center(x, y)
            m.zoom = zoom
            # global map_dirty
            self.map_dirty = True
            return (
                "Do not call any more functions in this request to let geemap bounds "
                "update. Wait for user input."
            )

        def add_image_layer(image_id: str) -> str:
            """Adds to the map center an ee.Image with the given id
            and returns status message (success or failure)."""
            m.clear()
            command_output.value = f"add_image_layer('{image_id}')"
            m.addLayer(ee.Image(image_id))
            return "success"

        def get_dataset_description(dataset_id: str) -> str:
            """Fetches JSON STAC description for the given Earth Engine dataset id."""
            with debug_output:
                print(f"LOOKING UP {dataset_id}\n")
            parent = dataset_id.split("/")[0]

            # Get the blob (file)
            path = (
                os.path.join("catalog", parent, dataset_id.replace("/", "_")) + ".json"
            )
            blob = bucket.blob(path)

            if not blob.exists():
                return "dataset file not found: " + path

            file_contents = blob.download_as_string().decode()
            return file_contents

        def get_image(image_url: str) -> bytes:
            """Fetches from Earth Engine the content of the given URL as bytes."""
            response = requests.get(image_url)

            if response.status_code == 200:
                image_widget.value = response.content
                return response.content
            else:
                error_message = f"Error downloading image: {response}"
                try:
                    error_details = (
                        json.loads(response.content.decode())
                        .get("error", {})
                        .get("message")
                    )
                    if error_details:
                        error_message += f" - {error_details}"
                except json.JSONDecodeError:
                    pass
                with debug_output:
                    print(error_message)
                raise ValueError("URL %s causes %s" % (image_url, error_message))

        def show_layer(python_code: str) -> str:
            """Execute the given Earth Engine Python client code and add the result to
            the map. Returns the status message (success or error message)."""
            m.layers = m.layers[:2]
            while '\\"' in python_code:
                python_code = python_code.replace('\\"', '"')
            command_output.value = f"show_layer('{python_code}')"
            with debug_output:
                print(f"IMAGE:\n {python_code}\n")
            try:
                locals = {}
                exec(f"import ee; im = {python_code}", {}, locals)
                m.addLayer(locals["im"])
            except Exception as e:
                with debug_output:
                    print(f"ERROR: {e}")
                return str(e)
            return "success"

        def inner_monologue(thoughts: str) -> str:
            """Sends the current thinking of the LLM model to the user so that they are
            aware of what the model is thinking between function calls."""
            with debug_output:
                print(f"THOUGHTS:\n {thoughts}\n")
            return "success"

        # Functions for textual analysis of images

        def _lat_lon_to_tile(lon, lat, zoom_level):
            # Convert latitude and longitude to Mercator coordinates
            x_merc = (lon + 180) / 360
            y_merc = (
                1
                - math.log(
                    math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat))
                )
                / math.pi
            ) / 2

            # Calculate number of tiles
            n = 2**zoom_level

            # Convert to tile coordinates
            X = int(x_merc * n)
            Y = int(y_merc * n)

            return X, Y

        def analyze_image(additional_instructions: str = "") -> str:
            """Returns GenAI image analysis describing the current map image.
            Optional additional instructions might be passed to target the analysis
            more precisely.
            """
            # global map_dirty
            if self.map_dirty:
                print("MAP DIRTY")
                return (
                    "Map is not ready. Stop further processing and ask for user input"
                )

            try:
                return _analyze_image(additional_instructions)
            except ValueError as e:
                return str(e)

        def _analyze_image(additional_instructions: str = "") -> str:
            # bounds = m.bounds
            # s, w = bounds[0]
            # n, e = bounds[1]
            # zoom = int(m.zoom)

            # min_tile_x, max_tile_y = _lat_lon_to_tile(w, s, zoom)
            # max_tile_x, min_tile_y = _lat_lon_to_tile(e, n, zoom)
            # min_tile_x = max(0, min_tile_x)
            # max_tile_x = min(2**zoom - 1, max_tile_x)
            # min_tile_y = max(0, min_tile_y)
            # max_tile_y = min(2**zoom - 1, max_tile_y)

            # with debug_output:
            #     if additional_instructions:
            #         print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n")
            #     else:
            #         print("RUNNING IMAGE ANALYSIS...\n")

            # layers = list(m.ee_layer_dict.values())
            # if not layers:
            #     return "No data layers loaded"
            # url_template = layers[-1]["ee_layer"].url
            # tile_width = 256
            # tile_height = 256
            # image_width = (max_tile_x - min_tile_x + 1) * tile_width
            # image_height = (max_tile_y - min_tile_y + 1) * tile_height

            # # Create a new blank image
            # image = PIL.Image.new("RGB", (image_width, image_height))

            # for y in range(min_tile_y, max_tile_y + 1):
            #     for x in range(min_tile_x, max_tile_x + 1):
            #         tile_url = str.format(url_template, x=x, y=y, z=zoom)
            #         # print(tile_url)
            #         tile_img = PIL.Image.open(io.BytesIO(get_image(tile_url)))

            #         offset_x = (x - min_tile_x) * tile_width
            #         offset_y = (y - min_tile_y) * tile_height
            #         image.paste(tile_img, (offset_x, offset_y))

            # width, height = image.size
            # num_bands = len(image.getbands())

            with debug_output:
                if additional_instructions:
                    print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n")
                else:
                    print("RUNNING IMAGE ANALYSIS...\n")

            layers = list(m.ee_layer_dict.values())
            if not layers:
                return "No data layers loaded"
            image_temp_file = temp_file_path(extension="jpg")
            layer_name = layers[-1]["ee_layer"].name
            m.layer_to_image(layer_name, output=image_temp_file, scale=m.get_scale())
            image = PIL.Image.open(image_temp_file)

            image_array = np.array(image)
            image_min = np.min(image_array)
            image_max = np.max(image_array)

            file = open(image_temp_file, "rb")
            image_widget.value = file.read()
            file.close()

            # Skip an LLM call when we can simply tell that something is wrong.
            # (Also, LLMs might hallucinate on uniform images.)
            if image_min == image_max:
                return (
                    f"The image tile has a single uniform color with value "
                    f"{image_min}."
                )

            query = """You are an objective, precise overhead imagery analyst.
        Describe what the provided map tile depicts in terms of:

        1. The colors, textures, and patterns visible in the image.
        2. The spatial distribution, shape, and extent of distinct features or regions.
        3. Any notable contrasts, boundaries, or gradients between different areas.

        Avoid making assumptions about the specific geographic location, time period,
        or cause of the observed features. Focus solely on the literal contents of the
        image itself. Clearly indicate which features look natural, which look human-made,
        and which look like image artifacts. (Eg, a completely straight blue line
        is unlikely to be a river.)

        If the image is ambiguous or unclear, state so directly. Do not speculate or
        hypothesize beyond what is directly visible.

        If the image is of mostly the same color (white, gray, or black) with little
        contrast, just report that and do not describe the features.

        Use clear, concise language. Avoid subjective interpretations or analogies.
        Organize your response into structured paragraphs.
        """
            if additional_instructions:
                query += additional_instructions
            req = {
                "parts": [
                    {"text": query},
                    {"inline_data": image},
                ]
            }
            image_response = image_model.generate_content(req)
            try:
                with debug_output:
                    print(f"ANALYSIS RESULT: {image_response.text}\n")
                return image_response.text
            except ValueError as e:
                with debug_output:
                    print(f"UNEXPECTED IMAGE RESPONSE: {e}")
                    print(image_response)
                breakpoint()

        # Function for scoring how well image analysis corresponds to the user query.

        # Note that we ask for the score outside of the main agent chat to keep
        # the scoring more objective.

        scoring_system_prompt = """
        After looking at the user query and the map tile analysis, start
        your answer with a number between 0 and 1 indicating how relevant
        the image is as an answer to the query. (0=irrelevant, 1=perfect answer)

        Make sure you have enough justification to definitively declare the analysis
        relevant - it's better to give a false negative than a false positive. However,
        the image analysis identifies specific matching landmarks (eg, the
        the outlines of Manhattan island for a request to show NYC), believe it.

        Do not assume  too much (eg, that the presence of green doesn't by itself mean the
        image shows forest); attempt to find multiple (at least three) independent
        lines of evidence before declaring victory and cite all these lines of evidence
        in your response.

        Be very, very skeptical - look for specific features that match only the query
        and nothing else (eg, if the query looks for a river, a completely straight blue
        line is unlikely to be a river). Think about what size the features are based on
        the zoom level and whether this size matches the feature size expected from
        first principles.

        If there is ambiguity or uncertainty, express it in your analysis and
        lower the score accordingly. If the image analysis is inconclusive, try zooming
        out to make sure you are looking at the right spot. Do not reduce the score if
        the analysis does not mention visualization parameters - they are just given for
        your reference. The image might show an area slightly larger than requested -
        this is okay, do not reduce the score on this account.
        """

        def score_response(
            query: str, visualization_parameters: str, analysis: str
        ) -> str:
            """Returns how well the given analysis describes a map tile returned for
            the given query. The analysis starts with a number between 0 and 1.

            Arguments:
            query: user-specified query
            visualization_parameters: description of the bands used and visualization
                parameters applied to the map tile
            analysis: the textual description of the map tile
            """
            with debug_output:
                print(f"VIZ PARAMS: {visualization_parameters}\n")
            question = f"""For user query {query} please score the following analysis:
            {analysis}. The answer must start with a number between 0 and 1."""
            if visualization_parameters:
                question += f"""Do not assume that common bands or visualization
                parameters should have been used, as the visualization used the
                following parameters: {visualization_parameters}"""

            result = analysis_model.ask(question)
            # global iteration
            with debug_output:
                print(f"SCORE #{self.iteration}:\n {result}\n")
            try:
                self.iteration += 1
            except Exception as e:
                with debug_output:
                    print(f"UNEXPECTED SCORE RESPONSE: {e}")
            return result

        # Main prompt for the agent

        system_prompt = f"""
        The client is running in a Python notebook with a geemap Map displayed.
        When composing Python code, do not use getMapId - just return the single-line
        layer definition like 'ee.Image("USGS/SRTMGL1_003")' that we will pass to
        Map.addLayer(). Do not escape quotation marks in Python code.

        Be sure to use Python, not Javascript, syntax for keyword parameters in
        Python code (that is, "function(arg=value)") Using the provided functions,
        respond to the user command following below (or respond why it's not possible).
        If you get an Earth Engine error, attempt to fix it and then try again.

        Before you choose a dataset, think about what kind of dataset would be most
        suitable for the query. Also think about what zoom level would be suitable for
        the query, keeping in mind that for high-resolution image collections higher
        zoom levels are better to speed up tile loading.

        Once you have chosen a dataset, read its description using the provided function
        to see what spatial and temporal range it covers, what bands it has, as well as
        to find the recommended visualization parameters. Explain using the inner
        monlogue function why you chose a specific dataset, zoom level and map location.

        Prefer mosaicing image collections using the mosaic() function, don't get
        individual images from collections via
        'first()'. Choose a tile size and zoom level that will ensure the
        tile has enough pixels in it to avoid graininess, but not so many
        that processing becomes very expensive. Do not use wide date ranges
        with collections that have many images, but remember that Landsat and
        Sentinel-2 have revisit period of several days. Do not use sample
        locations - try to come up with actual locations that are relevant to
        the request.

        Use Landsat Collection 2, not Landsat Collection 1 ids. If you are getting
        repeated errors when filtering by a time range, read the dataset description
        to confirm that the dataset has data for the selected range.

        Important: after using the set_center() function, just say that you have called
        this function and wait for the user to hit enter, after which you should
        continue answering the original request. This will make sure the map is updated
        on the client side.

        Once the map is updated and the user told you to proceed, call the analyze_image
        function() to describe the image for the same location that will be shown in
        geemap. If you pass additional instructions to analyze_image(), do not disclose
        what the image is supposed to be to discourage hallucinations - you can only tell
        the analysis function to pay attention to specific areas (eg, center or top left)
        or shapes (eg, a line at the bottom) in the image. You can also tell the analysis
        function about the chosen bands, color palette and min/max visualization
        parameters, if any, to help it interpret the colors correctly. If the image
        turns out to be uniform in color with no features,
        use min/max visualization parameters to enhance contrast.

        Frequently call the inner_monologue() functions to tell the user about your
        current thought process. This is a good time to reflect if you have been running
        into repeated errors of the same kind, and if so, to try a different approach.

        When you are done, call the score_response() function to evaluate the analysis.
        You can also tell the scoring function about the chosen bands, color palette
        and min/max visualization parameters, if any. If the analysis score is below
        {target_score},
        keep trying to find and show a better image. You might have to change the dataset,
        map location, zoom level, date range, bands, or other parameters - think about
        what went wrong in the previous attempt and make the change that's most likely
        to improve the score.
        """

        # Class for LLM chat with function calling

        gemini_tools = [
            set_center,
            show_layer,
            analyze_image,
            inner_monologue,
            get_dataset_description,
            score_response,
        ]

        class Gemini:
            """Gemini LLM."""

            def __init__(
                self, system_prompt, tools=None, model_name="gemini-1.5-pro-latest"
            ):
                if not tools:
                    tools = []
                self._text_model = genai.GenerativeModel(
                    model_name=model_name, tools=tools
                )

                initial_messages = glm.Content(
                    role="model", parts=[glm.Part(text=system_prompt)]
                )
                self._chat_proxy = self._text_model.start_chat(
                    history=initial_messages, enable_automatic_function_calling=True
                )

            def ask(self, question, temperature=0):
                while True:
                    condition = ""
                    try:
                        sleep_duration = 10
                        response = self._text_model.generate_content(
                            question + condition
                        )
                        return response.text
                    except genai.types.generation_types.StopCandidateException as e:
                        if (
                            glm.Candidate.FinishReason.RECITATION
                            == e.args[0].finish_reason
                        ):
                            condition = (
                                "Previous attempt returned a RECITATION error. "
                                "Rephrase the answer to avoid it."
                            )
                        with chat_output:
                            command_input.description = "🆁"
                        time.sleep(1)
                        with chat_output:
                            command_input.description = "🤔"
                        continue
                    except (
                        google.api_core.exceptions.TooManyRequests,
                        google.api_core.exceptions.DeadlineExceeded,
                    ):
                        with debug_output:
                            command_input.description = "💤"
                        time.sleep(sleep_duration)
                        continue
                    except ValueError as e:
                        with debug_output:
                            print(f"Response {response} led to error: {e}")
                        breakpoint()
                        i = 1

            def chat(self, question: str, temperature=0) -> str:
                """Adds a question to the ongoing chat session."""
                # Always delay a bit to reduce the chance for rate-limiting errors.
                time.sleep(1)
                condition = ""
                sleep_duration = 10
                while True:
                    response = ""
                    try:
                        response = self._chat_proxy.send_message(
                            question + condition,
                            generation_config={
                                "temperature": temperature,
                                # Use a generous but limited output size to encourage in-depth
                                # replies.
                                "max_output_tokens": 5000,
                            },
                        )
                        if not response.parts:
                            raise ValueError(
                                "Cannot get analysis with reason"
                                f" {response.candidates[0].finish_reason.name}, terminating"
                            )
                    except genai.types.generation_types.StopCandidateException as e:
                        if (
                            glm.Candidate.FinishReason.RECITATION
                            == e.args[0].finish_reason
                        ):
                            condition = (
                                "Previous attempt returned a RECITATION error. "
                                "Rephrase the answer to avoid it."
                            )
                        with chat_output:
                            command_input.description = "🆁"
                        time.sleep(1)
                        with chat_output:
                            command_input.description = "🤔"
                        continue
                    except (
                        google.api_core.exceptions.TooManyRequests,
                        google.api_core.exceptions.DeadlineExceeded,
                    ):
                        with debug_output:
                            command_input.description = "💤"
                        time.sleep(10)
                        continue
                    try:
                        return response.text
                    except ValueError as e:
                        with debug_output:
                            print(f"Response {response} led to the error {e}")

        model = Gemini(system_prompt, gemini_tools, model_name=gemini_model)
        analysis_model = Gemini(scoring_system_prompt, model_name=gemini_model)

        # UI functions

        def set_cursor_waiting():
            js_code = """
            document.querySelector('body').style.cursor = 'wait';
            """
            display(HTML(f"<script>{js_code}</script>"))

        def set_cursor_default():
            js_code = """
            document.querySelector('body').style.cursor = 'default';
            """
            display(HTML(f"<script>{js_code}</script>"))

        def on_submit(widget):
            # global map_dirty
            self.map_dirty = False
            command_input.description = "❓"
            command = widget.value
            if not command:
                command = "go on"
            with chat_output:
                print("> " + command + "\n")
            if command != "go on":
                with debug_output:
                    print("> " + command + "\n")
            widget.value = ""
            set_cursor_waiting()
            command_input.description = "🤔"
            response = model.chat(command, temperature=0)
            if self.map_dirty:
                command_input.description = "🙏"
            else:
                command_input.description = "❓"
            set_cursor_default()
            response = response.strip()
            if not response:
                response = "<EMPTY RESPONSE, HIT ENTER>"
            with chat_output:
                print(response + "\n")
            command_input.value = ""

        command_input.on_submit(on_submit)

        # UI layout

        # Arrange the chat history and input in a vertical box
        chat_ui = widgets.VBox(
            [image_widget, chat_output],
            layout=widgets.Layout(width="420px", height=widget_height),
        )

        chat_output.layout = widgets.Layout(
            width="400px"
        )  # Fixed width for the left control
        m.layout = widgets.Layout(flex="1 1 auto", height=widget_height)

        table = widgets.HBox(
            [chat_ui, debug_output, m], layout=widgets.Layout(align_items="flex-start")
        )

        message_widget = widgets.Output()
        with message_widget:
            print("❓ = waiting for user input")
            print("🙏 = waiting for user to hit enter after calling set_center()")
            print("🤔 = thinking")
            print("💤 = sleeping due to retries")
            print("🆁 = Gemini recitation error")

        super().__init__(
            [table, command_input, message_widget], layout={"overflow": "hidden"}
        )

__init__(project=None, google_api_key=None, gemini_model='gemini-1.5-flash', target_score=0.8, widget_height='600px', initialize_ee=True)

Source code in geemap/ai.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
def __init__(
    self,
    project: Optional[str] = None,
    google_api_key: Optional[str] = None,
    gemini_model: str = "gemini-1.5-flash",
    target_score: float = 0.8,
    widget_height: str = "600px",
    initialize_ee: bool = True,
) -> None:
    # Initialization

    if project is None:
        project = get_env_var("EE_PROJECT_ID") or get_env_var("GOOGLE_PROJECT_ID")
    if project is None:
        raise ValueError(
            "Please provide a valid project ID via the 'project' parameter."
        )

    if google_api_key is None:
        google_api_key = get_env_var("GOOGLE_API_KEY")
    if google_api_key is None:
        raise ValueError(
            "Please provide a valid Google API key via the 'google_api_key' parameter."
        )

    if initialize_ee:
        ee_initialize(project=project)

    genai.configure(api_key=google_api_key)
    storage_client = storage.Client(project=project)
    bucket = storage_client.get_bucket("earthengine-stac")

    # Score to aim for (on the 0-1 scale). The exact meaning of what "score" means
    # is left to the LLM.

    # Count of analysis rounds

    self.iteration = 1
    self.map_dirty = False

    m = Map()
    m.add("layer_manager")
    self.map = m

    analysis_model = None

    image_model = genai.GenerativeModel(gemini_model)

    # UI widget definitions

    # We define the widgets early because some functions will write to the debug
    # and/or chat panels.

    command_input = widgets.Text(
        value="show a whole continent Australia DEM visualization using a palette that captures the elevation range",
        # value='show NYC',
        # value='show an area with many center pivot irrigation circles',
        # value='show a fire scar',
        # value='show an open pit mine',
        # value='a sea port',
        # value='flood consequences',
        # value='show an interesting modis composite with the relevant visualization and analyze it over Costa Rica',
        description="❓",
        layout=widgets.Layout(width="100%", height="50px"),
    )

    command_output = widgets.Label(
        value="Last command will be here",
    )

    status_label = widgets.Textarea(
        value="LLM response will be here",
        layout=widgets.Layout(width="50%", height="100px"),
    )

    # widget_height = "600px"
    debug_output = widgets.Output(
        layout={
            "border": "1px solid black",
            "height": widget_height,
            "overflow": "scroll",
            "width": "500px",
            "padding": "5px",
        }
    )
    with debug_output:
        print("DEBUG COLUMN\n")

    logo = requests.get(
        "https://drive.usercontent.google.com/download?id=1zE6G_nxXa3G5N0G_32jEhzdum2kMDfNY&export=view&authuser=0"
    ).content

    image_widget = widgets.Image(value=logo, format="png", width=400, height=600)

    chat_output = widgets.Output(
        layout={
            "border": "1px solid black",
            "height": "600px",
            "overflow": "scroll",
            "width": "300px",
        }
    )

    with chat_output:
        print("CHAT COLUMN\n")

    # Simple functions that LLM will call

    def set_center(x: float, y: float, zoom: int) -> str:
        """Sets the map center to the given coordinates and zoom level and
        returns instructions on what to do next."""
        with debug_output:
            print(f"SET_CENTER({x}, {y}, {zoom})\n")
        m.set_center(x, y)
        m.zoom = zoom
        # global map_dirty
        self.map_dirty = True
        return (
            "Do not call any more functions in this request to let geemap bounds "
            "update. Wait for user input."
        )

    def add_image_layer(image_id: str) -> str:
        """Adds to the map center an ee.Image with the given id
        and returns status message (success or failure)."""
        m.clear()
        command_output.value = f"add_image_layer('{image_id}')"
        m.addLayer(ee.Image(image_id))
        return "success"

    def get_dataset_description(dataset_id: str) -> str:
        """Fetches JSON STAC description for the given Earth Engine dataset id."""
        with debug_output:
            print(f"LOOKING UP {dataset_id}\n")
        parent = dataset_id.split("/")[0]

        # Get the blob (file)
        path = (
            os.path.join("catalog", parent, dataset_id.replace("/", "_")) + ".json"
        )
        blob = bucket.blob(path)

        if not blob.exists():
            return "dataset file not found: " + path

        file_contents = blob.download_as_string().decode()
        return file_contents

    def get_image(image_url: str) -> bytes:
        """Fetches from Earth Engine the content of the given URL as bytes."""
        response = requests.get(image_url)

        if response.status_code == 200:
            image_widget.value = response.content
            return response.content
        else:
            error_message = f"Error downloading image: {response}"
            try:
                error_details = (
                    json.loads(response.content.decode())
                    .get("error", {})
                    .get("message")
                )
                if error_details:
                    error_message += f" - {error_details}"
            except json.JSONDecodeError:
                pass
            with debug_output:
                print(error_message)
            raise ValueError("URL %s causes %s" % (image_url, error_message))

    def show_layer(python_code: str) -> str:
        """Execute the given Earth Engine Python client code and add the result to
        the map. Returns the status message (success or error message)."""
        m.layers = m.layers[:2]
        while '\\"' in python_code:
            python_code = python_code.replace('\\"', '"')
        command_output.value = f"show_layer('{python_code}')"
        with debug_output:
            print(f"IMAGE:\n {python_code}\n")
        try:
            locals = {}
            exec(f"import ee; im = {python_code}", {}, locals)
            m.addLayer(locals["im"])
        except Exception as e:
            with debug_output:
                print(f"ERROR: {e}")
            return str(e)
        return "success"

    def inner_monologue(thoughts: str) -> str:
        """Sends the current thinking of the LLM model to the user so that they are
        aware of what the model is thinking between function calls."""
        with debug_output:
            print(f"THOUGHTS:\n {thoughts}\n")
        return "success"

    # Functions for textual analysis of images

    def _lat_lon_to_tile(lon, lat, zoom_level):
        # Convert latitude and longitude to Mercator coordinates
        x_merc = (lon + 180) / 360
        y_merc = (
            1
            - math.log(
                math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat))
            )
            / math.pi
        ) / 2

        # Calculate number of tiles
        n = 2**zoom_level

        # Convert to tile coordinates
        X = int(x_merc * n)
        Y = int(y_merc * n)

        return X, Y

    def analyze_image(additional_instructions: str = "") -> str:
        """Returns GenAI image analysis describing the current map image.
        Optional additional instructions might be passed to target the analysis
        more precisely.
        """
        # global map_dirty
        if self.map_dirty:
            print("MAP DIRTY")
            return (
                "Map is not ready. Stop further processing and ask for user input"
            )

        try:
            return _analyze_image(additional_instructions)
        except ValueError as e:
            return str(e)

    def _analyze_image(additional_instructions: str = "") -> str:
        # bounds = m.bounds
        # s, w = bounds[0]
        # n, e = bounds[1]
        # zoom = int(m.zoom)

        # min_tile_x, max_tile_y = _lat_lon_to_tile(w, s, zoom)
        # max_tile_x, min_tile_y = _lat_lon_to_tile(e, n, zoom)
        # min_tile_x = max(0, min_tile_x)
        # max_tile_x = min(2**zoom - 1, max_tile_x)
        # min_tile_y = max(0, min_tile_y)
        # max_tile_y = min(2**zoom - 1, max_tile_y)

        # with debug_output:
        #     if additional_instructions:
        #         print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n")
        #     else:
        #         print("RUNNING IMAGE ANALYSIS...\n")

        # layers = list(m.ee_layer_dict.values())
        # if not layers:
        #     return "No data layers loaded"
        # url_template = layers[-1]["ee_layer"].url
        # tile_width = 256
        # tile_height = 256
        # image_width = (max_tile_x - min_tile_x + 1) * tile_width
        # image_height = (max_tile_y - min_tile_y + 1) * tile_height

        # # Create a new blank image
        # image = PIL.Image.new("RGB", (image_width, image_height))

        # for y in range(min_tile_y, max_tile_y + 1):
        #     for x in range(min_tile_x, max_tile_x + 1):
        #         tile_url = str.format(url_template, x=x, y=y, z=zoom)
        #         # print(tile_url)
        #         tile_img = PIL.Image.open(io.BytesIO(get_image(tile_url)))

        #         offset_x = (x - min_tile_x) * tile_width
        #         offset_y = (y - min_tile_y) * tile_height
        #         image.paste(tile_img, (offset_x, offset_y))

        # width, height = image.size
        # num_bands = len(image.getbands())

        with debug_output:
            if additional_instructions:
                print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n")
            else:
                print("RUNNING IMAGE ANALYSIS...\n")

        layers = list(m.ee_layer_dict.values())
        if not layers:
            return "No data layers loaded"
        image_temp_file = temp_file_path(extension="jpg")
        layer_name = layers[-1]["ee_layer"].name
        m.layer_to_image(layer_name, output=image_temp_file, scale=m.get_scale())
        image = PIL.Image.open(image_temp_file)

        image_array = np.array(image)
        image_min = np.min(image_array)
        image_max = np.max(image_array)

        file = open(image_temp_file, "rb")
        image_widget.value = file.read()
        file.close()

        # Skip an LLM call when we can simply tell that something is wrong.
        # (Also, LLMs might hallucinate on uniform images.)
        if image_min == image_max:
            return (
                f"The image tile has a single uniform color with value "
                f"{image_min}."
            )

        query = """You are an objective, precise overhead imagery analyst.
    Describe what the provided map tile depicts in terms of:

    1. The colors, textures, and patterns visible in the image.
    2. The spatial distribution, shape, and extent of distinct features or regions.
    3. Any notable contrasts, boundaries, or gradients between different areas.

    Avoid making assumptions about the specific geographic location, time period,
    or cause of the observed features. Focus solely on the literal contents of the
    image itself. Clearly indicate which features look natural, which look human-made,
    and which look like image artifacts. (Eg, a completely straight blue line
    is unlikely to be a river.)

    If the image is ambiguous or unclear, state so directly. Do not speculate or
    hypothesize beyond what is directly visible.

    If the image is of mostly the same color (white, gray, or black) with little
    contrast, just report that and do not describe the features.

    Use clear, concise language. Avoid subjective interpretations or analogies.
    Organize your response into structured paragraphs.
    """
        if additional_instructions:
            query += additional_instructions
        req = {
            "parts": [
                {"text": query},
                {"inline_data": image},
            ]
        }
        image_response = image_model.generate_content(req)
        try:
            with debug_output:
                print(f"ANALYSIS RESULT: {image_response.text}\n")
            return image_response.text
        except ValueError as e:
            with debug_output:
                print(f"UNEXPECTED IMAGE RESPONSE: {e}")
                print(image_response)
            breakpoint()

    # Function for scoring how well image analysis corresponds to the user query.

    # Note that we ask for the score outside of the main agent chat to keep
    # the scoring more objective.

    scoring_system_prompt = """
    After looking at the user query and the map tile analysis, start
    your answer with a number between 0 and 1 indicating how relevant
    the image is as an answer to the query. (0=irrelevant, 1=perfect answer)

    Make sure you have enough justification to definitively declare the analysis
    relevant - it's better to give a false negative than a false positive. However,
    the image analysis identifies specific matching landmarks (eg, the
    the outlines of Manhattan island for a request to show NYC), believe it.

    Do not assume  too much (eg, that the presence of green doesn't by itself mean the
    image shows forest); attempt to find multiple (at least three) independent
    lines of evidence before declaring victory and cite all these lines of evidence
    in your response.

    Be very, very skeptical - look for specific features that match only the query
    and nothing else (eg, if the query looks for a river, a completely straight blue
    line is unlikely to be a river). Think about what size the features are based on
    the zoom level and whether this size matches the feature size expected from
    first principles.

    If there is ambiguity or uncertainty, express it in your analysis and
    lower the score accordingly. If the image analysis is inconclusive, try zooming
    out to make sure you are looking at the right spot. Do not reduce the score if
    the analysis does not mention visualization parameters - they are just given for
    your reference. The image might show an area slightly larger than requested -
    this is okay, do not reduce the score on this account.
    """

    def score_response(
        query: str, visualization_parameters: str, analysis: str
    ) -> str:
        """Returns how well the given analysis describes a map tile returned for
        the given query. The analysis starts with a number between 0 and 1.

        Arguments:
        query: user-specified query
        visualization_parameters: description of the bands used and visualization
            parameters applied to the map tile
        analysis: the textual description of the map tile
        """
        with debug_output:
            print(f"VIZ PARAMS: {visualization_parameters}\n")
        question = f"""For user query {query} please score the following analysis:
        {analysis}. The answer must start with a number between 0 and 1."""
        if visualization_parameters:
            question += f"""Do not assume that common bands or visualization
            parameters should have been used, as the visualization used the
            following parameters: {visualization_parameters}"""

        result = analysis_model.ask(question)
        # global iteration
        with debug_output:
            print(f"SCORE #{self.iteration}:\n {result}\n")
        try:
            self.iteration += 1
        except Exception as e:
            with debug_output:
                print(f"UNEXPECTED SCORE RESPONSE: {e}")
        return result

    # Main prompt for the agent

    system_prompt = f"""
    The client is running in a Python notebook with a geemap Map displayed.
    When composing Python code, do not use getMapId - just return the single-line
    layer definition like 'ee.Image("USGS/SRTMGL1_003")' that we will pass to
    Map.addLayer(). Do not escape quotation marks in Python code.

    Be sure to use Python, not Javascript, syntax for keyword parameters in
    Python code (that is, "function(arg=value)") Using the provided functions,
    respond to the user command following below (or respond why it's not possible).
    If you get an Earth Engine error, attempt to fix it and then try again.

    Before you choose a dataset, think about what kind of dataset would be most
    suitable for the query. Also think about what zoom level would be suitable for
    the query, keeping in mind that for high-resolution image collections higher
    zoom levels are better to speed up tile loading.

    Once you have chosen a dataset, read its description using the provided function
    to see what spatial and temporal range it covers, what bands it has, as well as
    to find the recommended visualization parameters. Explain using the inner
    monlogue function why you chose a specific dataset, zoom level and map location.

    Prefer mosaicing image collections using the mosaic() function, don't get
    individual images from collections via
    'first()'. Choose a tile size and zoom level that will ensure the
    tile has enough pixels in it to avoid graininess, but not so many
    that processing becomes very expensive. Do not use wide date ranges
    with collections that have many images, but remember that Landsat and
    Sentinel-2 have revisit period of several days. Do not use sample
    locations - try to come up with actual locations that are relevant to
    the request.

    Use Landsat Collection 2, not Landsat Collection 1 ids. If you are getting
    repeated errors when filtering by a time range, read the dataset description
    to confirm that the dataset has data for the selected range.

    Important: after using the set_center() function, just say that you have called
    this function and wait for the user to hit enter, after which you should
    continue answering the original request. This will make sure the map is updated
    on the client side.

    Once the map is updated and the user told you to proceed, call the analyze_image
    function() to describe the image for the same location that will be shown in
    geemap. If you pass additional instructions to analyze_image(), do not disclose
    what the image is supposed to be to discourage hallucinations - you can only tell
    the analysis function to pay attention to specific areas (eg, center or top left)
    or shapes (eg, a line at the bottom) in the image. You can also tell the analysis
    function about the chosen bands, color palette and min/max visualization
    parameters, if any, to help it interpret the colors correctly. If the image
    turns out to be uniform in color with no features,
    use min/max visualization parameters to enhance contrast.

    Frequently call the inner_monologue() functions to tell the user about your
    current thought process. This is a good time to reflect if you have been running
    into repeated errors of the same kind, and if so, to try a different approach.

    When you are done, call the score_response() function to evaluate the analysis.
    You can also tell the scoring function about the chosen bands, color palette
    and min/max visualization parameters, if any. If the analysis score is below
    {target_score},
    keep trying to find and show a better image. You might have to change the dataset,
    map location, zoom level, date range, bands, or other parameters - think about
    what went wrong in the previous attempt and make the change that's most likely
    to improve the score.
    """

    # Class for LLM chat with function calling

    gemini_tools = [
        set_center,
        show_layer,
        analyze_image,
        inner_monologue,
        get_dataset_description,
        score_response,
    ]

    class Gemini:
        """Gemini LLM."""

        def __init__(
            self, system_prompt, tools=None, model_name="gemini-1.5-pro-latest"
        ):
            if not tools:
                tools = []
            self._text_model = genai.GenerativeModel(
                model_name=model_name, tools=tools
            )

            initial_messages = glm.Content(
                role="model", parts=[glm.Part(text=system_prompt)]
            )
            self._chat_proxy = self._text_model.start_chat(
                history=initial_messages, enable_automatic_function_calling=True
            )

        def ask(self, question, temperature=0):
            while True:
                condition = ""
                try:
                    sleep_duration = 10
                    response = self._text_model.generate_content(
                        question + condition
                    )
                    return response.text
                except genai.types.generation_types.StopCandidateException as e:
                    if (
                        glm.Candidate.FinishReason.RECITATION
                        == e.args[0].finish_reason
                    ):
                        condition = (
                            "Previous attempt returned a RECITATION error. "
                            "Rephrase the answer to avoid it."
                        )
                    with chat_output:
                        command_input.description = "🆁"
                    time.sleep(1)
                    with chat_output:
                        command_input.description = "🤔"
                    continue
                except (
                    google.api_core.exceptions.TooManyRequests,
                    google.api_core.exceptions.DeadlineExceeded,
                ):
                    with debug_output:
                        command_input.description = "💤"
                    time.sleep(sleep_duration)
                    continue
                except ValueError as e:
                    with debug_output:
                        print(f"Response {response} led to error: {e}")
                    breakpoint()
                    i = 1

        def chat(self, question: str, temperature=0) -> str:
            """Adds a question to the ongoing chat session."""
            # Always delay a bit to reduce the chance for rate-limiting errors.
            time.sleep(1)
            condition = ""
            sleep_duration = 10
            while True:
                response = ""
                try:
                    response = self._chat_proxy.send_message(
                        question + condition,
                        generation_config={
                            "temperature": temperature,
                            # Use a generous but limited output size to encourage in-depth
                            # replies.
                            "max_output_tokens": 5000,
                        },
                    )
                    if not response.parts:
                        raise ValueError(
                            "Cannot get analysis with reason"
                            f" {response.candidates[0].finish_reason.name}, terminating"
                        )
                except genai.types.generation_types.StopCandidateException as e:
                    if (
                        glm.Candidate.FinishReason.RECITATION
                        == e.args[0].finish_reason
                    ):
                        condition = (
                            "Previous attempt returned a RECITATION error. "
                            "Rephrase the answer to avoid it."
                        )
                    with chat_output:
                        command_input.description = "🆁"
                    time.sleep(1)
                    with chat_output:
                        command_input.description = "🤔"
                    continue
                except (
                    google.api_core.exceptions.TooManyRequests,
                    google.api_core.exceptions.DeadlineExceeded,
                ):
                    with debug_output:
                        command_input.description = "💤"
                    time.sleep(10)
                    continue
                try:
                    return response.text
                except ValueError as e:
                    with debug_output:
                        print(f"Response {response} led to the error {e}")

    model = Gemini(system_prompt, gemini_tools, model_name=gemini_model)
    analysis_model = Gemini(scoring_system_prompt, model_name=gemini_model)

    # UI functions

    def set_cursor_waiting():
        js_code = """
        document.querySelector('body').style.cursor = 'wait';
        """
        display(HTML(f"<script>{js_code}</script>"))

    def set_cursor_default():
        js_code = """
        document.querySelector('body').style.cursor = 'default';
        """
        display(HTML(f"<script>{js_code}</script>"))

    def on_submit(widget):
        # global map_dirty
        self.map_dirty = False
        command_input.description = "❓"
        command = widget.value
        if not command:
            command = "go on"
        with chat_output:
            print("> " + command + "\n")
        if command != "go on":
            with debug_output:
                print("> " + command + "\n")
        widget.value = ""
        set_cursor_waiting()
        command_input.description = "🤔"
        response = model.chat(command, temperature=0)
        if self.map_dirty:
            command_input.description = "🙏"
        else:
            command_input.description = "❓"
        set_cursor_default()
        response = response.strip()
        if not response:
            response = "<EMPTY RESPONSE, HIT ENTER>"
        with chat_output:
            print(response + "\n")
        command_input.value = ""

    command_input.on_submit(on_submit)

    # UI layout

    # Arrange the chat history and input in a vertical box
    chat_ui = widgets.VBox(
        [image_widget, chat_output],
        layout=widgets.Layout(width="420px", height=widget_height),
    )

    chat_output.layout = widgets.Layout(
        width="400px"
    )  # Fixed width for the left control
    m.layout = widgets.Layout(flex="1 1 auto", height=widget_height)

    table = widgets.HBox(
        [chat_ui, debug_output, m], layout=widgets.Layout(align_items="flex-start")
    )

    message_widget = widgets.Output()
    with message_widget:
        print("❓ = waiting for user input")
        print("🙏 = waiting for user to hit enter after calling set_center()")
        print("🤔 = thinking")
        print("💤 = sleeping due to retries")
        print("🆁 = Gemini recitation error")

    super().__init__(
        [table, command_input, message_widget], layout={"overflow": "hidden"}
    )

PrecomputedEmbeddings

Bases: Embeddings

Class for handling precomputed embeddings.

Source code in geemap/ai.py
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
class PrecomputedEmbeddings(Embeddings):
    """Class for handling precomputed embeddings."""

    def __init__(self, embeddings_dict: Dict[str, List[float]]) -> None:
        """Initializes the PrecomputedEmbeddings.

        Args:
            embeddings_dict (Dict[str, List[float]]): A dictionary mapping texts to their embeddings.
        """
        self.embeddings_dict = embeddings_dict
        self.model = TextEmbeddingModel.from_pretrained("google/text-embedding-004")

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embeds a list of documents.

        Args:
            texts (List[str]): The list of texts to embed.

        Returns:
            List[List[float]]: The list of embeddings.
        """
        return [self.embeddings_dict[text] for text in texts]

    def embed_query(self, text: str) -> List[float]:
        """Embeds a query text.

        Args:
            text (str): The query text to embed.

        Returns:
            List[float]: The embedding of the query text.
        """
        embeddings = self.model.get_embeddings([text])
        return embeddings[0].values

__init__(embeddings_dict)

Initializes the PrecomputedEmbeddings.

Parameters:

Name Type Description Default
embeddings_dict Dict[str, List[float]]

A dictionary mapping texts to their embeddings.

required
Source code in geemap/ai.py
1489
1490
1491
1492
1493
1494
1495
1496
def __init__(self, embeddings_dict: Dict[str, List[float]]) -> None:
    """Initializes the PrecomputedEmbeddings.

    Args:
        embeddings_dict (Dict[str, List[float]]): A dictionary mapping texts to their embeddings.
    """
    self.embeddings_dict = embeddings_dict
    self.model = TextEmbeddingModel.from_pretrained("google/text-embedding-004")

embed_documents(texts)

Embeds a list of documents.

Parameters:

Name Type Description Default
texts List[str]

The list of texts to embed.

required

Returns:

Type Description
List[List[float]]

List[List[float]]: The list of embeddings.

Source code in geemap/ai.py
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
def embed_documents(self, texts: List[str]) -> List[List[float]]:
    """Embeds a list of documents.

    Args:
        texts (List[str]): The list of texts to embed.

    Returns:
        List[List[float]]: The list of embeddings.
    """
    return [self.embeddings_dict[text] for text in texts]

embed_query(text)

Embeds a query text.

Parameters:

Name Type Description Default
text str

The query text to embed.

required

Returns:

Type Description
List[float]

List[float]: The embedding of the query text.

Source code in geemap/ai.py
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
def embed_query(self, text: str) -> List[float]:
    """Embeds a query text.

    Args:
        text (str): The query text to embed.

    Returns:
        List[float]: The embedding of the query text.
    """
    embeddings = self.model.get_embeddings([text])
    return embeddings[0].values

explain_relevance(query, dataset_id, catalog, model_name='gemini-1.5-pro-latest', stream=False)

Prompts LLM to explain the relevance of a dataset to a query.

Parameters:

Name Type Description Default
query str

The user's query.

required
dataset_id str

The ID of the dataset.

required
catalog Catalog

The catalog containing the dataset.

required
model_name str

The name of the model to use. Defaults to "gemini-1.5-pro-latest".

'gemini-1.5-pro-latest'
stream bool

Whether to stream the response. Defaults to False.

False

Returns:

Name Type Description
str str

The explanation of the dataset's relevance to the query.

Source code in geemap/ai.py
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
def explain_relevance(
    query: str,
    dataset_id: str,
    catalog: Catalog,
    model_name: str = "gemini-1.5-pro-latest",
    stream: bool = False,
) -> str:
    """Prompts LLM to explain the relevance of a dataset to a query.

    Args:
        query (str): The user's query.
        dataset_id (str): The ID of the dataset.
        catalog (Catalog): The catalog containing the dataset.
        model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest".
        stream (bool): Whether to stream the response. Defaults to False.

    Returns:
        str: The explanation of the dataset's relevance to the query.
    """

    stac_json = catalog.get_collection(dataset_id).stac_json
    return explain_relevance_from_stac_json(query, stac_json, model_name, stream)

explain_relevance_from_stac_json(query, stac_json, model_name='gemini-1.5-pro-latest', stream=False)

Prompts LLM to explain the relevance of a dataset to a query using its STAC JSON.

Parameters:

Name Type Description Default
query str

The user's query.

required
stac_json Dict[str, Any]

The STAC JSON of the dataset.

required
model_name str

The name of the model to use. Defaults to "gemini-1.5-pro-latest".

'gemini-1.5-pro-latest'
stream bool

Whether to stream the response. Defaults to False.

False

Returns:

Name Type Description
str str

The explanation of the dataset's relevance to the query.

Source code in geemap/ai.py
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(
        (requests.exceptions.RequestException, ConnectionError)
    ),
)
def explain_relevance_from_stac_json(
    query: str,
    stac_json: Dict[str, Any],
    model_name: str = "gemini-1.5-pro-latest",
    stream: bool = False,
) -> str:
    """Prompts LLM to explain the relevance of a dataset to a query using its STAC JSON.

    Args:
        query (str): The user's query.
        stac_json (Dict[str, Any]): The STAC JSON of the dataset.
        model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest".
        stream (bool): Whether to stream the response. Defaults to False.

    Returns:
        str: The explanation of the dataset's relevance to the query.
    """
    stac_json_str = json.dumps(stac_json)

    prompt = f"""
  I am an Earth Engine user contemplating using a dataset to support
  my investigation of the following query. Provide a concise, paragraph-long
  summary explaining why this dataset may be a good fit for my use case.
  If it does not seem like an appropriate dataset, say so.
  If relevant, call attention to a max of 3 bands that may be of particular interest.
  Weigh the tradeoffs between temporal and spatial resolution, particularly
  if the original query specifies regions of interest, time periods, or
  frequency of data collection. If I have not specified any
  spatial constraints, do your best based on the nature of their query. For example,
  if I'm wanting to study something small, like buildings, I will likely need good spatial resolution.

  Here is the original query:
  {query}

  Here is the stac json metadata for the dataset:
  {stac_json_str}
  """
    model = genai.GenerativeModel(model_name)
    response = model.generate_content(prompt, stream=stream)
    if stream:
        return response
    return response.text

fix_ee_python_code(code, ee, geemap_instance, model_name='gemini-1.5-pro-latest')

Asks a model to do ee python code correction in the event of error.

Parameters:

Name Type Description Default
code str

The Earth Engine Python code to fix.

required
ee Any

The Earth Engine module.

required
geemap_instance Map

The geemap instance.

required
model_name str

The name of the model to use. Defaults to "gemini-1.5-pro-latest".

'gemini-1.5-pro-latest'

Returns:

Name Type Description
str str

The corrected Earth Engine Python code.

Source code in geemap/ai.py
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(
        (requests.exceptions.RequestException, ConnectionError)
    ),
)
def fix_ee_python_code(
    code: str, ee: Any, geemap_instance: Map, model_name: str = "gemini-1.5-pro-latest"
) -> str:
    """Asks a model to do ee python code correction in the event of error.

    Args:
        code (str): The Earth Engine Python code to fix.
        ee (Any): The Earth Engine module.
        geemap_instance (Map): The geemap instance.
        model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest".

    Returns:
        str: The corrected Earth Engine Python code.
    """

    def create_error_prompt(code: str, error_msg: str) -> str:
        return f"""
      You are an extremely laconic code correction robot.
      I am attempting to execute the following snippet of python Earth Engine code,
      using a geemap instance:

      ```
        {code}
      ```

      I have encountered the following error, please fix it. In 1-2 sentences,
      explain your debugging thought process in the 'thoughts' field. Note that
      the setOptions method exists only in the ee javascript library. Code
      referencing that method can be removed.

      Include the complete revised code snippet in the code field.
      Do not provide any other comentary in the code field. Do not add any new
      imports to the code snippet.

      {error_msg}
    """

    generation_config = {
        "response_mime_type": "application/json",
        "response_schema": CodeThoughts,
    }

    model = genai.GenerativeModel(model_name, generation_config=generation_config)

    max_attempts = 5
    total_attempts = 0
    broken = True
    while total_attempts < max_attempts and broken:
        try:
            run_ee_code(code, ee, geemap_instance)
            # logging.warning(f'Code success! after {total_attempts} try.')
            return code
        except Exception as e:
            logging.warning("Code execution error, asking Gemini for help.")

            gemini_json_fail = True
            while gemini_json_fail:
                response = model.generate_content(create_error_prompt(code, str(e)))
                try:
                    code_thoughts = json.loads(response.text)
                    gemini_json_fail = False
                except json.JSONDecodeError:
                    pass

            total_attempts += 1

            code = code_thoughts["code"]
            thoughts = code_thoughts["thoughts"]
            logging.warning(f"Gemini thoughts: \n{thoughts}")
            # logging.warning(f'Gemini new code: {code}')
            if total_attempts == max_attempts:
                raise Exception(e)
    logging.warning(f"Failed to fix code after {max_attempts} attempts.")

is_valid_question(question, model_name='gemini-1.5-pro-latest')

Filters out questions that cannot be answered by a dataset search tool.

Parameters:

Name Type Description Default
question str

The user's question.

required
model_name str

The name of the model to use. Defaults to "gemini-1.5-pro-latest".

'gemini-1.5-pro-latest'

Returns:

Name Type Description
bool bool

True if the question is valid, False otherwise.

Source code in geemap/ai.py
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(
        (requests.exceptions.RequestException, ConnectionError)
    ),
)
def is_valid_question(question: str, model_name: str = "gemini-1.5-pro-latest") -> bool:
    """Filters out questions that cannot be answered by a dataset search tool.

    Args:
        question (str): The user's question.
        model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest".

    Returns:
        bool: True if the question is valid, False otherwise.
    """

    prompt = f"""
  You are a tool whose job is to determine whether or not the following question
  relates even in a small way to geospatial datasets.  Please provide a single
  word answer either True or False.

  For example, if the original query is "hello" - you should answer False. If
  the original query is "cheese futures" you should still answer True because
  the user could be interested in cheese production, and therefore agricultural
  land where cattle might be raised.

  Here is the original query:
  {question}
  """
    model = genai.GenerativeModel(model_name)
    response = model.generate_content(prompt)
    # Err on the side of returning True
    return response.text.lower().strip() != "false"

make_langchain_index(embeddings_df)

Creates an index from a dataframe of precomputed embeddings.

Parameters:

Name Type Description Default
embeddings_df DataFrame

The dataframe containing precomputed embeddings.

required

Returns:

Name Type Description
VectorStoreIndexWrapper VectorStoreIndexWrapper

The vector store index wrapper.

Source code in geemap/ai.py
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
def make_langchain_index(embeddings_df: pd.DataFrame) -> VectorStoreIndexWrapper:
    """Creates an index from a dataframe of precomputed embeddings.

    Args:
        embeddings_df (pd.DataFrame): The dataframe containing precomputed embeddings.

    Returns:
        VectorStoreIndexWrapper: The vector store index wrapper.
    """
    # Create a dictionary mapping texts to their embeddings
    embeddings_dict = dict(zip(embeddings_df["id"], embeddings_df["embedding"]))

    # Create our custom embeddings class
    precomputed_embeddings = PrecomputedEmbeddings(embeddings_dict)

    # Create Langchain Document objects
    documents = []
    for index, row in embeddings_df.iterrows():
        page_content = row["id"]
        metadata = {"summary": row["summary"], "name": row["name"]}
        documents.append(Document(page_content=page_content, metadata=metadata))

    # Create the VectorstoreIndexCreator
    index_creator = VectorstoreIndexCreator(embedding=precomputed_embeddings)

    # Create the index
    return index_creator.from_documents(documents)

matches_datetime(collection_interval, query_datetime)

Checks if the collection's datetime interval matches the query datetime.

Parameters:

Name Type Description Default
collection_interval Tuple[datetime, Optional[datetime]]

Temporal interval of the collection.

required
query_datetime datetime

A datetime coming from a query.

required

Returns:

Name Type Description
bool bool

True if the datetime interval matches, False otherwise.

Source code in geemap/ai.py
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
def matches_datetime(
    collection_interval: Tuple[datetime.datetime, Optional[datetime.datetime]],
    query_datetime: datetime.datetime,
) -> bool:
    """Checks if the collection's datetime interval matches the query datetime.

    Args:
        collection_interval (Tuple[datetime.datetime, Optional[datetime.datetime]]):
            Temporal interval of the collection.
        query_datetime (datetime.datetime): A datetime coming from a query.

    Returns:
        bool: True if the datetime interval matches, False otherwise.
    """
    if collection_interval[1] is None:
        # End date should always be set in STAC JSON files, but just in case...
        end_date = datetime.datetime.now(tz=datetime.UTC)
    else:
        end_date = collection_interval[1]
    return collection_interval[0] <= query_datetime <= end_date

matches_interval(collection_interval, query_interval)

Checks if the collection's datetime interval matches the query datetime interval.

Parameters:

Name Type Description Default
collection_interval Tuple[datetime, datetime]

Temporal interval of the collection.

required
query_interval Tuple[datetime, datetime]

A tuple with the query interval start and end.

required

Returns:

Name Type Description
bool bool

True if the datetime interval matches, False otherwise.

Source code in geemap/ai.py
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
def matches_interval(
    collection_interval: Tuple[datetime.datetime, datetime.datetime],
    query_interval: Tuple[datetime.datetime, datetime.datetime],
) -> bool:
    """Checks if the collection's datetime interval matches the query datetime interval.

    Args:
        collection_interval (Tuple[datetime.datetime, datetime.datetime]):
            Temporal interval of the collection.
        query_interval (Tuple[datetime.datetime, datetime.datetime]): A tuple
            with the query interval start and end.

    Returns:
        bool: True if the datetime interval matches, False otherwise.
    """
    start_query, end_query = query_interval
    start_collection, end_collection = collection_interval
    if end_collection is None:
        # End date should always be set in STAC JSON files, but just in case...
        end_collection = datetime.datetime.now(tz=datetime.UTC)
    return end_query > start_collection and start_query <= end_collection

run_ee_code(code, ee, geemap_instance)

Executes Earth Engine Python code within the context of a geemap instance.

Parameters:

Name Type Description Default
code str

The Earth Engine Python code to execute.

required
ee Any

The Earth Engine module.

required
geemap_instance Map

The geemap instance.

required

Raises:

Type Description
Exception

Re-raises any exception encountered during code execution.

Source code in geemap/ai.py
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
@tenacity.retry(
    stop=tenacity.stop_after_attempt(3),
    wait=tenacity.wait_fixed(1),
    retry=tenacity.retry_if_exception_type(LayerException),
    # before_sleep=lambda retry_state: print(f"LayerException occurred. Retrying in 1 seconds... (Attempt {retry_state.attempt_number}/3)")
)
def run_ee_code(code: str, ee: Any, geemap_instance: Map) -> None:
    """Executes Earth Engine Python code within the context of a geemap instance.

    Args:
        code (str): The Earth Engine Python code to execute.
        ee (Any): The Earth Engine module.
        geemap_instance (Map): The geemap instance.

    Raises:
        Exception: Re-raises any exception encountered during code execution.
    """
    try:
        # geemap appears to have some stray print statements.
        _ = io.StringIO()
        with redirect_stdout(_):
            # Note that sometimes the geemap code uses both 'Map' and 'm' to refer to a map instance.
            exec(code, {"ee": ee, "Map": geemap_instance, "m": geemap_instance})
    except Exception:
        # Re-raise the exception with the original traceback
        exc_type, exc_value, exc_traceback = sys.exc_info()
        raise exc_value.with_traceback(exc_traceback)