Skip to content

Commit

Permalink
1.2.4 (#135)
Browse files Browse the repository at this point in the history
Added minor symengine features
  • Loading branch information
KristianJensen authored Oct 15, 2017
1 parent 95966da commit c34103b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
18 changes: 15 additions & 3 deletions optlang/symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@
)


if USE_SYMENGINE: # pragma: no cover
if USE_SYMENGINE: # pragma: no cover # noqa: C901
import operator
from six.moves import reduce

Integer = symengine.Integer
Real = symengine.RealDouble
Basic = symengine.Basic
Number = symengine.Number
Zero = Integer(0)
One = Integer(1)
NegativeOne = Integer(-1)
Expand Down Expand Up @@ -84,12 +86,16 @@ def __getnewargs__(self):
def add(*args):
if len(args) == 1:
args = args[0]
return sum(args)
elif len(args) == 0:
return Zero
return Add(*args)

def mul(*args):
if len(args) == 1:
args = args[0]
return reduce(operator.mul, args, 1)
elif len(args) == 0:
return One # if you multiply nothing the result should be zero
return Mul(args)

else: # Use sympy
import sympy
Expand All @@ -99,6 +105,8 @@ def mul(*args):

Integer = sympy.Integer
Real = sympy.RealNumber
Basic = sympy.Basic
Number = sympy.Number
Zero = Integer(0)
One = Integer(1)
NegativeOne = Integer(-1)
Expand Down Expand Up @@ -128,9 +136,13 @@ def __init__(self, *args, **kwargs):
def add(*args):
if len(args) == 1:
args = args[0]
elif len(args) == 0:
return Zero
return sympy.Add._from_args(args)

def mul(*args):
if len(args) == 1:
args = args[0]
elif len(args) == 0:
return One
return sympy.Mul._from_args(args)
5 changes: 5 additions & 0 deletions optlang/tests/abstract_test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def test_change_name(self):
self.model.remove(self.var)
self.model.update()

def test_non_string_name_raises(self):
for name in [2, None, True, ["name1", "name2"]]:
with self.assertRaises(TypeError):
self.interface.Variable(name)

@abc.abstractmethod
def test_get_primal(self):
pass
Expand Down
11 changes: 11 additions & 0 deletions optlang/tests/test_symbolics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import unittest
import optlang


class SymbolicsTestCase(unittest.TestCase):

def test_add_identity(self):
self.assertEqual(optlang.symbolics.add(), 0)

def test_mul_identity(self):
self.assertEqual(optlang.symbolics.mul(), 1)

0 comments on commit c34103b

Please sign in to comment.