Skip to content
This repository was archived by the owner on Nov 16, 2019. It is now read-only.

Commit 054aa08

Browse files
authored
Merge pull request #176 from yahoo/lstm_inference
LSTM support in CaffeOnSpark
2 parents 081407c + b5f0a87 commit 054aa08

32 files changed

+2854
-19
lines changed

Makefile

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
HOME ?=/home/${USER}
22
ifeq ($(shell which spark-submit),)
3-
SPARK_HOME ?=/home/y/share/spark
3+
SPARK_HOME ?= /home/y/share/spark
44
else
55
SPARK_HOME ?=$(shell which spark-submit 2>&1 | sed 's/\/bin\/spark-submit//g')
66
endif
77
CAFFE_ON_SPARK ?=$(shell pwd)
8-
LD_LIBRARY_PATH ?=/home/y/lib64:/home/y/lib64/mkl/intel64
8+
LD_LIBRARY_PATH ?=/home/y/lib64:/home/y/lib64/mkl/intel64:/usr/local/cuda/
99
LD_LIBRARY_PATH2=${LD_LIBRARY_PATH}:${CAFFE_ON_SPARK}/caffe-public/distribute/lib:${CAFFE_ON_SPARK}/caffe-distri/distribute/lib:/usr/lib64:/lib64
10-
DYLD_LIBRARY_PATH ?=/home/y/lib64:/home/y/lib64/mkl/intel64
11-
DYLD_LIBRARY_PATH2=${DYLD_LIBRARY_PATH}:${CAFFE_ON_SPARK}/caffe-public/distribute/lib:${CAFFE_ON_SPARK}/caffe-distri/distribute/lib:/usr/lib64:/lib64
10+
DYLD_LIBRARY_PATH ?=/home/y/lib64:/home/y/lib64/mkl/intel64:/usr/local/cuda/lib
11+
DYLD_LIBRARY_PATH2=${DYLD_LIBRARY_PATH}:${CAFFE_ON_SPARK}/caffe-public/distribute/lib:${CAFFE_ON_SPARK}/caffe-distri/distribute/lib:/usr/lib64:/lib64
1212

1313
export SPARK_VERSION=$(shell ${SPARK_HOME}/bin/spark-submit --version 2>&1 | grep version | awk '{print $$5}' | cut -d'.' -f1)
1414
ifeq (${SPARK_VERSION}, 2)
@@ -17,24 +17,25 @@ endif
1717

