Skip to content

Commit

Permalink
fix #476
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghuan committed Feb 2, 2024
1 parent 7191013 commit 9844824
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ local ArrayDictionary = (function ()
end
else
local dictionary, comparer = ...
if type(dictionary) ~= "number" then
if type(dictionary) ~= "number" then
buildFromDictionary(this, dictionary)
end
Comparer = comparer
Expand All @@ -556,7 +556,7 @@ local ArrayDictionary = (function ()
Add = function (this, ...)
local k, v
if select("#", ...) == 1 then
local pair = ...
local pair = ...
k, v = pair[1], pair[2]
else
k, v = ...
Expand Down Expand Up @@ -650,8 +650,8 @@ local ArrayDictionary = (function ()
this:set(key, value)
return true
end,
TryGetValue = function (this, key)
if key == nil then throw(ArgumentNullException("key")) end
TryGetValue = function (this, key, hasNil)
if key == nil and not hasNil then throw(ArgumentNullException("key")) end
local len = #this
if len > 0 then
local comparer = this.comparer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,21 @@ local define = System.define
local throw = System.throw
local equalsObj = System.equalsObj
local compareObj = System.compareObj
local hashObj = System.hashObj
local ArgumentException = System.ArgumentException
local ArgumentNullException = System.ArgumentNullException

local type = type

local EqualityComparer
EqualityComparer = define("System.EqualityComparer", function (T)
local equals
local Equals = T.Equals
if Equals then
if T.class == 'S' then
equals = Equals
else
equals = function (x, y)
return x:Equals(y)
end
end
local equals, getHashCode
if T.class == 'S' then
equals = T.Equals or equalsObj
getHashCode = T.GetHashCode
else
equals = equalsObj
end
local function getHashCode(x)
if type(x) == "table" then
return x:GetHashCode()
end
return x
equals = T.Equals and function (x, y) return x:Equals(y) end or equalsObj
getHashCode = hashObj
end
local defaultComparer
return {
Expand Down
49 changes: 22 additions & 27 deletions CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,23 @@ local Grouping = define("System.Linq.Grouping", function (TKey, TElement)
}
end, nil, 2)

local function getGrouping(this, key)
local hashCode = this.comparer:GetHashCodeOf(key)
local groupIndex = this.indexs[hashCode]
return this.groups[groupIndex]
local function getGrouping(this, key, create)
local t = this.groups
local found, group = t:TryGetValue(key, true)
if found then return group end
if create then
group = setmetatable({ key = key }, Grouping(this.__genericTKey__, this.__genericTElement__))
t[#t + 1] = setmetatable({ key, group }, t.__genericT__)
end
return group
end

local Lookup = {
__ctor__ = function (this, comparer)
this.comparer = comparer or EqualityComparer(this.__genericTKey__).getDefault()
this.groups = {}
this.indexs = {}
local TKey = this.__genericTKey__
comparer = comparer or EqualityComparer(TKey).getDefault()
local G = Grouping(TKey, this.__genericTElement__)
this.groups = System.Dictionary(TKey, G)(comparer)
end,
get = function (this, key)
local grouping = getGrouping(this, key)
Expand All @@ -303,7 +309,7 @@ local Lookup = {
return getGrouping(this, key) ~= nil
end,
GetEnumerator = function (this)
return arrayEnumerator(this.groups, IGrouping)
return this.groups:getValues():GetEnumerator()
end
}

Expand All @@ -316,18 +322,7 @@ local LookupFn = define("System.Linq.Lookup", function(TKey, TElement)
end, Lookup, 2)

local function addToLookup(this, key, value)
local hashCode = this.comparer:GetHashCodeOf(key)
local groupIndex = this.indexs[hashCode]
local group
if groupIndex == nil then
groupIndex = #this.groups + 1
this.indexs[hashCode] = groupIndex
group = setmetatable({ key = key }, Grouping(this.__genericTKey__, this.__genericTElement__))
this.groups[groupIndex] = group
else
group = this.groups[groupIndex]
assert(group)
end
local group = getGrouping(this, key, true)
group[#group + 1] = wrap(value)
end

Expand Down Expand Up @@ -400,15 +395,15 @@ local function ordered(source, compare)
local orderedEnumerable = createEnumerable(T, function()
local t = {}
local index = 0
return createEnumerator(T, source, function()
return createEnumerator(T, source, function()
index = index + 1
local v = t[index]
if v ~= nil then
return true, unWrap(v)
end
return false
end,
function()
end,
function()
local count = 1
if isDictLike(source) then
for k, v in pairs(source) do
Expand All @@ -423,7 +418,7 @@ local function ordered(source, compare)
end
if count > 1 then
tsort(t, function(x, y)
return compare(unWrap(x), unWrap(y)) < 0
return compare(unWrap(x), unWrap(y)) < 0
end)
end
end)
Expand All @@ -436,9 +431,9 @@ end
local function orderBy(source, keySelector, comparer, TKey, descending)
if source == nil then throw(ArgumentNullException("source")) end
if keySelector == nil then throw(ArgumentNullException("keySelector")) end
if comparer == nil then comparer = Comparer_1(TKey).getDefault() end
if comparer == nil then comparer = Comparer_1(TKey).getDefault() end
local keys = {}
local function getKey(t)
local function getKey(t)
local k = keys[t]
if k == nil then
k = keySelector(t)
Expand Down Expand Up @@ -487,7 +482,7 @@ local function thenBy(source, keySelector, comparer, TKey, descending)
if keySelector == nil then throw(ArgumentNullException("keySelector")) end
if comparer == nil then comparer = Comparer_1(TKey).getDefault() end
local keys = {}
local function getKey(t)
local function getKey(t)
local k = keys[t]
if k == nil then
k = keySelector(t)
Expand Down
11 changes: 11 additions & 0 deletions CSharp.lua/CoreSystem.Lua/CoreSystem/Core.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1203,13 +1203,23 @@ local function hash(v)
return addr(v)
end

local function hashObj(obj)
if obj == nil then return 0 end
local t = type(obj)
if t == "table" then
return obj:GetHashCode()
end
return hash(obj)
end

System.hasHash = function (t)
return t.GetHashCode ~= hash
end

System.equalsObj = equalsObj
System.compareObj = compareObj
System.hash = hash
System.hashObj = hashObj
System.toString = toString

Object = defCls("System.Object", {
Expand Down Expand Up @@ -1261,6 +1271,7 @@ ValueType = defCls("System.ValueType", {
end
end,
EqualsObj = function (this, obj)
if this == obj then return true end
if getmetatable(this) ~= getmetatable(obj) then return false end
for k, v in pairs(this) do
if not equalsObj(v, obj[k]) then
Expand Down
4 changes: 0 additions & 4 deletions CSharp.lua/CoreSystem.Lua/Sample/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,3 @@ test(testIO, "IO")
--test(testAsync, "Async")
--test(testAsyncForeach, "testAsyncForeach")


local dt = System.DateTime()
local o = System.Nullable.clone(dt)
print(System.Nullable.GetHashCode(dt), o:GetHashCode())

0 comments on commit 9844824

Please sign in to comment.