@@ -781,6 +781,78 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
781781}
782782)" ;
783783
784+ static const char *kSum = R"(
785+ @group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
786+ @group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
787+ var<workgroup> buffer: array<{{precision}}, 1024>;
788+ @compute @workgroup_size({{workgroupSize}})
789+ fn main(
790+ @builtin(global_invocation_id) globalID : vec3<u32>,
791+ @builtin(local_invocation_id) localID : vec3<u32>,
792+ @builtin(workgroup_id) groupid : vec3<u32>,
793+ @builtin(num_workgroups) numGroups : vec3<u32>) {
794+ let blockSize3d: vec3<u32> = vec3({{workgroupSize}});
795+ let blockSize: u32 = blockSize3d.x;
796+ let threadId: u32 = localID.x;
797+ let blockId: u32 = groupid.x + groupid.y * numGroups.x;
798+ let blockStart = blockId * blockSize * 2 + threadId;
799+
800+ buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize];
801+ workgroupBarrier();
802+ var stride: u32 = blockSize / 2;
803+
804+ if (blockSize >= 1024 && threadId < 512) {
805+ buffer[threadId] += buffer[threadId + 512];
806+ }
807+ workgroupBarrier();
808+
809+ if (blockSize >= 512 && threadId < 256) {
810+ buffer[threadId] += buffer[threadId + 256];
811+ }
812+ workgroupBarrier();
813+
814+ if (blockSize >= 256 && threadId < 128) {
815+ buffer[threadId] += buffer[threadId + 128];
816+ }
817+ workgroupBarrier();
818+
819+ if (threadId < 64) {
820+ buffer[threadId] += buffer[threadId + 64];
821+ }
822+ workgroupBarrier();
823+
824+ if (threadId < 32) {
825+ buffer[threadId] += buffer[threadId + 32];
826+ }
827+ workgroupBarrier();
828+
829+ if (threadId < 16) {
830+ buffer[threadId] += buffer[threadId + 16];
831+ }
832+ workgroupBarrier();
833+
834+ if (threadId < 8) {
835+ buffer[threadId] += buffer[threadId + 8];
836+ }
837+ workgroupBarrier();
838+
839+ if (threadId < 4) {
840+ buffer[threadId] += buffer[threadId + 4];
841+ }
842+ workgroupBarrier();
843+
844+ if (threadId < 2) {
845+ buffer[threadId] += buffer[threadId + 2];
846+ }
847+ workgroupBarrier();
848+
849+ if (threadId == 0) {
850+ buffer[0] += buffer[1];
851+ out[blockId] = buffer[0];
852+ }
853+ }
854+ )" ;
855+
784856} // namespace gpu
785857
786858#endif // KERNELS_H
0 commit comments