-
Notifications
You must be signed in to change notification settings - Fork 48
Expand file tree
/
Copy pathunsupervised-learning.pymd
More file actions
142 lines (99 loc) · 5.42 KB
/
unsupervised-learning.pymd
File metadata and controls
142 lines (99 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Unsupervised Machine Learning
In this example, we will demonstrate how to fit and score an unsupervised learning model with a [sample of Landsat 8 data](https://github.com/locationtech/rasterframes/tree/develop/pyrasterframes/src/test/resources).
## Imports and Data Preparation
```python, setup, echo=False
from IPython.core.display import display
import pyrasterframes.rf_ipython
from pyrasterframes.utils import create_rf_spark_session
import os
spark = create_rf_spark_session()
```
We import various Spark components needed to construct our `Pipeline`.
```python, imports, echo=True
import pandas as pd
from pyrasterframes import TileExploder
from pyrasterframes.rasterfunctions import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans
from pyspark.ml import Pipeline
```
The first step is to create a Spark DataFrame of our imagery data. To achieve that we will create a catalog DataFrame using the pattern from [the I/O page](raster-io.html#Single-Scene--Multiple-Bands). In the catalog, each row represents a distinct area and time, and each column is the URI to a band's image product. The resulting Spark DataFrame may have many rows per URI, with a column corresponding to each band.
```python, catalog
filenamePattern = "https://rasterframes.s3.amazonaws.com/samples/elkton/L8-B{}-Elkton-VA.tiff"
catalog_df = pd.DataFrame([
{'b' + str(b): filenamePattern.format(b) for b in range(1, 8)}
])
tile_size = 256
df = spark.read.raster(catalog_df, catalog_col_names=catalog_df.columns, tile_size=tile_size)
df = df.withColumn('crs', rf_crs(df.b1)) \
.withColumn('extent', rf_extent(df.b1))
df.printSchema()
```
In this small example, all the images in our `catalog_df` have the same @ref:[CRS](concepts.md#coordinate-reference-system-crs-), which we verify in the code snippet below. The `crs` object will be useful for visualization later.
```python, crses
crses = df.select('crs.crsProj4').distinct().collect()
print('Found ', len(crses), 'distinct CRS: ', crses)
assert len(crses) == 1
crs = crses[0]['crsProj4']
```
## Create ML Pipeline
SparkML requires that each observation be in its own row, and features for each observation be packed into a single `Vector`. For this unsupervised learning problem, we will treat each _pixel_ as an observation and each band as a feature. The first step is to "explode" the _tiles_ into a single row per pixel. In RasterFrames, generally a pixel is called a @ref:[`cell`](concepts.md#cell).
```python, exploder
exploder = TileExploder()
```
To "vectorize" the band columns, we use the SparkML `VectorAssembler`. Each of the seven bands is a different feature.
```python, assembler
assembler = VectorAssembler() \
.setInputCols(list(catalog_df.columns)) \
.setOutputCol("features")
```
For this problem, we will use the K-means clustering algorithm and configure our model to have 5 clusters.
```python, kmeans
kmeans = KMeans().setK(5).setFeaturesCol('features')
```
We can combine the above stages into a single [`Pipeline`](https://spark.apache.org/docs/latest/ml-pipeline.html).
```python, pipeline
pipeline = Pipeline().setStages([exploder, assembler, kmeans])
```
## Fit the Model and Score
Fitting the _pipeline_ actually executes exploding the _tiles_, assembling the features _vectors_, and fitting the K-means clustering model.
```python, fit
model = pipeline.fit(df)
```
We can use the `transform` function to score the training data in the fitted _pipeline_ model. This will add a column called `prediction` with the closest cluster identifier.
```python, transform
clustered = model.transform(df)
```
Now let's take a look at some sample output.
```python, view_predictions
clustered.select('prediction', 'extent', 'column_index', 'row_index', 'features')
```
If we want to inspect the model statistics, the SparkML API requires us to go through this unfortunate contortion to access the clustering results:
```python, cluster_stats
cluster_stage = model.stages[2]
```
We can then compute the sum of squared distances of points to their nearest center, which is elemental to most cluster quality metrics.
```python, distance
metric = cluster_stage.computeCost(clustered)
print("Within set sum of squared errors: %s" % metric)
```
## Visualize Prediction
We can recreate the tiled data structure using the metadata added by the `TileExploder` pipeline stage.
```python, assemble
from pyrasterframes.rf_types import CellType
retiled = clustered.groupBy('extent', 'crs') \
.agg(
rf_assemble_tile('column_index', 'row_index', 'prediction',
tile_size, tile_size, CellType.int8())
)
```
Next we will @ref:[write the output to a GeoTiff file](raster-write.md#geotiffs). Doing so in this case works quickly and well for a few specific reasons that may not hold in all cases. We can write the data at full resolution, by omitting the `raster_dimensions` argument, because we know the input raster dimensions are small. Also, the data is all in a single CRS, as we demonstrated above. Because the `catalog_df` is only a single row, we know the output GeoTIFF value at a given location corresponds to a single input. Finally, the `retiled` `DataFrame` only has a single `Tile` column, so the band interpretation is trivial.
```python, viz
import rasterio
output_tif = 'unsupervised.tif'
retiled.write.geotiff(output_tif, crs=crs)
with rasterio.open(output_tif) as src:
for b in range(1, src.count + 1):
print("Tags on band", b, src.tags(b))
display(src)
```