forked from GoogleCloudPlatform/python-docs-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·354 lines (294 loc) · 13.5 KB
/
main.py
File metadata and controls
executable file
·354 lines (294 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#!/usr/bin/env python
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example uses the Google Cloud Vision API to detect text in images, then
analyzes that text using the Google Cloud Natural Language API to detect
entities in the text. It stores the detected entity information in an sqlite3
database, which may then be queried.
After this script has analyzed a directory of images, it outputs some
information on the images' entities to STDOUT. You can also further query
the generated sqlite3 database; see the README for more information.
Run the script on a directory of images to do the analysis, E.g.:
$ python main.py --input_directory=<path-to-image-directory>
You can try this on a sample directory of images:
$ curl -O http://storage.googleapis.com/python-docs-samples-tests/language/ocr_nl-images.zip
$ unzip ocr_nl-images.zip
$ python main.py --input_directory=images/
""" # noqa
import argparse
import base64
import contextlib
import logging
import os
import sqlite3
import sys
import time
import googleapiclient.discovery
import googleapiclient.errors
BATCH_SIZE = 10
class VisionApi(object):
"""Construct and use the Cloud Vision API service."""
def __init__(self):
self.service = googleapiclient.discovery.build('vision', 'v1')
def detect_text(self, input_filenames, num_retries=3, max_results=6):
"""Uses the Vision API to detect text in the given file."""
batch_request = []
for filename in input_filenames:
request = {
'image': {},
'features': [{
'type': 'TEXT_DETECTION',
'maxResults': max_results,
}]
}
# Accept both files in cloud storage, as well as local files.
if filename.startswith('gs://'):
request['image']['source'] = {
'gcsImageUri': filename
}
else:
with open(filename, 'rb') as image_file:
request['image']['content'] = base64.b64encode(
image_file.read()).decode('UTF-8')
batch_request.append(request)
request = self.service.images().annotate(
body={'requests': batch_request})
try:
responses = request.execute(num_retries=num_retries)
if 'responses' not in responses:
return {}
text_response = {}
for filename, response in zip(
input_filenames, responses['responses']):
if 'error' in response:
logging.error('API Error for {}: {}'.format(
filename,
response['error'].get('message', '')))
continue
text_response[filename] = response.get('textAnnotations', [])
return text_response
except googleapiclient.errors.HttpError as e:
logging.error('Http Error for {}: {}'.format(filename, e))
except KeyError as e2:
logging.error('Key error: {}'.format(e2))
class TextAnalyzer(object):
"""Construct and use the Google Natural Language API service."""
def __init__(self, db_filename=None):
self.service = googleapiclient.discovery.build('language', 'v1')
# This list will store the entity information gleaned from the
# image files.
self.entity_info = []
# This is the filename of the sqlite3 database to save to
self.db_filename = db_filename or 'entities{}.db'.format(
int(time.time()))
def _get_native_encoding_type(self):
"""Returns the encoding type that matches Python's native strings."""
if sys.maxunicode == 65535:
return 'UTF16'
else:
return 'UTF32'
def nl_detect(self, text):
"""Use the Natural Language API to analyze the given text string."""
# We're only requesting 'entity' information from the Natural Language
# API at this time.
body = {
'document': {
'type': 'PLAIN_TEXT',
'content': text,
},
'encodingType': self._get_native_encoding_type(),
}
entities = []
try:
request = self.service.documents().analyzeEntities(body=body)
response = request.execute()
entities = response['entities']
except googleapiclient.errors.HttpError as e:
logging.error('Http Error: %s' % e)
except KeyError as e2:
logging.error('Key error: %s' % e2)
return entities
def add_entities(self, filename, locale, document):
"""Apply the Natural Language API to the document, and collect the
detected entities."""
# Apply the Natural Language API to the document.
entities = self.nl_detect(document)
self.extract_and_save_entity_info(entities, locale, filename)
def extract_entity_info(self, entity):
"""Extract information about an entity."""
type = entity['type']
name = entity['name'].lower()
metadata = entity['metadata']
salience = entity['salience']
wiki_url = metadata.get('wikipedia_url', None)
return (type, name, salience, wiki_url)
def extract_and_save_entity_info(self, entities, locale, filename):
for entity in entities:
type, name, salience, wiki_url = self.extract_entity_info(entity)
# Because this is a small example, we're using a list to hold
# all the entity information, then we'll insert it into the
# database all at once when we've processed all the files.
# For a larger data set, you would want to write to the database
# in batches.
self.entity_info.append(
(locale, type, name, salience, wiki_url, filename))
def write_entity_info_to_db(self):
"""Store the info gleaned about the entities in the text, via the
Natural Language API, in an sqlite3 database table, and then print out
some simple analytics.
"""
logging.info('Saving entity info to the sqlite3 database.')
# Create the db.
with contextlib.closing(sqlite3.connect(self.db_filename)) as conn:
with conn as cursor:
# Create table
cursor.execute(
'CREATE TABLE if not exists entities (locale text, '
'type text, name text, salience real, wiki_url text, '
'filename text)')
with conn as cursor:
# Load all the data
cursor.executemany(
'INSERT INTO entities VALUES (?,?,?,?,?,?)',
self.entity_info)
def output_entity_data(self):
"""Output some info about the entities by querying the generated
sqlite3 database.
"""
with contextlib.closing(sqlite3.connect(self.db_filename)) as conn:
# This query finds the number of times each entity name was
# detected, in descending order by count, and returns information
# about the first 15 names, including the files in which they were
# found, their detected 'salience' and language (locale), and the
# wikipedia urls (if any) associated with them.
print('\n==============\nTop 15 most frequent entity names:')
cursor = conn.cursor()
results = cursor.execute(
'select name, count(name) as wc from entities '
'group by name order by wc desc limit 15;')
for item in results:
cursor2 = conn.cursor()
print(u'\n----Name: {} was found with count {}'.format(*item))
results2 = cursor2.execute(
'SELECT name, type, filename, locale, wiki_url, salience '
'FROM entities WHERE name=?', (item[0],))
urls = set()
for elt in results2:
print(('Found in file {}, detected as type {}, with\n'
' locale {} and salience {}.').format(
elt[2], elt[1], elt[3], elt[5]))
if elt[4]:
urls.add(elt[4])
if urls:
print('url(s): {}'.format(urls))
# This query finds the number of times each wikipedia url was
# detected, in descending order by count, and returns information
# about the first 15 urls, including the files in which they were
# found and the names and 'salience' with which they were
# associated.
print('\n==============\nTop 15 most frequent Wikipedia URLs:')
c = conn.cursor()
results = c.execute(
'select wiki_url, count(wiki_url) as wc from entities '
'group by wiki_url order by wc desc limit 15;')
for item in results:
cursor2 = conn.cursor()
print('\n----entity: {} was found with count {}'.format(*item))
results2 = cursor2.execute(
'SELECT name, type, filename, locale, salience '
'FROM entities WHERE wiki_url=?', (item[0],))
names = set()
salience = set()
for elt in results2:
print(('Found in file {}, detected as type {}, with\n'
' locale {}.').format(elt[2], elt[1], elt[3]))
names.add(elt[0])
salience.add(elt[4])
print('names(s): {}'.format(names))
print('salience measure(s): {}'.format(salience))
def extract_description(texts):
"""Returns text annotations as a single string"""
document = []
for text in texts:
try:
document.append(text['description'])
locale = text['locale']
# Process only the first entry, which contains all
# text detected.
break
except KeyError as e:
logging.error('KeyError: %s\n%s' % (e, text))
return (locale, ' '.join(document))
def extract_descriptions(input_filename, texts, text_analyzer):
"""Gets the text that was detected in the image."""
if texts:
locale, document = extract_description(texts)
text_analyzer.add_entities(input_filename, locale, document)
sys.stdout.write('.') # Output a progress indicator.
sys.stdout.flush()
elif texts == []:
print('%s had no discernible text.' % input_filename)
def get_text_from_files(vision, input_filenames, text_analyzer):
"""Call the Vision API on a file and index the results."""
texts = vision.detect_text(input_filenames)
if texts:
for filename, text in texts.items():
extract_descriptions(filename, text, text_analyzer)
def batch(list_to_batch, batch_size=BATCH_SIZE):
"""Group a list into batches of size batch_size.
>>> tuple(batch([1, 2, 3, 4, 5], batch_size=2))
((1, 2), (3, 4), (5))
"""
for i in range(0, len(list_to_batch), batch_size):
yield tuple(list_to_batch[i:i + batch_size])
def main(input_dir, db_filename=None):
"""Walk through all the image files in the given directory, extracting any
text from them and feeding that text to the Natural Language API for
analysis.
"""
# Create a client object for the Vision API
vision_api_client = VisionApi()
# Create an object to analyze our text using the Natural Language API
text_analyzer = TextAnalyzer(db_filename)
if input_dir:
allfileslist = []
# Recursively construct a list of all the files in the given input
# directory.
for folder, subs, files in os.walk(input_dir):
for filename in files:
allfileslist.append(os.path.join(folder, filename))
# Analyze the text in the files using the Vision and Natural Language
# APIs.
for filenames in batch(allfileslist, batch_size=1):
get_text_from_files(vision_api_client, filenames, text_analyzer)
# Save the result to a database, then run some queries on the database,
# with output to STDOUT.
text_analyzer.write_entity_info_to_db()
# now, print some information about the entities detected.
text_analyzer.output_entity_data()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Detects text in the images in the given directory.')
parser.add_argument(
'--input_directory',
help='The image directory you\'d like to detect text in. If left '
'unspecified, the --db specified will be queried without being '
'updated.')
parser.add_argument(
'--db', help='The filename to use for the sqlite3 database.')
args = parser.parse_args()
if not (args.input_directory or args.db):
parser.error('Either --input_directory or --db must be specified.')
main(args.input_directory, args.db)