Skip to content

Commit 464fd02

Browse files
committed
fixed usage of IntQubit(), fiexed a bug in IntQubit constructor
1 parent 0e2574d commit 464fd02

File tree

5 files changed

+43
-30
lines changed

5 files changed

+43
-30
lines changed

examples/advanced/grover_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
def demo_vgate_app(v):
1313
for i in range(2**v.nqubits):
1414
print('qapply(v*IntQubit(%i, %r))' % (i, v.nqubits))
15-
pprint(qapply(v*IntQubit(i, v.nqubits)))
16-
qapply(v*IntQubit(i, v.nqubits))
15+
pprint(qapply(v*IntQubit(i, nqubits=v.nqubits)))
16+
qapply(v*IntQubit(i, nqubits=v.nqubits))
1717

1818

1919
def black_box(qubits):
20-
return True if qubits == IntQubit(1, qubits.nqubits) else False
20+
return True if qubits == IntQubit(1, nqubits=qubits.nqubits) else False
2121

2222

2323
def main():

sympy/physics/quantum/grover.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def superposition_basis(nqubits):
5555
"""
5656

5757
amp = 1/sqrt(2**nqubits)
58-
return sum([amp*IntQubit(n, nqubits) for n in range(2**nqubits)])
58+
return sum([amp*IntQubit(n, nqubits=nqubits) for n in range(2**nqubits)])
5959

6060

6161
class OracleGate(Gate):
@@ -176,7 +176,7 @@ def _represent_ZGate(self, basis, **options):
176176
matrixOracle = eye(nbasis)
177177
# Flip the sign given the output of the oracle function
178178
for i in range(nbasis):
179-
if self.search_function(IntQubit(i, self.nqubits)):
179+
if self.search_function(IntQubit(i, nqubits=self.nqubits)):
180180
matrixOracle[i, i] = NegativeOne()
181181
return matrixOracle
182182

sympy/physics/quantum/qubit.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,12 @@ class IntQubitState(QubitState):
284284
@classmethod
285285
def _eval_args(cls, args, **extra_args):
286286
nqubits = extra_args.get('nqubits')
287+
# use nqubits if specified
288+
if nqubits is not None:
289+
if len(args) != 1:
290+
raise ValueError(
291+
'too many positional arguments (%s). should be (number, nqubits=n)' % (args,))
292+
return cls._eval_args_with_nqubits(args[0], nqubits)
287293
# The case of a QubitState instance
288294
if len(args) == 1 and isinstance(args[0], QubitState):
289295
return QubitState._eval_args(args)
@@ -296,18 +302,20 @@ def _eval_args(cls, args, **extra_args):
296302
return QubitState._eval_args(qubit_values)
297303
# For two numbers, the second number is the number of bits
298304
# on which it is expressed, so IntQubit(0,5) == |00000>.
299-
elif nqubits is not None or (len(args) == 2 and args[1] > 1):
300-
if nqubits is None:
301-
nqubits = args[1]
302-
need = bitcount(abs(args[0]))
303-
if nqubits < need:
304-
raise ValueError(
305-
'cannot represent %s with %s bits' % (args[0], nqubits))
306-
qubit_values = [(args[0] >> i) & 1 for i in reversed(range(nqubits))]
307-
return QubitState._eval_args(qubit_values)
305+
elif len(args) == 2 and args[1] > 1:
306+
return cls._eval_args_with_nqubits(args[0], args[1])
308307
else:
309308
return QubitState._eval_args(args)
310309

310+
@classmethod
311+
def _eval_args_with_nqubits(cls, number, nqubits):
312+
need = bitcount(abs(number))
313+
if nqubits < need:
314+
raise ValueError(
315+
'cannot represent %s with %s bits' % (number, nqubits))
316+
qubit_values = [(number >> i) & 1 for i in reversed(range(nqubits))]
317+
return QubitState._eval_args(qubit_values)
318+
311319
def as_int(self):
312320
"""Return the numerical value of the qubit."""
313321
number = 0
@@ -544,7 +552,7 @@ def measure_all(qubit, format='sympy', normalize=True):
544552
for i in range(size):
545553
if m[i] != 0.0:
546554
results.append(
547-
(Qubit(IntQubit(i, nqubits)), m[i]*conjugate(m[i]))
555+
(Qubit(IntQubit(i, nqubits=nqubits)), m[i]*conjugate(m[i]))
548556
)
549557
return results
550558
else:

sympy/physics/quantum/tests/test_grover.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@ def return_one_on_two(qubits):
1111

1212

1313
def return_one_on_one(qubits):
14-
return qubits == IntQubit(1, qubits.nqubits)
14+
return qubits == IntQubit(1, nqubits=qubits.nqubits)
1515

1616

1717
def test_superposition_basis():
1818
nbits = 2
19-
first_half_state = IntQubit(0, nbits)/2 + IntQubit(1, nbits)/2
19+
first_half_state = IntQubit(0, nqubits=nbits)/2 + IntQubit(1, nqubits=nbits)/2
2020
second_half_state = IntQubit(2, nbits)/2 + IntQubit(3, nbits)/2
2121
assert first_half_state + second_half_state == superposition_basis(nbits)
2222

2323
nbits = 3
24-
firstq = (1/sqrt(8))*IntQubit(0, nbits) + (1/sqrt(8))*IntQubit(1, nbits)
24+
firstq = (1/sqrt(8))*IntQubit(0, nqubits=nbits) + (1/sqrt(8))*IntQubit(1, nqubits=nbits)
2525
secondq = (1/sqrt(8))*IntQubit(2, nbits) + (1/sqrt(8))*IntQubit(3, nbits)
2626
thirdq = (1/sqrt(8))*IntQubit(4, nbits) + (1/sqrt(8))*IntQubit(5, nbits)
2727
fourthq = (1/sqrt(8))*IntQubit(6, nbits) + (1/sqrt(8))*IntQubit(7, nbits)
@@ -35,30 +35,30 @@ def test_OracleGate():
3535

3636
nbits = 2
3737
v = OracleGate(2, return_one_on_two)
38-
assert qapply(v*IntQubit(0, nbits)) == IntQubit(0, nbits)
39-
assert qapply(v*IntQubit(1, nbits)) == IntQubit(1, nbits)
38+
assert qapply(v*IntQubit(0, nbits)) == IntQubit(0, nqubits=nbits)
39+
assert qapply(v*IntQubit(1, nbits)) == IntQubit(1, nqubits=nbits)
4040
assert qapply(v*IntQubit(2, nbits)) == -IntQubit(2, nbits)
4141
assert qapply(v*IntQubit(3, nbits)) == IntQubit(3, nbits)
4242

43-
# Due to a bug of IntQubit, this first assertion is buggy
44-
# assert represent(OracleGate(1, lambda qubits: qubits == IntQubit(0)), nqubits=1) == \
45-
# Matrix([[-1/sqrt(2), 0], [0, 1/sqrt(2)]])
43+
assert represent(OracleGate(1, lambda qubits: qubits == IntQubit(0)), nqubits=1) == \
44+
Matrix([[-1, 0], [0, 1]])
4645
assert represent(v, nqubits=2) == Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
4746

47+
4848
def test_WGate():
4949
nqubits = 2
5050
basis_states = superposition_basis(nqubits)
5151
assert qapply(WGate(nqubits)*basis_states) == basis_states
5252

53-
expected = ((2/sqrt(pow(2, nqubits)))*basis_states) - IntQubit(1, nqubits)
54-
assert qapply(WGate(nqubits)*IntQubit(1, nqubits)) == expected
53+
expected = ((2/sqrt(pow(2, nqubits)))*basis_states) - IntQubit(1, nqubits=nqubits)
54+
assert qapply(WGate(nqubits)*IntQubit(1, nqubits=nqubits)) == expected
5555

5656

5757
def test_grover_iteration_1():
5858
numqubits = 2
5959
basis_states = superposition_basis(numqubits)
6060
v = OracleGate(numqubits, return_one_on_one)
61-
expected = IntQubit(1, numqubits)
61+
expected = IntQubit(1, nqubits=numqubits)
6262
assert qapply(grover_iteration(basis_states, v)) == expected
6363

6464

@@ -83,7 +83,7 @@ def test_grover_iteration_2():
8383

8484
def test_grover():
8585
nqubits = 2
86-
assert apply_grover(return_one_on_one, nqubits) == IntQubit(1, nqubits)
86+
assert apply_grover(return_one_on_one, nqubits) == IntQubit(1, nqubits=nqubits)
8787

8888
nqubits = 4
8989
basis_states = superposition_basis(nqubits)

sympy/physics/quantum/tests/test_qubit.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@ def test_IntQubit():
5757
iqb = IntQubit(0, nqubits=1)
5858
assert qubit_to_matrix(Qubit('0')) == qubit_to_matrix(iqb)
5959

60-
iqb = IntQubit(1, nbits=1)
60+
iqb = IntQubit(1, nqubits=1)
6161
assert qubit_to_matrix(Qubit('1')) == qubit_to_matrix(iqb)
62+
assert qubit_to_matrix(IntQubit(1)) == qubit_to_matrix(iqb)
6263

63-
iqb = IntQubit(7, nbits=4)
64-
assert qubit_to_matrix(Qubit('111')) == qubit_to_matrix(iqb)
64+
iqb = IntQubit(7, nqubits=4)
65+
assert qubit_to_matrix(Qubit('0111')) == qubit_to_matrix(iqb)
66+
assert qubit_to_matrix(IntQubit(7, 4)) == qubit_to_matrix(iqb)
6567

6668
iqb = IntQubit(8)
6769
assert iqb.as_int() == 8
@@ -177,6 +179,9 @@ def test_measure_all():
177179
assert measure_all(state2) == \
178180
[(Qubit('00'), Rational(4, 5)), (Qubit('11'), Rational(1, 5))]
179181

182+
# from issue #12585
183+
assert measure_all(qapply(Qubit('0'))) == [(Qubit('0'), 1)]
184+
180185

181186
def test_eval_trace():
182187
q1 = Qubit('10110')

0 commit comments

Comments
 (0)