-
Notifications
You must be signed in to change notification settings - Fork 496
Expand file tree
/
Copy pathDecoder.h
More file actions
160 lines (132 loc) · 5.85 KB
/
Decoder.h
File metadata and controls
160 lines (132 loc) · 5.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
// Copyright CERN and copyright holders of ALICE O2. This software is
// distributed under the terms of the GNU General Public License v3 (GPL
// Version 3), copied verbatim in the file "COPYING".
//
// See http://alice-o2.web.cern.ch/license for full licensing information.
//
// In applying this license CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.
/// @file Decoder.h
/// @author Michael Lettrich
/// @since 2020-04-06
/// @brief Decoder - decode a rANS encoded state back into source symbols
#ifndef RANS_DECODER_H
#define RANS_DECODER_H
#include "internal/Decoder.h"
#include <cstddef>
#include <type_traits>
#include <iostream>
#include <memory>
#include <fairlogger/Logger.h>
#include "FrequencyTable.h"
#include "internal/DecoderSymbol.h"
#include "internal/ReverseSymbolLookupTable.h"
#include "internal/SymbolTable.h"
#include "internal/Decoder.h"
#include "internal/SymbolStatistics.h"
#include "internal/helper.h"
namespace o2
{
namespace rans
{
template <typename coder_T, typename stream_T, typename source_T>
class Decoder
{
protected:
using decoderSymbolTable_t = internal::SymbolTable<internal::DecoderSymbol>;
using reverseSymbolLookupTable_t = internal::ReverseSymbolLookupTable;
using ransDecoder = internal::Decoder<coder_T, stream_T>;
public:
Decoder(const Decoder& d);
Decoder(Decoder&& d) = default;
Decoder<coder_T, stream_T, source_T>& operator=(const Decoder& d);
Decoder<coder_T, stream_T, source_T>& operator=(Decoder&& d) = default;
~Decoder() = default;
Decoder(const FrequencyTable& stats, size_t probabilityBits);
template <typename stream_IT, typename source_IT, std::enable_if_t<internal::isCompatibleIter_v<stream_T, stream_IT> && internal::isCompatibleIter_v<source_T, source_IT>, bool> = true>
void process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength) const;
size_t getAlphabetRangeBits() const { return mSymbolTable->getAlphabetRangeBits(); }
int getMinSymbol() const { return mSymbolTable->getMinSymbol(); }
int getMaxSymbol() const { return mSymbolTable->getMaxSymbol(); }
using coder_t = coder_T;
using stream_t = stream_T;
using source_t = source_T;
protected:
std::unique_ptr<decoderSymbolTable_t> mSymbolTable;
std::unique_ptr<reverseSymbolLookupTable_t> mReverseLUT;
size_t mProbabilityBits;
};
template <typename coder_T, typename stream_T, typename source_T>
Decoder<coder_T, stream_T, source_T>::Decoder(const Decoder& d) : mSymbolTable(nullptr), mReverseLUT(nullptr), mProbabilityBits(d.mProbabilityBits)
{
mSymbolTable = std::make_unique<decoderSymbolTable_t>(*d.mSymbolTable);
mReverseLUT = std::make_unique<reverseSymbolLookupTable_t>(*d.mReverseLUT);
}
template <typename coder_T, typename stream_T, typename source_T>
Decoder<coder_T, stream_T, source_T>& Decoder<coder_T, stream_T, source_T>::operator=(const Decoder& d)
{
mSymbolTable = std::make_unique<decoderSymbolTable_t>(*d.mSymbolTable);
mReverseLUT = std::make_unique<reverseSymbolLookupTable_t>(*d.mReverseLUT);
mProbabilityBits = d.mProbabilityBits;
return *this;
}
template <typename coder_T, typename stream_T, typename source_T>
Decoder<coder_T, stream_T, source_T>::Decoder(const FrequencyTable& frequencies, size_t probabilityBits) : mSymbolTable(nullptr), mReverseLUT(nullptr), mProbabilityBits(probabilityBits)
{
using namespace internal;
SymbolStatistics stats(frequencies, mProbabilityBits);
mProbabilityBits = stats.getSymbolTablePrecision();
RANSTimer t;
t.start();
mSymbolTable = std::make_unique<decoderSymbolTable_t>(stats);
t.stop();
LOG(debug1) << "Decoder SymbolTable inclusive time (ms): " << t.getDurationMS();
t.start();
mReverseLUT = std::make_unique<reverseSymbolLookupTable_t>(mProbabilityBits, stats);
t.stop();
LOG(debug1) << "ReverseSymbolLookupTable inclusive time (ms): " << t.getDurationMS();
};
template <typename coder_T, typename stream_T, typename source_T>
template <typename stream_IT, typename source_IT, std::enable_if_t<internal::isCompatibleIter_v<stream_T, stream_IT> && internal::isCompatibleIter_v<source_T, source_IT>, bool>>
void Decoder<coder_T, stream_T, source_T>::process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength) const
{
using namespace internal;
LOG(trace) << "start decoding";
RANSTimer t;
t.start();
if (messageLength == 0) {
LOG(warning) << "Empty message passed to decoder, skipping decode process";
return;
}
stream_IT inputIter = inputEnd;
source_IT it = outputBegin;
// make Iter point to the last last element
--inputIter;
ransDecoder rans0, rans1;
inputIter = rans0.init(inputIter);
inputIter = rans1.init(inputIter);
for (size_t i = 0; i < (messageLength & ~1); i += 2) {
const int64_t s0 = (*mReverseLUT)[rans0.get(mProbabilityBits)];
const int64_t s1 = (*mReverseLUT)[rans1.get(mProbabilityBits)];
*it++ = s0;
*it++ = s1;
inputIter = rans0.advanceSymbol(inputIter, (*mSymbolTable)[s0], mProbabilityBits);
inputIter = rans1.advanceSymbol(inputIter, (*mSymbolTable)[s1], mProbabilityBits);
}
// last byte, if message length was odd
if (messageLength & 1) {
const int64_t s0 = (*mReverseLUT)[rans0.get(mProbabilityBits)];
*it = s0;
inputIter = rans0.advanceSymbol(inputIter, (*mSymbolTable)[s0], mProbabilityBits);
}
t.stop();
LOG(debug1) << "Decoder::" << __func__ << " { DecodedSymbols: " << messageLength << ","
<< "processedBytes: " << messageLength * sizeof(source_T) << ","
<< " inclusiveTimeMS: " << t.getDurationMS() << ","
<< " BandwidthMiBPS: " << std::fixed << std::setprecision(2) << (messageLength * sizeof(source_T) * 1.0) / (t.getDurationS() * 1.0 * (1 << 20)) << "}";
LOG(trace) << "done decoding";
}
} // namespace rans
} // namespace o2
#endif /* RANS_DECODER_H */