Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghuan committed Jan 10, 2024
2 parents f7c21b0 + 7020564 commit e83c02d
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 35 deletions.
2 changes: 1 addition & 1 deletion CSharp.lua/CSharp.lua.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.6.0" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.8.0" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ local MultiArray = {
Clone = function (this)
local array = { __rank__ = this.__rank__ }
tmove(this, 1, #this, 1, array)
return arrayFromTable(array, this.__genericT__)
return setmetatable(array, Array(this.__genericT__, #this.__rank__))
end
}

Expand Down
5 changes: 5 additions & 0 deletions CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,11 @@ function Enumerable.Cast(source, T)
end)
end

function Enumerable.AsEnumerable(source)
if source == nil then throw(ArgumentNullException("source")) end
return source
end

local function first(source, ...)
if source == nil then throw(ArgumentNullException("source")) end
local len = select("#", ...)
Expand Down
25 changes: 11 additions & 14 deletions CSharp.lua/LuaAst/LuaIfStatementSyntax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,29 +63,26 @@ internal override void Render(LuaRenderer renderer) {

public sealed class LuaSwitchAdapterStatementSyntax : LuaStatementSyntax {
public readonly LuaRepeatStatementSyntax RepeatStatement = new(LuaIdentifierNameSyntax.One);
public LuaIdentifierNameSyntax Temp { get; }
public LuaBlockSyntax Body => RepeatStatement.Body;

public LuaIdentifierNameSyntax Temp { get; set; }
private LuaBlockSyntax defaultBlock_;
private readonly LuaLocalVariablesSyntax caseLabelVariables_ = new();
public LuaIdentifierNameSyntax DefaultLabel { get; set; }
public readonly Dictionary<int, LuaIdentifierNameSyntax> CaseLabels = new();
private LuaIfStatementSyntax headIfStatement_;

public LuaSwitchAdapterStatementSyntax(LuaIdentifierNameSyntax temp) {
Temp = temp;
public LuaSwitchAdapterStatementSyntax() {
}

public void Fill(LuaExpressionSyntax expression, IEnumerable<LuaStatementSyntax> sections) {
if (expression == null) {
throw new ArgumentNullException(nameof(expression));
}
public void Fill(IEnumerable<LuaStatementSyntax> sections) {
if (sections == null) {
throw new ArgumentNullException(nameof(sections));
}

var body = RepeatStatement.Body;
var body = Body;
body.Statements.Add(caseLabelVariables_);
body.Statements.Add(new LuaLocalVariableDeclaratorSyntax(Temp, expression));


LuaIfStatementSyntax ifStatement = null;
foreach (var section in sections) {
if (section is LuaIfStatementSyntax statement) {
Expand Down Expand Up @@ -122,10 +119,10 @@ private void CheckHasDefaultLabel() {
Contract.Assert(defaultBlock_ != null);
caseLabelVariables_.Variables.Add(DefaultLabel);
LuaLabeledStatement labeledStatement = new LuaLabeledStatement(DefaultLabel);
RepeatStatement.Body.Statements.Add(labeledStatement);
Body.Statements.Add(labeledStatement);
LuaIfStatementSyntax ifStatement = new LuaIfStatementSyntax(DefaultLabel);
ifStatement.Body.Statements.AddRange(defaultBlock_.Statements);
RepeatStatement.Body.Statements.Add(ifStatement);
Body.Statements.Add(ifStatement);
}
}

Expand All @@ -143,10 +140,10 @@ private void CheckHasCaseLabel() {
caseLabelVariables_.Variables.AddRange(CaseLabels.Values);
foreach (var (index, labelIdentifier) in CaseLabels) {
var caseLabelStatement = FindMatchIfStatement(index);
RepeatStatement.Body.Statements.Add(new LuaLabeledStatement(labelIdentifier));
Body.Statements.Add(new LuaLabeledStatement(labelIdentifier));
LuaIfStatementSyntax ifStatement = new LuaIfStatementSyntax(labelIdentifier);
ifStatement.Body.Statements.AddRange(caseLabelStatement.Statements);
RepeatStatement.Body.Statements.Add(ifStatement);
Body.Statements.Add(ifStatement);
}
}
}
Expand Down
14 changes: 10 additions & 4 deletions CSharp.lua/LuaSyntaxNodeTransform.Object.cs
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,10 @@ public override LuaSyntaxNode VisitTryStatement(TryStatementSyntax node) {
}

private LuaStatementSyntax BuildUsingStatement(List<LuaIdentifierNameSyntax> variableIdentifiers, List<LuaExpressionSyntax> variableExpressions, Action<LuaBlockSyntax> writeStatements) {
var usingAdapterExpress = new LuaUsingAdapterExpressionSyntax();
return BuildUsingStatement(new LuaUsingAdapterExpressionSyntax(), variableIdentifiers, variableExpressions, writeStatements);
}

private LuaStatementSyntax BuildUsingStatement(LuaUsingAdapterExpressionSyntax usingAdapterExpress, List<LuaIdentifierNameSyntax> variableIdentifiers, List<LuaExpressionSyntax> variableExpressions, Action<LuaBlockSyntax> writeStatements) {
usingAdapterExpress.ParameterList.Parameters.AddRange(variableIdentifiers);
PushFunction(usingAdapterExpress);
var block = new LuaBlockSyntax();
Expand Down Expand Up @@ -854,7 +857,10 @@ private void ApplyUsingDeclarations(LuaBlockSyntax block, List<int> indexes, Blo

int lastIndex = indexes.Last();
var statements = block.Statements.Skip(lastIndex + 1);
var usingStatement = BuildUsingStatement(variableIdentifiers, variableExpressions, body => body.Statements.AddRange(statements));
var usingAdapterExpress = new LuaUsingAdapterExpressionSyntax() {
HasReturn = node.Statements.Any(i => i.IsKind((SyntaxKind.ReturnStatement))),
};
var usingStatement = BuildUsingStatement(usingAdapterExpress, variableIdentifiers, variableExpressions, body => body.Statements.AddRange(statements));
block.Statements.RemoveRange(indexes[position]);
block.AddStatement(usingStatement);
indexes.RemoveRange(position);
Expand Down Expand Up @@ -1434,7 +1440,7 @@ private LuaExpressionSyntax BuildPatternExpression(LuaExpressionSyntax targetExp
case SyntaxKind.RecursivePattern: {
var recursivePattern = (RecursivePatternSyntax)notPattern.Pattern;
var governingIdentifier = GetIdentifierNameFromExpression(targetExpression);
var expression = BuildRecursivePatternExpression(recursivePattern, governingIdentifier, null, targetNode);
var expression = BuildRecursivePatternExpression(recursivePattern, governingIdentifier, targetNode);
return expression.Parenthesized().Not();
}
case SyntaxKind.ConstantPattern: {
Expand Down Expand Up @@ -1480,7 +1486,7 @@ private LuaExpressionSyntax BuildPatternExpression(LuaExpressionSyntax targetExp
} else {
name = GetIdentifierNameFromExpression(targetExpression);
}
return BuildRecursivePatternExpression(recursivePattern, name, null, targetNode);
return BuildRecursivePatternExpression(recursivePattern, name, targetNode);
}
}
}
Expand Down
67 changes: 52 additions & 15 deletions CSharp.lua/LuaSyntaxNodeTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1239,20 +1239,24 @@ public override LuaSyntaxNode VisitIndexerDeclaration(IndexerDeclarationSyntax n
var indexName = GetMemberName(symbol);
var parameterList = node.ParameterList.Accept<LuaParameterListSyntax>(this);

void Fill(Action<LuaFunctionExpressionSyntax, LuaPropertyOrEventIdentifierNameSyntax> action) {
void Fill(IMethodSymbol accessorSymbol, Action<LuaFunctionExpressionSyntax, LuaPropertyOrEventIdentifierNameSyntax> action) {
var methodInfo = new MethodInfo(accessorSymbol);
methodInfos_.Push(methodInfo);
var function = new LuaFunctionExpressionSyntax();
function.AddParameter(LuaIdentifierNameSyntax.This);
function.ParameterList.Parameters.AddRange(parameterList.Parameters);
var name = new LuaPropertyOrEventIdentifierNameSyntax(true, indexName);
PushFunction(function);
action(function, name);
PopFunction();
methodInfos_.Pop();
CurType.AddMethod(name, function, isPrivate);
}

if (node.AccessorList != null) {
foreach (var accessor in node.AccessorList.Accessors) {
Fill((function, name) => {
var accessorSymbol = semanticModel_.GetDeclaredSymbol(accessor);
Fill(accessorSymbol, (function, name) => {
bool isGet = accessor.IsKind(SyntaxKind.GetAccessorDeclaration);
if (accessor.Body != null) {
var block = accessor.Body.Accept<LuaBlockSyntax>(this);
Expand All @@ -1272,7 +1276,7 @@ void Fill(Action<LuaFunctionExpressionSyntax, LuaPropertyOrEventIdentifierNameSy
});
}
} else {
Fill((function, name) => {
Fill(symbol.GetMethod, (function, name) => {
var bodyExpression = node.ExpressionBody.AcceptExpression(this);
function.AddStatement(new LuaReturnStatementSyntax(bodyExpression));
});
Expand Down Expand Up @@ -3675,11 +3679,19 @@ public override LuaSyntaxNode VisitElseClause(ElseClauseSyntax node) {
}

public override LuaSyntaxNode VisitSwitchStatement(SwitchStatementSyntax node) {
var temp = GetTempIdentifier();
var switchStatement = new LuaSwitchAdapterStatementSyntax(temp);
var switchStatement = new LuaSwitchAdapterStatementSyntax();
switches_.Push(switchStatement);
PushBlock(switchStatement.Body);
var expression = node.Expression.AcceptExpression(this);
switchStatement.Fill(expression, node.Sections.Select(i => i.Accept<LuaStatementSyntax>(this)));
if (expression is LuaIdentifierNameSyntax name) {
switchStatement.Temp = name;
} else {
var temp = GetTempIdentifier();
switchStatement.Temp = temp;
switchStatement.Body.Statements.Add(new LuaLocalVariableDeclaratorSyntax(temp, expression));
}
switchStatement.Fill(node.Sections.Select(i => i.Accept<LuaStatementSyntax>(this)));
PopBlock();
switches_.Pop();
return switchStatement;
}
Expand Down Expand Up @@ -3713,8 +3725,14 @@ public override LuaSyntaxNode VisitSwitchSection(SwitchSectionSyntax node) {

public override LuaSyntaxNode VisitCaseSwitchLabel(CaseSwitchLabelSyntax node) {
var left = switches_.Peek().Temp;
var right = node.Value.AcceptExpression(this);
return left.EqualsEquals(right);
var symbol = semanticModel_.GetSymbolInfo(node.Value).Symbol;
if (symbol?.Kind == SymbolKind.NamedType) {
var switchStatement = (SwitchStatementSyntax)FindParent(node, SyntaxKind.SwitchStatement);
return BuildTypePattern(node.Value, left, switchStatement.Expression, null);
} else {
var right = node.Value.AcceptExpression(this);
return left.EqualsEquals(right);
}
}

private LuaExpressionSyntax BuildSwitchLabelWhenClause(LuaExpressionSyntax expression, WhenClauseSyntax whenClause) {
Expand All @@ -3735,14 +3753,19 @@ private LuaExpressionSyntax BuildDeclarationPattern(DeclarationPatternSyntax dec
if (!declarationPattern.Designation.IsKind(SyntaxKind.DiscardDesignation)) {
AddLocalVariableMapping(left, declarationPattern.Designation);
}
var isExpression = BuildIsPatternExpression(expressionType, declarationPattern.Type, left);

return BuildTypePattern(declarationPattern.Type, left, expressionType, whenClause);
}

private LuaExpressionSyntax BuildTypePattern(ExpressionSyntax typePattern, LuaIdentifierNameSyntax left, ExpressionSyntax expressionType, WhenClauseSyntax whenClause) {
var isExpression = BuildIsPatternExpression(expressionType, typePattern, left);
if (isExpression == LuaIdentifierLiteralExpressionSyntax.True) {
return whenClause != null ? whenClause.AcceptExpression(this) : LuaIdentifierLiteralExpressionSyntax.True;
}

return BuildSwitchLabelWhenClause(isExpression, whenClause);
}

}
public override LuaSyntaxNode VisitCasePatternSwitchLabel(CasePatternSwitchLabelSyntax node) {
var left = switches_.Peek().Temp;
switch (node.Pattern.Kind()) {
Expand All @@ -3759,7 +3782,11 @@ public override LuaSyntaxNode VisitCasePatternSwitchLabel(CasePatternSwitchLabel
case SyntaxKind.RecursivePattern: {
var recursivePattern = (RecursivePatternSyntax)node.Pattern;
var switchStatement = (SwitchStatementSyntax)FindParent(node, SyntaxKind.SwitchStatement);
var expression = BuildRecursivePatternExpression(recursivePattern, left, null, switchStatement.Expression);
LuaLocalVariablesSyntax deconstruct = null;
var expression = BuildRecursivePatternExpression(recursivePattern, left, ref deconstruct, switchStatement.Expression);
if (deconstruct != null) {
CurBlock.AddStatement(deconstruct);
}
return BuildSwitchLabelWhenClause(expression, node.WhenClause);
}
case SyntaxKind.AndPattern:
Expand All @@ -3768,6 +3795,11 @@ public override LuaSyntaxNode VisitCasePatternSwitchLabel(CasePatternSwitchLabel
var expression = BuildPatternExpression(left, node.Pattern, switchStatement.Expression);
return BuildSwitchLabelWhenClause(expression, node.WhenClause);
}
case SyntaxKind.TypePattern: {
var switchStatement = (SwitchStatementSyntax)FindParent(node, SyntaxKind.SwitchStatement);
var typePattern = (TypePatternSyntax)node.Pattern;
return BuildTypePattern(typePattern.Type, left, switchStatement.Expression, node.WhenClause);
}
default: {
var patternExpression = node.Pattern.AcceptExpression(this);
var expression = left.EqualsEquals(patternExpression);
Expand Down Expand Up @@ -3840,7 +3872,12 @@ public override LuaSyntaxNode VisitRelationalPattern(RelationalPatternSyntax nod
return node.Expression.AcceptExpression(this);
}

private LuaExpressionSyntax BuildRecursivePatternExpression(RecursivePatternSyntax recursivePattern, LuaIdentifierNameSyntax governingIdentifier, LuaLocalVariablesSyntax deconstruct, ExpressionSyntax governingExpression) {
private LuaExpressionSyntax BuildRecursivePatternExpression(RecursivePatternSyntax recursivePattern, LuaIdentifierNameSyntax governingIdentifier, ExpressionSyntax governingExpression) {
LuaLocalVariablesSyntax deconstruct = null;
return BuildRecursivePatternExpression(recursivePattern, governingIdentifier, ref deconstruct, governingExpression);
}

private LuaExpressionSyntax BuildRecursivePatternExpression(RecursivePatternSyntax recursivePattern, LuaIdentifierNameSyntax governingIdentifier, ref LuaLocalVariablesSyntax deconstruct, ExpressionSyntax governingExpression) {
var subpatterns = recursivePattern.PropertyPatternClause?.Subpatterns ?? recursivePattern.PositionalPatternClause.Subpatterns;
var subpatternExpressions = new List<LuaExpressionSyntax>();
int subpatternIndex = 0;
Expand Down Expand Up @@ -3868,8 +3905,8 @@ private LuaExpressionSyntax BuildRecursivePatternExpression(RecursivePatternSynt
var variable = deconstruct.Variables[subpatternIndex];
subpatternExpressions.Add(variable.EqualsEquals(expression));
}
++subpatternIndex;
}
++subpatternIndex;
var condition = subpatternExpressions.Count > 0
? subpatternExpressions.Aggregate((x, y) => x.And(y))
: governingIdentifier.NotEquals(LuaIdentifierNameSyntax.Nil);
Expand Down Expand Up @@ -3919,7 +3956,7 @@ public override LuaSyntaxNode VisitSwitchExpression(SwitchExpressionSyntax node)
if (recursivePattern.Designation != null) {
AddLocalVariableMapping(governingIdentifier, recursivePattern.Designation);
}
var condition = BuildRecursivePatternExpression(recursivePattern, governingIdentifier, deconstruct, node.GoverningExpression);
var condition = BuildRecursivePatternExpression(recursivePattern, governingIdentifier, ref deconstruct, node.GoverningExpression);
FillSwitchPatternSyntax(ref ifStatement, condition, arm.WhenClause, result, arm.Expression);
break;
}
Expand Down
1 change: 1 addition & 0 deletions CSharp.lua/System.xml
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ limitations under the License.
<method name="ThenByDescending" ArgCount="3" Template="Linq.ThenByDescending({0}, {1}, {2}, {`1})" />
<method name="Average" Template="Linq.Average({0}, {1})" />
<method name="DefaultIfEmpty" Template="Linq.DefaultIfEmpty({0})" />
<method name="AsEnumerable" Template="Linq.AsEnumerable({0})" />
</class>
</namespace>
<namespace name="System.Diagnostics" Name="System">
Expand Down

0 comments on commit e83c02d

Please sign in to comment.