Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored and created an abstraction for control values #5362

Merged
merged 17 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@
PhasedXZGate,
PhaseFlipChannel,
StatePreparationChannel,
ProductOfSums,
ProjectorString,
ProjectorSum,
RandomGateChannel,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _symmetricalqidpair(qids):
'PhasedXPowGate': cirq.PhasedXPowGate,
'PhasedXZGate': cirq.PhasedXZGate,
'ProductState': cirq.ProductState,
'ProductOfSums': cirq.ProductOfSums,
'ProjectorString': cirq.ProjectorString,
'ProjectorSum': cirq.ProjectorSum,
'QasmUGate': cirq.circuits.qasm_output.QasmUGate,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,5 @@
from cirq.ops.wait_gate import wait, WaitGate

from cirq.ops.state_preparation_channel import StatePreparationChannel

from cirq.ops.control_values import AbstractControlValues, ProductOfSums
18 changes: 13 additions & 5 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from cirq import protocols, value
from cirq._compat import proper_repr
from cirq._doc import document
from cirq.ops import controlled_gate, eigen_gate, gate_features, raw_types
from cirq.ops import controlled_gate, eigen_gate, gate_features, raw_types, control_values as cv

from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -159,7 +159,9 @@ def _trace_distance_bound_(self) -> Optional[float]:
def controlled(
self,
num_controls: int = None,
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
control_values: Optional[
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
"""Returns a controlled `XPowGate`, using a `CXPowGate` where possible.
Expand Down Expand Up @@ -566,7 +568,9 @@ def with_canonical_global_phase(self) -> 'ZPowGate':
def controlled(
self,
num_controls: int = None,
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
control_values: Optional[
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
"""Returns a controlled `ZPowGate`, using a `CZPowGate` where possible.
Expand Down Expand Up @@ -998,7 +1002,9 @@ def _phase_by_(self, phase_turns, qubit_index):
def controlled(
self,
num_controls: int = None,
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
control_values: Optional[
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
"""Returns a controlled `CZPowGate`, using a `CCZPowGate` where possible.
Expand Down Expand Up @@ -1187,7 +1193,9 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
def controlled(
self,
num_controls: int = None,
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
control_values: Optional[
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
"""Returns a controlled `CXPowGate`, using a `CCXPowGate` where possible.
Expand Down
168 changes: 168 additions & 0 deletions cirq-core/cirq/ops/control_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Union, Tuple, List, TYPE_CHECKING, Any, Dict, Generator, Optional
from dataclasses import dataclass

import itertools

if TYPE_CHECKING:
import cirq


# ignore type to bypass github.com/python/mypy/issues/5374.
@dataclass(frozen=True, eq=False) # type: ignore
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
class AbstractControlValues(abc.ABC):
"""AbstractControlValues is an abstract immutable data class.

AbstractControlValues defines an API for control values and implements
functions common to all implementations (e.g. comparison).
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring should be more detailed. For example:

Suggested change
"""AbstractControlValues is an abstract immutable data class.
AbstractControlValues defines an API for control values and implements
functions common to all implementations (e.g. comparison).
"""
"""Abstract immutable base class that defines the API for control values.
`cirq.ControlledGate` and `cirq.ControlledOperation` are useful to augment existing gates
and operations to have one or more control qubits. For every control qubit, the set of
integer values for which the control should be enabled is represented by
`cirq.AbstractControlValues`.
Implementations of `cirq.AbstractControlValues` can use different internal representations
to store control values, but they should all satisfy the public API defined here.
"""

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL


_internal_representation: Any
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved

def __and__(self, other: 'AbstractControlValues') -> 'AbstractControlValues':
"""Sets self to be the cartesian product of all combinations in self x other.

Args:
other: An object that implements AbstractControlValues.

Returns:
An object that represents the cartesian product of the two inputs.
"""
return type(self)(self._internal_representation + other._internal_representation)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that self._internal_representation and other._internal_representation can be simply added together. But their types are Any, so this need not be true.

Does the current implementation assume that ProductOfSums is the only derived type, and would be special cased later once we add more derived types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is just for refactoring, this function will be rewritten in a better way once the linked structure is introduced


@abc.abstractmethod
def _expand(self):
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the control values tracked by the object."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain more what the "expanded" representation means / contains. Does it depend upon what internal representation the subclasses use? Does every tuple correspond to a "one value per qubit" representation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL


@abc.abstractmethod
def diagram_repr(self) -> str:
"""Returns a string representation to be used in circuit diagrams."""

@abc.abstractmethod
def number_variables(self) -> int:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the control values tracked by the object."""
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
def __len__(self) -> int:
pass

@abc.abstractmethod
def identifier(self) -> Tuple[Any]:
"""Returns an identifier from which the object can be rebuilt."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve docstrings, it's not clear to me what it means to "have an identifier from which object can be rebuilt". Does this mean an identifier we can use for serialization / deserialization ?

Also, when would a user use this method? Do we need this to be part of the public API ? Can we just enforce that the class should have a __repr__ defined that satisfies eval(repr(cv)) == cv?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the docstring, this function essentially returns the internal representation, I changed it to be private and I will remove it in the next PR, it's now used in functions in controlled_gate and controlled_operation that require access to the internal representation


@abc.abstractmethod
def __hash__(self):
pass

@abc.abstractmethod
def __repr__(self) -> str:
pass

@abc.abstractmethod
def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> Optional[ValueError]:
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
"""Validates control values

Validate that control values are in the half closed interval
[0, qid_shapes) for each qubit.
"""

@abc.abstractmethod
def _are_ones(self) -> bool:
"""Checks whether all control values are equal to 1."""

@abc.abstractmethod
def _json_dict_(self) -> Dict[str, Any]:
pass

@abc.abstractmethod
def __getitem__(self, key):
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
pass

def __iter__(self) -> Generator[Tuple[int], None, None]:
for assignment in self._expand():
yield assignment

def __eq__(self, other) -> bool:
"""Returns True iff self and other represent the same configurations.

Args:
other: A AbstractControlValues object.

Returns:
boolean whether the two objects are equivalent or not.
"""
if not isinstance(other, AbstractControlValues):
other = ProductOfSums(other)
return sorted(v for v in self) == sorted(v for v in other)


class ProductOfSums(AbstractControlValues):
"""ProductOfSums represents control values in a form of a cartesian product of tuples."""

_internal_representation: Tuple[Tuple[int, ...]]

def identifier(self):
return self._internal_representation

def _expand(self):
"""Returns the combinations tracked by the object."""
return itertools.product(*self._internal_representation)

def __repr__(self) -> str:
return f'cirq.ProductOfSums({str(self.identifier())})'

def number_variables(self) -> int:
return len(self._internal_representation)

def __len__(self) -> int:
return self.number_variables()

def __hash__(self):
return hash(self._internal_representation)

def validate(self, qid_shapes: Union[Tuple[int, ...], List[int]]) -> Optional[ValueError]:
for i, (vals, shape) in enumerate(zip(self._internal_representation, qid_shapes)):
if not all(0 <= v < shape for v in vals):
message = (
f'Control values <{vals!r}> outside of range for control qubit '
f'number <{i}>.'
)
return ValueError(message)
return None

def _are_ones(self) -> bool:
return frozenset(self._internal_representation) == {(1,)}

def diagram_repr(self) -> str:
if self._are_ones():
return 'C' * self.number_variables()

def get_prefix(control_vals):
control_vals_str = ''.join(map(str, sorted(control_vals)))
return f'C{control_vals_str}'

return ''.join(map(get_prefix, self._internal_representation))

def __getitem__(self, key):
if isinstance(key, slice):
return ProductOfSums(self._internal_representation[key])
return self._internal_representation[key]

def _json_dict_(self) -> Dict[str, Any]:
return {
'_internal_representation': self._internal_representation,
'cirq_type': 'ProductOfSums',
}
41 changes: 41 additions & 0 deletions cirq-core/cirq/ops/control_values_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq
from cirq.ops import control_values as cv


def test_init_productOfSum():
eq = cirq.testing.EqualsTester()
tests = [
(((1,),), {(1,)}),
(((0, 1), (1,)), {(0, 1), (1, 1)}),
((((0, 1), (1, 0))), {(0, 0), (0, 1), (1, 0), (1, 1)}),
]
for control_values, want in tests:
print(control_values)
got = {c for c in cv.ProductOfSums(control_values)}
eq.add_equality_group(got, want)


def test_and_operation():
eq = cirq.testing.EqualsTester()
originals = [((1,),), ((0, 1), (1,)), (((0, 1), (1, 0)))]
for control_values1 in originals:
for control_values2 in originals:
control_vals1 = cv.ProductOfSums(control_values1)
control_vals2 = cv.ProductOfSums(control_values2)
want = [v1 + v2 for v1 in control_vals1 for v2 in control_vals2]
got = [c for c in control_vals1 & control_vals2]
eq.add_equality_group(got, want)
Loading