1818
build:
1919
cd caffe-public; make proto; make -j4 -e distribute; cd ..
20-
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH2}"; mvn ${MVN_SPARK_FLAG} -B package -DskipTests
20+
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH2}"; GLOG_minloglevel=1 mvn ${MVN_SPARK_FLAG} -B package -DskipTests
2121
jar -xvf caffe-grid/target/caffe-grid-0.1-SNAPSHOT-jar-with-dependencies.jar META-INF/native/linux64/liblmdbjni.so
2222
mv META-INF/native/linux64/liblmdbjni.so ${CAFFE_ON_SPARK}/caffe-distri/distribute/lib
2323
${CAFFE_ON_SPARK}/scripts/setup-mnist.sh
24-
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH2}"; mvn ${MVN_SPARK_FLAG} -B package
25-
cp -r ${CAFFE_ON_SPARK}/caffe-public/python/caffe ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/
26-
cd ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/; zip -r caffeonsparkpythonapi *; mv caffeonsparkpythonapi.zip ${CAFFE_ON_SPARK}/caffe-grid/target/;cd ${CAFFE_ON_SPARK}
27-
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH2}"; export SPARK_HOME="${SPARK_HOME}"; ${CAFFE_ON_SPARK}/caffe-grid/src/test/python/PythonTest.sh
24+
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH2}"; GLOG_minloglevel=1 mvn ${MVN_SPARK_FLAG} -B test
25+
cd ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/; zip -r caffeonsparkpythonapi *; cd ${CAFFE_ON_SPARK}/caffe-public/python/; zip -ur ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/caffeonsparkpythonapi.zip *; cd - ; mv caffeonsparkpythonapi.zip ${CAFFE_ON_SPARK}/caffe-grid/target/; cd ${CAFFE_ON_SPARK}
26+
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}; export SPARK_HOME=${SPARK_HOME};GLOG_minloglevel=1 ${CAFFE_ON_SPARK}/caffe-grid/src/test/python/PythonTest.sh
27+
2828
buildosx:
2929
cd caffe-public; make proto; make -j4 -e distribute; cd ..
30-
export DYLD_LIBRARY_PATH="${DYLD_LIBRARY_PATH2}"; mvn ${MVN_SPARK_FLAG} -B package -DskipTests
30+
export DYLD_LIBRARY_PATH="${DYLD_LIBRARY_PATH2}"; GLOG_minloglevel=1 mvn ${MVN_SPARK_FLAG} -B package -DskipTests
3131
jar -xvf caffe-grid/target/caffe-grid-0.1-SNAPSHOT-jar-with-dependencies.jar META-INF/native/osx64/liblmdbjni.jnilib
3232
mv META-INF/native/osx64/liblmdbjni.jnilib ${CAFFE_ON_SPARK}/caffe-distri/distribute/lib
3333
${CAFFE_ON_SPARK}/scripts/setup-mnist.sh
34-
export DYLD_LIBRARY_PATH="${DYLD_LIBRARY_PATH2}"; mvn ${MVN_SPARK_FLAG} -B package
35-
cp -r ${CAFFE_ON_SPARK}/caffe-public/python/caffe ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/
36-
cd ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/; zip -r caffeonsparkpythonapi *; mv caffeonsparkpythonapi.zip ${CAFFE_ON_SPARK}/caffe-grid/target/; cd ${CAFFE_ON_SPARK}
37-
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH2}"; export SPARK_HOME="${SPARK_HOME}"; ${CAFFE_ON_SPARK}/caffe-grid/src/test/python/PythonTest.sh
34+
export LD_LIBRARY_PATH="${DYLD_LIBRARY_PATH2}"; GLOG_minloglevel=1 mvn ${MVN_SPARK_FLAG} -B test
35+
cd ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/; zip -r caffeonsparkpythonapi *; cd ${CAFFE_ON_SPARK}/caffe-public/python/; zip -ur ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/caffeonsparkpythonapi.zip *; cd -; mv caffeonsparkpythonapi.zip ${CAFFE_ON_SPARK}/caffe-grid/target/; cd ${CAFFE_ON_SPARK}
36+
cd ${CAFFE_ON_SPARK}/caffe-grid/src/main/python/; zip -r caffeonsparkpythonapi *; mv caffeonsparkpythonapi.zip ${CAFFE_ON_SPARK}/caffe-grid/target/; cd ${CAFFE_ON_SPARK}
37+
export DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}; export SPARK_HOME=${SPARK_HOME};GLOG_minloglevel=1 ${CAFFE_ON_SPARK}/caffe-grid/src/test/python/PythonTest.sh
38+
3839
update:
3940
git submodule init
4041
git submodule update --force
@@ -48,7 +49,7 @@ gh-pages:
4849
rm -rf scala_doc
4950
git checkout gh-pages scala_doc
5051

51-
clean:
52+
clean:
5253
cd caffe-public; make clean; cd ..
5354
cd caffe-distri; make clean; cd ..
5455
mvn ${MVN_SPARK_FLAG} clean

