Skip to content

Commit 84a2d5d

Browse files
ffjiangfacebook-github-bot
authored andcommitted
Add hashing to bucket-weighted pooling (#20673)
Summary: Pull Request resolved: #20673 Add option to bucket-weighted pooling to hash the bucket so that any cardinality score can be used. Reviewed By: huginhuangfb Differential Revision: D15003509 fbshipit-source-id: 575a149de395f18fd7759f3edb485619f8aa5363
1 parent 1aae4b0 commit 84a2d5d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

caffe2/python/layers/bucket_weighted.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121

2222
class BucketWeighted(ModelLayer):
2323
def __init__(self, model, input_record, max_score=0, bucket_boundaries=None,
24-
weight_optim=None, name="bucket_weighted"):
24+
hash_buckets=False, weight_optim=None, name="bucket_weighted"):
2525
super(BucketWeighted, self).__init__(model, name, input_record)
2626

2727
assert isinstance(input_record, schema.List), "Incorrect input type"
2828
self.bucket_boundaries = bucket_boundaries
29+
self.hash_buckets = hash_buckets
2930
if bucket_boundaries is not None:
3031
self.shape = len(bucket_boundaries) + 1
3132
elif max_score > 0:
@@ -63,6 +64,10 @@ def add_ops(self, net):
6364
"buckets_int",
6465
to=core.DataType.INT32
6566
)
67+
if self.hash_buckets:
68+
buckets_int = net.IndexHash(
69+
buckets_int, "hashed_buckets_int", seed=0, modulo=self.shape
70+
)
6671
net.Gather(
6772
[self.bucket_w, buckets_int],
6873
self.output_schema.bucket_weights.field_blobs())

0 commit comments

Comments
 (0)