Skip to content

Commit

Permalink
fix some luacheck warning
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghuan committed Feb 23, 2024
1 parent 5f19a1f commit e6c615c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 86 deletions.
7 changes: 3 additions & 4 deletions CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ if coroutine ~= nil then
cyield = coroutine.yield
end

local null = {}
local null = { GetHashCode = System.zeroFn }
local arrayEnumerator
local arrayFromTable

Expand Down Expand Up @@ -527,9 +527,8 @@ end
local function checkOrderUniqueAndUnfoundElements(t, n, comparer, other, returnIfUnfound)
if n == 0 then
local numElementsInOther = 0
for _, v in each(other) do
numElementsInOther = numElementsInOther + 1
break
if other:GetEnumerator():MoveNext() then
numElementsInOther = 1
end
return 0, numElementsInOther
end
Expand Down
115 changes: 43 additions & 72 deletions CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
--]]

local System = System
local System = _G.System
local define = System.define
local throw = System.throw
local each = System.each
Expand All @@ -38,18 +38,15 @@ local Comparer_1 = System.Comparer_1
local Empty = System.Array.Empty

local IEnumerable_1 = System.IEnumerable_1
local IEnumerable = System.IEnumerable
local IEnumerator_1 = System.IEnumerator_1
local IEnumerator = System.IEnumerator

local assert = assert
local getmetatable = getmetatable
local setmetatable = setmetatable
local select = select
local pairs = pairs
local tsort = table.sort

