Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,37 @@ package org.locationtech.rasterframes.extensions
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.expressions.SpatialRelation
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
import org.locationtech.rasterframes.functions.reproject_and_merge
import org.locationtech.rasterframes.util._

import scala.util.Random

object RasterJoin {

/** Perform a raster join on dataframes that each have proj_raster columns, or crs and extent explicitly included. */
def apply(left: DataFrame, right: DataFrame): DataFrame = {
val df = apply(left, right, left("extent"), left("crs"), right("extent"), right("crs"))
df.drop(right("extent")).drop(right("crs"))
def usePRT(d: DataFrame) =
d.projRasterColumns.headOption
.map(p => (rf_crs(p), rf_extent(p)))
.orElse(Some(col("crs"), col("extent")))
.map { case (crs, extent) =>
val d2 = d.withColumn("crs", crs).withColumn("extent", extent)
(d2, d2("crs"), d2("extent"))
}
.get

val (ldf, lcrs, lextent) = usePRT(left)
val (rdf, rcrs, rextent) = usePRT(right)

apply(ldf, rdf, lextent, lcrs, rextent, rcrs)
}

def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column): DataFrame = {
val leftGeom = st_geometry(leftExtent)
val rightGeomReproj = st_reproject(st_geometry(rightExtent), rightCRS, leftCRS)
val joinExpr = st_intersects(leftGeom, rightGeomReproj)
val joinExpr = new Column(SpatialRelation.Intersects(leftGeom.expr, rightGeomReproj.expr))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of nicety of deferring to the RasterFrame Expression form of spatial relations that are rewritten into the call graph anyway to help with predicate push down, and eventually, optimization.

apply(left, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS)
}

Expand All @@ -65,7 +80,7 @@ object RasterJoin {
val leftAggCols = left.columns.map(s => first(left(s), true) as s)
// On the RHS we collect result as a list.
val rightAggCtx = Seq(collect_list(rightExtent) as rightExtent2, collect_list(rightCRS) as rightCRS2)
val rightAggTiles = right.tileColumns.map(c => collect_list(c) as c.columnName)
val rightAggTiles = right.tileColumns.map(c => collect_list(ExtractTile(c)) as c.columnName)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The primary fix for #419

val rightAggOther = right.notTileColumns
.filter(n => n.columnName != rightExtent.columnName && n.columnName != rightCRS.columnName)
.map(c => collect_list(c) as (c.columnName + "_agg"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers {

total18 should be > 0.0
total18 should be < total17


}

it("should pass through ancillary columns") {
Expand All @@ -164,5 +162,14 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers {
val joined = left.rasterJoin(right)
joined.columns should contain allElementsOf Seq("left_id", "right_id_agg")
}

it("should handle proj_raster types") {
val df1 = Seq(one).toDF("one")
val df2 = Seq(two).toDF("two")
noException shouldBe thrownBy {
val joined1 = df1.rasterJoin(df2)
val joined2 = df2.rasterJoin(df1)
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was at a loss on how to really beef this up without having to work something out by hand.

}
}
}