Source code for task_geo.testing
LOCATION_TERMS = ['city', 'sub_region', 'region', 'country']
[docs]def get_geographical_granularity(data):
    """Returns the index of the lowest granularity level available in data."""
    for index, term in enumerate(LOCATION_TERMS):
        if term in data.columns:
            return index
[docs]def check_column_and_get_index(name, data, reference):
    """Check that name is present in data.columns and returns it's index."""
    assert name in data.columns, f'If {reference} is provided, {name} should be present too'
    return data.columns.get_loc(name)
[docs]def check_dataset_format(data):
    message = 'All columns should be lowercase'
    assert all(data.columns == [column.lower() for column in data.columns]), message
    if 'latitude' in data.columns:
        message = 'Latitude should be provided with Longitude'
        assert 'longitude' in data.columns, message
    granularity = get_geographical_granularity(data)
    if granularity is not None:
        min_granularity_term = LOCATION_TERMS[granularity]
        locations = [data.columns.get_loc(LOCATION_TERMS[granularity])]
        for term in LOCATION_TERMS[granularity + 1:]:
            locations.append(check_column_and_get_index(term, data, min_granularity_term))
        message = 'The correct ordening of the columns is "country, region, sub_region, city"'
        assert (locations[::-1] == sorted(locations)), message
    message = 'date and timestamp columns should be of datetime dtype'
    time_index = None
    if 'date' in data.columns:
        assert data.date.dtype == 'datetime64[ns]', message
        time_index = data.columns.get_loc('date')
    if 'timestamp' in data.columns:
        assert data.timestamp.dtype == 'datetime64[ns]', message
        time_index = data.columns.get_loc('timestamp')
    if time_index is not None and granularity is not None:
        message = 'geographical columns should be before time columns.'
        assert all(time_index > location for location in locations), message
    object_columns = data.select_dtypes(include='object').columns
    for column in object_columns:
        try:
            data[column].astype(float)
            assert False, f'Column {column} is of dtype object, but casteable to float.'
        except ValueError:
            pass