local InternalEnumerable = define("System.Linq.InternalEnumerable", function(T)
local InternalEnumerable = define("System.Linq.InternalEnumerable", function(T)
return {
base = { IEnumerable_1(T) }
}
Expand All @@ -60,7 +57,7 @@ local function createEnumerable(T, GetEnumerator)
return setmetatable({ __genericT__ = T, GetEnumerator = GetEnumerator }, InternalEnumerable(T))
end

local InternalEnumerator = define("System.Linq.InternalEnumerator", function(T)
local InternalEnumerator = define("System.Linq.InternalEnumerator", function(T)
return {
base = { IEnumerator_1(T) }
}
Expand All @@ -76,10 +73,10 @@ local function createEnumerator(T, source, tryGetNext, init)
if state == 1 then
state = 2
if source then
en = source:GetEnumerator()
en = source:GetEnumerator()
end
if init then
init(en)
init(en)
end
end
if state == 2 then
Expand All @@ -91,7 +88,7 @@ local function createEnumerator(T, source, tryGetNext, init)
local dispose = en.Dispose
if dispose then
dispose(en)
end
end
end
end
return false
Expand All @@ -109,7 +106,7 @@ function Enumerable.Where(source, predicate)
if source == nil then throw(ArgumentNullException("source")) end
if predicate == nil then throw(ArgumentNullException("predicate")) end
local T = source.__genericT__
return createEnumerable(T, function()
return createEnumerable(T, function()
local index = -1
return createEnumerator(T, source, function(en)
while en:MoveNext() do
Expand All @@ -118,7 +115,7 @@ function Enumerable.Where(source, predicate)
if predicate(current, index) then
return true, current
end
end
end
return false
end)
end)
Expand All @@ -129,7 +126,7 @@ function Enumerable.Select(source, selector, T)
if selector == nil then throw(ArgumentNullException("selector")) end
return createEnumerable(T, function()
local index = -1
return createEnumerator(T, source, function(en)
return createEnumerator(T, source, function(en)
if en:MoveNext() then
index = index + 1
return true, selector(en:getCurrent(), index)
Expand Down Expand Up @@ -165,7 +162,7 @@ local function selectMany(source, collectionSelector, resultSelector, T)
end)
end

local function identityFnOfSelectMany(s, x)
local function identityFnOfSelectMany(_, x)
return x
end

Expand Down Expand Up @@ -295,9 +292,13 @@ local function getGrouping(this, key, create)
return nil
end

local function getComparer(source, comparer)
return comparer or EqualityComparer(source.__genericT__).getDefault()
end

local Lookup = {
__ctor__ = function (this, comparer)
this.comparer = comparer or EqualityComparer(this.__genericTKey__).getDefault()
this.comparer = getComparer(this, comparer)
end,
get = function (this, key)
local grouping = getGrouping(this, key)
Expand Down Expand Up @@ -604,7 +605,7 @@ function Enumerable.Concat(first, second)
end)
end

function Enumerable.Zip(first, second, resultSelector, TResult)
function Enumerable.Zip(first, second, resultSelector, TResult)
if first == nil then throw(ArgumentNullException("first")) end
if second == nil then throw(ArgumentNullException("second")) end
if resultSelector == nil then throw(ArgumentNullException("resultSelector")) end
Expand All @@ -621,39 +622,15 @@ function Enumerable.Zip(first, second, resultSelector, TResult)
end)
end

local function addToSet(set, v, getHashCode, comparer)
local hashCode = getHashCode(comparer, v)
if set[hashCode] == nil then
set[hashCode] = true
return true
end
return false
end

local function removeFromSet(set, v, getHashCode, comparer)
local hashCode = getHashCode(comparer, v)
if set[hashCode] ~= nil then
set[hashCode] = nil
return true
end
return false
end

local function getComparer(source, comparer)
return comparer or EqualityComparer(source.__genericT__).getDefault()
end

function Enumerable.Distinct(source, comparer)
if source == nil then throw(ArgumentNullException("source")) end
local T = source.__genericT__
return createEnumerable(T, function()
local set = {}
comparer = getComparer(source, comparer)
local getHashCode = comparer.GetHashCodeOf
local set = System.HashSet(T)(comparer)
return createEnumerator(T, source, function(en)
while en:MoveNext() do
local current = en:getCurrent()
if addToSet(set, current, getHashCode, comparer) then
if set:Add(current) then
return true, current
end
end
Expand All @@ -667,23 +644,21 @@ function Enumerable.Union(first, second, comparer)
if second == nil then throw(ArgumentNullException("second")) end
local T = first.__genericT__
return createEnumerable(T, function()
local set = {}
comparer = getComparer(first, comparer)
local getHashCode = comparer.GetHashCodeOf
local set = System.HashSet(T)(comparer)
local secondEn
return createEnumerator(T, first, function(en)
if secondEn == nil then
while en:MoveNext() do
local current = en:getCurrent()
if addToSet(set, current, getHashCode, comparer) then
if set:Add(current) then
return true, current
end
end
secondEn = second:GetEnumerator()
end
while secondEn:MoveNext() do
local current = secondEn:getCurrent()
if addToSet(set, current, getHashCode, comparer) then
if set:Add(current) then
return true, current
end
end
Expand All @@ -697,46 +672,42 @@ function Enumerable.Intersect(first, second, comparer)
if second == nil then throw(ArgumentNullException("second")) end
local T = first.__genericT__
return createEnumerable(T, function()
local set = {}
comparer = getComparer(first, comparer)
local getHashCode = comparer.GetHashCodeOf
local set = System.HashSet(T)(comparer)
return createEnumerator(T, first, function(en)
while en:MoveNext() do
local current = en:getCurrent()
if removeFromSet(set, current, getHashCode, comparer) then
if set:Remove(current) then
return true, current
end
end
return false
end,
function()
for _, v in each(second) do
addToSet(set, v, getHashCode, comparer)
set:Add(v)
end
end)
end)
end)
end

function Enumerable.Except(first, second, comparer)
if first == nil then throw(ArgumentNullException("first")) end
if second == nil then throw(ArgumentNullException("second")) end
local T = first.__genericT__
return createEnumerable(T, function()
local set = {}
comparer = getComparer(first, comparer)
local getHashCode = comparer.GetHashCodeOf
local set = System.HashSet(T)(comparer)
return createEnumerator(T, first, function(en)
while en:MoveNext() do
local current = en:getCurrent()
if addToSet(set, current, getHashCode, comparer) then
if set:Add(current) then
return true, current
end
end
return false
end,
function()
for _, v in each(second) do
addToSet(set, v, getHashCode, comparer)
set:Add(v)
end
end)
end)
Expand Down Expand Up @@ -843,7 +814,7 @@ end
function Enumerable.DefaultIfEmpty(source)
if source == nil then throw(ArgumentNullException("source")) end
local T = source.__genericT__
local state
local state
return createEnumerable(T, function()
return createEnumerator(T, source, function(en)
if not state then
Expand All @@ -866,7 +837,7 @@ end
function Enumerable.OfType(source, T)
if source == nil then throw(ArgumentNullException("source")) end
return createEnumerable(T, function()
return createEnumerator(T, source, function(en)
return createEnumerator(T, source, function(en)
while en:MoveNext() do
local current = en:getCurrent()
if is(current, T) then
Expand All @@ -882,7 +853,7 @@ function Enumerable.Cast(source, T)
if source == nil then throw(ArgumentNullException("source")) end
if is(source, IEnumerable_1(T)) then return source end
return createEnumerable(T, function()
return createEnumerator(T, source, function(en)
return createEnumerator(T, source, function(en)
if en:MoveNext() then
return true, cast(T, en:getCurrent())
end
Expand All @@ -907,7 +878,7 @@ local function first(source, ...)
end
else
local en = source:GetEnumerator()
if en:MoveNext() then
if en:MoveNext() then
return true, en:getCurrent()
end
end
Expand All @@ -916,7 +887,7 @@ local function first(source, ...)
local predicate = ...
if predicate == nil then throw(ArgumentNullException("predicate")) end
for _, v in each(source) do
if predicate(v) then
if predicate(v) then
return true, v
end
end
Expand Down Expand Up @@ -949,7 +920,7 @@ local function last(source, ...)
end
else
local en = source:GetEnumerator()
if en:MoveNext() then
if en:MoveNext() then
local result
repeat
result = en:getCurrent()
Expand All @@ -967,7 +938,7 @@ local function last(source, ...)
result = v
found = true
end
end
end
if found then return true, result end
return false, 1
end
Expand Down Expand Up @@ -1020,8 +991,8 @@ local function single(source, ...)
found = true
end
end
if found then return true, result end
return false, 0
if found then return true, result end
return false, 0
end
end

Expand Down Expand Up @@ -1077,7 +1048,7 @@ function Enumerable.Range(start, count)
return createEnumerator(Int32, nil, function()
index = index + 1
if index < count then
return true, start + index
return true, start + index
end
return false
end)
Expand All @@ -1091,7 +1062,7 @@ function Enumerable.Repeat(element, count, T)
return createEnumerator(T, nil, function()
index = index + 1
if index < count then
return true, element
return true, element
end
return false
end)
Expand Down Expand Up @@ -1136,8 +1107,8 @@ function Enumerable.Count(source, ...)
end
local count = 0
local en = source:GetEnumerator()
while en:MoveNext() do
count = count + 1
while en:MoveNext() do
count = count + 1
end
return count
else
Expand Down Expand Up @@ -1221,7 +1192,7 @@ end
local function minOrMax(compareFn, source, ...)
if source == nil then throw(ArgumentNullException("source")) end
local len = select("#", ...)
local selector, T
local selector, T
if len == 0 then
selector, T = identityFn, source.__genericT__
else
Expand All @@ -1236,7 +1207,7 @@ local function minOrMax(compareFn, source, ...)
x = selector(x)
if x ~= nil and (value == nil or compareFn(compare, comparer, x, value)) then
value = x
end
end
end
return value
else
Expand Down
4 changes: 2 additions & 2 deletions CSharp.lua/CoreSystem.Lua/CoreSystem/Core.lua
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ local function xpcallErr(e)
return e
end

local function try(try, catch, finally)
local ok, status, result = xpcall(try, xpcallErr)
local function try(tryFn, catch, finally)
local ok, status, result = xpcall(tryFn, xpcallErr)
if not ok then
if catch then
if finally then
Expand Down
Loading

0 comments on commit e6c615c

Please sign in to comment.