caffe-grid/src/main/python/com/yahoo/ml/caffe/Config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ class Config:
3737
:ivar int test_data_layer_id: Get layer ID of training data source
3838
:ivar int train_data_layer_id: Get layer ID of training data source
3939
:ivar int transform_thread_per_device: Get/Set # of transformer threads per device
40+
:ivar String imageCaptionDFDir: Path to generate the image caption dataframe
41+
:ivar String vocabDir: Path to generate the Vocab
42+
:ivar String embeddingDFDir: Path to generate the embedded dataframe
43+
:ivar String captionFile: Path to the caption file
44+
:ivar int captionLength: Embedding caption length
45+
:ivar int vocabSize: Vocab size to consider
4046
"""
4147
def __init__(self,sc,args=None):
4248
registerContext(sc)

caffe-grid/src/main/python/com/yahoo/ml/caffe/ConversionUtil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def toJavaSC(pySc):
388388
Converts a Python SQLContext to a Scala SparkContext.
389389
'''
390390
def toScalaSQLC(pySQLc):
391-
return pySQLc._scala_SQLContext
391+
return jvm.org.apache.spark.sql.SQLContext(pySQLc._jsc.sc())
392392

393393
'''
394394
Converts a Python SQLContext to a Java SparkContext.

caffe-grid/src/main/python/com/yahoo/ml/caffe/DisplayUtils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def image_tag(np_array):
2020

2121
def show_df(df, nrows=10):
2222
"""Displays a table of labels with their images, inline in html
23-
2423
:param DataFrame df: A python dataframe
2524
:param int nrows: First n rows to display from the dataframe
2625
"""
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from tools import *
2+
3+
__all__=["tools"]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
'''
2+
Copyright 2016 Yahoo Inc.
3+
Licensed under the terms of the Apache 2.0 license.
4+
Please see LICENSE file in the project root for terms.
5+
'''
6+
from PIL import Image
7+
from io import BytesIO
8+
from IPython.display import HTML
9+
import numpy as np
10+
from base64 import b64encode
11+
from google.protobuf import text_format
12+
import array
13+
from com.yahoo.ml.caffe.ConversionUtil import wrapClass, getScalaSingleton, toPython
14+
from com.yahoo.ml.caffe.RegisterContext import registerContext
15+
from pyspark.sql import DataFrame,SQLContext
16+
17+
class DFConversions:
18+
"""
19+
20+
:ivar SparkContext: The spark context of the current spark session
21+
"""
22+
23+
def __init__(self,sc):
24+
registerContext(sc)
25+
wrapClass("com.yahoo.ml.caffe.tools.Conversions$")
26+
self.__dict__['conversions']=toPython(getScalaSingleton("com.yahoo.ml.caffe.tools.Conversions"))
27+
self.__dict__['sqlContext']=SQLContext(sc)
28+
29+
def Coco2ImageCaptionFile(self,src,clusterSize):
30+
"""Convert Cocodataset to Image Caption Dataframe
31+
:param src: the source for coco dataset i.e the caption file
32+
:param clusterSize: No. of executors
33+
"""
34+
df = self.__dict__.get('conversions').Coco2ImageCaptionFile(self.__dict__.get('sqlContext'), src, clusterSize)
35+
pydf = DataFrame(df,self.__dict__.get('sqlContext'))
36+
return pydf
37+
38+
39+
def Image2Embedding(self, imageRootFolder, imageCaptionDF):
40+
"""Get the embedding for the image as a dataframe
41+
:param imageRootFolder: the src folder of the images
42+
:param imageCaptionDF: the dataframe with the image file and image attributes
43+
"""
44+
df = self.__dict__.get('conversions').Image2Embedding(imageRootFolder, imageCaptionDF._jdf)
45+
pydf = DataFrame(df,self.__dict__.get('sqlContext'))
46+
return pydf
47+
48+
def ImageCaption2Embedding(self, imageRootFolder, imageCaptionDF, vocab, captionLength):
49+
"""Get the embedding for the images as well as the caption as a dataframe
50+
:param imageRootFolder: the src folder of the images
51+
:param imageCaptionDF: the dataframe with the images as well as captions
52+
:param vocab: the vocab object
53+
:param captionLength: Length of the embedding to generate for the caption
54+
"""
55+
df = self.__dict__.get('conversions').ImageCaption2Embedding(imageRootFolder, imageCaptionDF._jdf, vocab.vocabObject, captionLength)
56+
pydf = DataFrame(df,self.__dict__.get('sqlContext'))
57+
return pydf
58+
59+
60+
def Embedding2Caption(self, embeddingDF, vocab, embeddingColumn, captionColumn):
61+
"""Get the captions from the embeddings
62+
:param embeddingDF: the dataframe which contains the embedding
63+
:param vocab: the vocab object
64+
:param embeddingColumn: the embedding column name in embeddingDF which contains the caption embedding
65+
"""
66+
df = self.__dict__.get('conversions').Embedding2Caption(embeddingDF._jdf, vocab.vocabObject, embeddingColumn, captionColumn)
67+
pydf = DataFrame(df,self.__dict__.get('sqlContext'))
68+
return pydf
69+
70+
71+
def get_image(image):
72+
bytes = array.array('b', image)
73+
return "<img src='data:image/png;base64," + b64encode(bytes) + "' />"
74+
75+
76+
def show_captions(df, nrows=10):
77+
"""Displays a table of captions(both original as well as predictions) with their images, inline in html
78+
79+
:param DataFrame df: A python dataframe
80+
:param int nrows: First n rows to display from the dataframe
81+
"""
82+
data = df.take(nrows)
83+
html = "<table><tr><th>Image Id</th><th>Image</th><th>Prediction</th>"
84+
for i in range(nrows):
85+
row = data[i]
86+
html += "<tr>"
87+
html += "<td>%s</td>" % row.id
88+
html += "<td>%s</td>" % get_image(row.data.image)
89+
html += "<td>%s</td>" % row.prediction
90+
html += "</tr>"
91+
html += "</table>"
92+
return HTML(html)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
'''
2+
Copyright 2016 Yahoo Inc.
3+
Licensed under the terms of the Apache 2.0 license.
4+
Please see LICENSE file in the project root for terms.
5+
'''
6+
7+
from com.yahoo.ml.caffe.ConversionUtil import wrapClass
8+
from com.yahoo.ml.caffe.RegisterContext import registerContext
9+
from pyspark.sql import DataFrame,SQLContext
10+
11+
class Vocab:
12+
"""
13+
14+
:ivar SparkContext: The spark context of the current spark session
15+
"""
16+
17+
def __init__(self,sc):
18+
registerContext(sc)
19+
self.vocab=wrapClass("com.yahoo.ml.caffe.tools.Vocab")
20+
self.sqlContext=SQLContext(sc)
21+
self.vocabObject=self.vocab(self.sqlContext)
22+
23+
def genFromData(self,dataset,columnName,vocabSize):
24+
"""Convert generate the vocabulary from dataset
25+
:param dataset: dataframe containing the captions
26+
:param columnName: column in the dataset which has the caption
27+
:param vocabSize: Size of the vocabulary to generate (with vocab in descending order)
28+
"""
29+
self.vocabObject.genFromData(dataset._jdf,columnName,vocabSize)
30+
31+
def save(self, vocabFilePath):
32+
"""Save the generated vocabulary
33+
:param vocabFilePath: the name of the file to save the vocabulary to
34+
"""
35+
self.vocabObject.save(vocabFilePath)
36+
37+
def load(self, vocabFilePath):
38+
"""Load the vocabulary from a file
39+
:param vocabFilePath: the name of the file to load the vocabulary from
40+
"""
41+
self.vocabObject.load(vocabFilePath)
42+
43+

caffe-grid/src/main/python/com/yahoo/ml/caffe/tools/__init__.py

Whitespace-only changes.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2016 Yahoo Inc.
2+
# Licensed under the terms of the Apache 2.0 license.
3+
# Please see LICENSE file in the project root for terms.
4+
import caffe
5+
from examples.coco.retrieval_experiment import *
6+
from pyspark.sql import SQLContext
7+
from pyspark import SparkConf,SparkContext
8+
from pyspark.sql.types import *
9+
from itertools import izip_longest
10+
import json
11+
import argparse
12+
13+
def predict_caption(list_of_images, model, imagenet, lstmnet, vocab):
14+
out_iterator = []
15+
ce = CaptionExperiment(str(model),str(imagenet),str(lstmnet),str(vocab))
16+
for image in list_of_images:
17+
out_iterator.append(ce.getCaption(image))
18+
return iter(out_iterator)
19+
20+
def get_predictions(sqlContext, images, model, imagenet, lstmnet, vocab):
21+
rdd = images.mapPartitions(lambda im: predict_caption(im, model, imagenet, lstmnet, vocab))
22+
INNERSCHEMA = StructType([StructField("id", StringType(), True),StructField("prediction", StringType(), True)])
23+
schema = StructType([StructField("result", INNERSCHEMA, True)])
24+
return sqlContext.createDataFrame(rdd, schema).select("result.id", "result.prediction")
25+
26+
def main():
27+
conf = SparkConf()
28+
sc = SparkContext(conf=conf)
29+
sqlContext = SQLContext(sc)
30+
cmdargs = conf.get('spark.pythonargs')
31+
parser = argparse.ArgumentParser(description="Image to Caption Util")
32+
parser.add_argument('-input', action="store", dest="input")
33+
parser.add_argument('-model', action="store", dest="model")
34+
parser.add_argument('-imagenet', action="store", dest="imagenet")
35+
parser.add_argument('-lstmnet', action="store", dest="lstmnet")
36+
parser.add_argument('-vocab', action="store", dest="vocab")
37+
parser.add_argument('-output', action="store", dest="output")
38+
39+
args=parser.parse_args(cmdargs.split(" "))
40+
41+
df_input = sqlContext.read.parquet(str(args.input))
42+
images = df_input.select("data.image","data.height", "data.width", "id")
43+
df=get_predictions(sqlContext, images, str(args.model), str(args.imagenet), str(args.lstmnet), str(args.vocab))
44+
df.write.json(str(args.output))
45+
46+
47+
if __name__ == "__main__":
48+
main()
49+
50+

0 commit comments

Comments
 (0)