diff --git a/Extension/src/LanguageServer/copilotProviders.ts b/Extension/src/LanguageServer/copilotProviders.ts new file mode 100644 index 0000000000..e3205070cb --- /dev/null +++ b/Extension/src/LanguageServer/copilotProviders.ts @@ -0,0 +1,136 @@ +/* -------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All Rights Reserved. + * See 'LICENSE' in the project root for license information. + * ------------------------------------------------------------------------------------------ */ +'use strict'; + +import * as vscode from 'vscode'; +import * as util from '../common'; +import * as telemetry from '../telemetry'; +import { ChatContextResult, GetIncludesResult } from './client'; +import { getActiveClient } from './extension'; + +let isRelatedFilesApiEnabled: boolean | undefined; + +export interface CopilotTrait { + name: string; + value: string; + includeInPrompt?: boolean; + promptTextOverride?: string; +} + +export interface CopilotApi { + registerRelatedFilesProvider( + providerId: { extensionId: string; languageId: string }, + callback: ( + uri: vscode.Uri, + context: { flags: Record }, + cancellationToken: vscode.CancellationToken + ) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] }> + ): Disposable; +} + +export async function registerRelatedFilesProvider(): Promise { + if (!await getIsRelatedFilesApiEnabled()) { + return; + } + + const api = await getCopilotApi(); + if (util.extensionContext && api) { + try { + for (const languageId of ['c', 'cpp', 'cuda-cpp']) { + api.registerRelatedFilesProvider( + { extensionId: util.extensionContext.extension.id, languageId }, + async (_uri: vscode.Uri, context: { flags: Record }, token: vscode.CancellationToken) => { + + const getIncludesHandler = async () => { + return (await getIncludesWithCancellation(1, token))?.includedFiles.map(file => vscode.Uri.file(file)) ?? []; + } + const getTraitsHandler = async () => { + const chatContext: ChatContextResult | undefined = await (getActiveClient().getChatContext(token) ?? undefined); + + if (!chatContext) { + return undefined; + } + + let traits: CopilotTrait[] = [ + { name: "language", value: chatContext.language, includeInPrompt: true, promptTextOverride: `The language is ${chatContext.language}.` }, + { name: "compiler", value: chatContext.compiler, includeInPrompt: true, promptTextOverride: `This project compiles using ${chatContext.compiler}.` }, + { name: "standardVersion", value: chatContext.standardVersion, includeInPrompt: true, promptTextOverride: `This project uses the ${chatContext.standardVersion} language standard.` }, + { name: "targetPlatform", value: chatContext.targetPlatform, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetPlatform}.` }, + { name: "targetArchitecture", value: chatContext.targetArchitecture, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetArchitecture}.` }, + ]; + + const excludeTraits = (context.flags.copilotcppExcludeTraits as string[] ?? []); + traits = traits.filter(trait => !excludeTraits.includes(trait.name)); + + return traits.length > 0 ? traits : undefined; + } + + // Call both handlers in parallel + let traitsPromise = ((context.flags.copilotcppTraits as boolean) ?? false) ? getTraitsHandler() : Promise.resolve(undefined); + let includesPromise = getIncludesHandler(); + + return ({ entries: await includesPromise, traits: await traitsPromise }); + } + ); + } + } catch { + console.log("Failed to register Copilot related files provider."); + } + } +} + +export async function registerRelatedFilesCommands(commandDisposables: vscode.Disposable[], enabled: boolean): Promise { + if (await getIsRelatedFilesApiEnabled()) { + commandDisposables.push(vscode.commands.registerCommand('C_Cpp.getIncludes', enabled ? (maxDepth: number) => getIncludes(maxDepth) : () => Promise.resolve())); + } +} + +async function getIncludesWithCancellation(maxDepth: number, token: vscode.CancellationToken): Promise { + let activeClient = getActiveClient(); + const includes = await activeClient.getIncludes(maxDepth, token); + const wksFolder = activeClient.RootUri?.toString(); + + if (!wksFolder) { + return includes; + } + + includes.includedFiles = includes.includedFiles.filter(header => vscode.Uri.file(header).toString().startsWith(wksFolder)); + return includes; +} + +async function getIncludes(maxDepth: number): Promise { + const tokenSource = new vscode.CancellationTokenSource(); + try { + const includes = await getIncludesWithCancellation(maxDepth, tokenSource.token); + return includes; + } finally { + tokenSource.dispose(); + } +} + +async function getIsRelatedFilesApiEnabled(): Promise { + if (isRelatedFilesApiEnabled === undefined) { + isRelatedFilesApiEnabled = await telemetry.isExperimentEnabled("CppToolsRelatedFilesApi"); + } + + return isRelatedFilesApiEnabled; +} + +export async function getCopilotApi(): Promise { + const copilotExtension = vscode.extensions.getExtension('github.copilot'); + if (!copilotExtension) { + return undefined; + } + + if (!copilotExtension.isActive) { + try { + return await copilotExtension.activate(); + } catch { + return undefined; + } + } else { + return copilotExtension.exports; + } +} diff --git a/Extension/src/LanguageServer/extension.ts b/Extension/src/LanguageServer/extension.ts index 47a7cd0342..bdc72886f8 100644 --- a/Extension/src/LanguageServer/extension.ts +++ b/Extension/src/LanguageServer/extension.ts @@ -20,9 +20,10 @@ import * as util from '../common'; import { getCrashCallStacksChannel } from '../logger'; import { PlatformInformation } from '../platform'; import * as telemetry from '../telemetry'; -import { Client, DefaultClient, DoxygenCodeActionCommandArguments, GetIncludesResult, openFileVersions } from './client'; +import { Client, DefaultClient, DoxygenCodeActionCommandArguments, openFileVersions } from './client'; import { ClientCollection } from './clientCollection'; import { CodeActionDiagnosticInfo, CodeAnalysisDiagnosticIdentifiersAndUri, codeAnalysisAllFixes, codeAnalysisCodeToFixes, codeAnalysisFileToCodeActions } from './codeAnalysis'; +import { registerRelatedFilesCommands, registerRelatedFilesProvider } from './copilotProviders'; import { CppBuildTaskProvider } from './cppBuildTaskProvider'; import { getCustomConfigProviders } from './customProviders'; import { getLanguageConfig } from './languageConfig'; @@ -33,24 +34,6 @@ import { CppSettings } from './settings'; import { LanguageStatusUI, getUI } from './ui'; import { makeLspRange, rangeEquals, showInstallCompilerWalkthrough } from './utils'; -interface CopilotTrait { - name: string; - value: string; - includeInPrompt?: boolean; - promptTextOverride?: string; -} - -interface CopilotApi { - registerRelatedFilesProvider( - providerId: { extensionId: string; languageId: string }, - callback: ( - uri: vscode.Uri, - context: { flags: Record }, - cancellationToken: vscode.CancellationToken - ) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] }> - ): Disposable; -} - nls.config({ messageFormat: nls.MessageFormat.bundle, bundleFormat: nls.BundleFormat.standalone })(); const localize: nls.LocalizeFunc = nls.loadMessageBundle(); export const CppSourceStr: string = "C/C++"; @@ -201,8 +184,7 @@ export async function activate(): Promise { void clients.ActiveClient.ready.then(() => intervalTimer = global.setInterval(onInterval, 2500)); - const isRelatedFilesApiEnabled = await telemetry.isExperimentEnabled("CppToolsRelatedFilesApi"); - registerCommands(true, isRelatedFilesApiEnabled); + registerCommands(true); vscode.tasks.onDidStartTask(() => getActiveClient().PauseCodeAnalysis()); @@ -274,22 +256,7 @@ export async function activate(): Promise { disposables.push(tool); } - if (isRelatedFilesApiEnabled) { - const api = await getCopilotApi(); - if (util.extensionContext && api) { - try { - for (const languageId of ['c', 'cpp', 'cuda-cpp']) { - api.registerRelatedFilesProvider( - { extensionId: util.extensionContext.extension.id, languageId }, - async (_uri: vscode.Uri, _context: { flags: Record }, token: vscode.CancellationToken) => - ({ entries: (await getIncludesWithCancellation(1, token))?.includedFiles.map(file => vscode.Uri.file(file)) ?? [] }) - ); - } - } catch { - console.log("Failed to register Copilot related files provider."); - } - } - } + registerRelatedFilesProvider(); } export function updateLanguageConfigurations(): void { @@ -386,7 +353,7 @@ function onInterval(): void { /** * registered commands */ -export function registerCommands(enabled: boolean, isRelatedFilesApiEnabled: boolean): void { +export function registerCommands(enabled: boolean): void { commandDisposables.forEach(d => d.dispose()); commandDisposables.length = 0; commandDisposables.push(vscode.commands.registerCommand('C_Cpp.SwitchHeaderSource', enabled ? onSwitchHeaderSource : onDisabledCommand)); @@ -445,9 +412,7 @@ export function registerCommands(enabled: boolean, isRelatedFilesApiEnabled: boo commandDisposables.push(vscode.commands.registerCommand('C_Cpp.ExtractToMemberFunction', enabled ? () => onExtractToFunction(false, true) : onDisabledCommand)); commandDisposables.push(vscode.commands.registerCommand('C_Cpp.ExpandSelection', enabled ? (r: Range) => onExpandSelection(r) : onDisabledCommand)); - if (!isRelatedFilesApiEnabled) { - commandDisposables.push(vscode.commands.registerCommand('C_Cpp.getIncludes', enabled ? (maxDepth: number) => getIncludes(maxDepth) : () => Promise.resolve())); - } + registerRelatedFilesCommands(commandDisposables, enabled) } function onDisabledCommand() { @@ -1412,42 +1377,3 @@ export async function preReleaseCheck(): Promise { } } } - -export async function getIncludesWithCancellation(maxDepth: number, token: vscode.CancellationToken): Promise { - const includes = await clients.ActiveClient.getIncludes(maxDepth, token); - const wksFolder = clients.ActiveClient.RootUri?.toString(); - - if (!wksFolder) { - return includes; - } - - includes.includedFiles = includes.includedFiles.filter(header => vscode.Uri.file(header).toString().startsWith(wksFolder)); - return includes; -} - -async function getIncludes(maxDepth: number): Promise { - const tokenSource = new vscode.CancellationTokenSource(); - try { - const includes = await getIncludesWithCancellation(maxDepth, tokenSource.token); - return includes; - } finally { - tokenSource.dispose(); - } -} - -async function getCopilotApi(): Promise { - const copilotExtension = vscode.extensions.getExtension('github.copilot'); - if (!copilotExtension) { - return undefined; - } - - if (!copilotExtension.isActive) { - try { - return await copilotExtension.activate(); - } catch { - return undefined; - } - } else { - return copilotExtension.exports; - } -} diff --git a/Extension/src/main.ts b/Extension/src/main.ts index fd6bacb89e..7379235277 100644 --- a/Extension/src/main.ts +++ b/Extension/src/main.ts @@ -146,8 +146,7 @@ export async function activate(context: vscode.ExtensionContext): Promise { + let moduleUnderTest: any; + let getIsRelatedFilesApiEnabledStub: sinon.SinonStub; + let mockCopilotApi: sinon.SinonStubbedInstance; + let getActiveClientStub: sinon.SinonStub; + let activeClientStub: sinon.SinonStubbedInstance; + let vscodeGetExtensionsStub: sinon.SinonStub; + let callbackPromise: Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] }> | undefined; + let vscodeExtension: vscode.Extension;; + + beforeEach(() => { + proxyquire.noPreserveCache(); // Tells proxyquire to not fetch the module from cache + // Ensures that each test has a freshly loaded instance of moduleUnderTest + moduleUnderTest = proxyquire( + '../../../../src/LanguageServer/copilotProviders', + {} // Stub if you need to, or keep the object empty + ); + + getIsRelatedFilesApiEnabledStub = sinon.stub(telemetry, 'isExperimentEnabled'); + sinon.stub(util, 'extensionContext').value({ extension: { id: 'test-extension-id' } }); + + class MockCopilotApi implements CopilotApi { + public registerRelatedFilesProvider( + providerId: { extensionId: string; languageId: string }, + callback: ( + uri: vscode.Uri, + context: { flags: Record }, + cancellationToken: vscode.CancellationToken + ) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] }> + ): vscode.Disposable & { [Symbol.dispose](): void } { + return { + dispose: () => { }, + [Symbol.dispose]: () => { } + } + } + }; + mockCopilotApi = sinon.createStubInstance(MockCopilotApi); + vscodeExtension = { + id: 'test-extension-id', + extensionUri: vscode.Uri.parse('file:///test-extension-path'), + extensionPath: 'test-extension-path', + isActive: true, + packageJSON: { name: 'test-extension-name' }, + activate: async () => { }, + exports: mockCopilotApi, + extensionKind: vscode.ExtensionKind.UI + }; + + activeClientStub = sinon.createStubInstance(DefaultClient); + getActiveClientStub = sinon.stub(extension, 'getActiveClient').returns(activeClientStub); + activeClientStub.getIncludes.resolves({ includedFiles: [] }); + }); + + afterEach(() => { + sinon.restore() + }); + + const arrange = ({ cppToolsRelatedFilesApi, vscodeExtension, getIncludeFiles, chatContext, rootUri, flags }: + { cppToolsRelatedFilesApi: boolean, vscodeExtension?: vscode.Extension, getIncludeFiles?: GetIncludesResult, chatContext?: ChatContextResult, rootUri?: vscode.Uri, flags?: Record } = + { cppToolsRelatedFilesApi: false, vscodeExtension: undefined, getIncludeFiles: undefined, chatContext: undefined, rootUri: undefined, flags: {} } + ) => { + getIsRelatedFilesApiEnabledStub.withArgs('CppToolsRelatedFilesApi').resolves(cppToolsRelatedFilesApi); + activeClientStub.getIncludes.resolves(getIncludeFiles); + activeClientStub.getChatContext.resolves(chatContext); + sinon.stub(activeClientStub, 'RootUri').get(() => rootUri); + mockCopilotApi.registerRelatedFilesProvider.callsFake((providerId: { extensionId: string; languageId: string }, callback: (uri: vscode.Uri, context: { flags: Record }, cancellationToken: vscode.CancellationToken) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] }>) => { + const tokenSource = new vscode.CancellationTokenSource(); + try { + callbackPromise = callback(vscode.Uri.parse('file:///test-extension-path'), { flags: flags ?? {} }, tokenSource.token) + } finally { + tokenSource.dispose(); + } + + return { + dispose: () => { }, + [Symbol.dispose]: () => { } + } + }); + vscodeGetExtensionsStub = sinon.stub(vscode.extensions, 'getExtension').returns(vscodeExtension); + } + + it('should not register provider if CppToolsRelatedFilesApi is not enabled', async () => { + arrange( + { cppToolsRelatedFilesApi: false } + ); + + await moduleUnderTest.registerRelatedFilesProvider(); + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + }); + + it('should register provider if CppToolsRelatedFilesApi is enabled', async () => { + arrange( + { cppToolsRelatedFilesApi: true, vscodeExtension: vscodeExtension } + ); + + await moduleUnderTest.registerRelatedFilesProvider(); + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(vscodeGetExtensionsStub.calledOnce, 'vscode.extensions.getExtension should be called once'); + ok(mockCopilotApi.registerRelatedFilesProvider.calledWithMatch(sinon.match({ extensionId: 'test-extension-id', languageId: sinon.match.in(['c', 'cpp', 'cuda-cpp']) })), 'registerRelatedFilesProvider should be called with the correct providerId and languageId'); + }); + + it('should not add #cpp traits when ChatContext isn\'t available.', async () => { + arrange({ + cppToolsRelatedFilesApi: true, + vscodeExtension: vscodeExtension, + getIncludeFiles: { includedFiles: ['c:\\system\\include\\vector', 'c:\\system\\include\\string', 'C:\\src\\my_project\\foo.h'] }, + chatContext: undefined, + rootUri: vscode.Uri.file('C:\\src\\my_project'), + flags: { copilotcppTraits: true } + }); + await moduleUnderTest.registerRelatedFilesProvider(); + + const result = await callbackPromise; + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(vscodeGetExtensionsStub.calledOnce, 'vscode.extensions.getExtension should be called once'); + ok(mockCopilotApi.registerRelatedFilesProvider.calledWithMatch(sinon.match({ extensionId: 'test-extension-id', languageId: sinon.match.in(['c', 'cpp', 'cuda-cpp']) })), 'registerRelatedFilesProvider should be called with the correct providerId and languageId'); + ok(getActiveClientStub.callCount !== 0, 'getActiveClient should be called'); + ok(callbackPromise, 'callbackPromise should be defined'); + ok(result, 'result should be defined'); + ok(result.entries.length === 1, 'result.entries should have 1 included file'); + ok(result.entries[0].toString() === 'file:///c%3A/src/my_project/foo.h', 'result.entries should have "file:///c%3A/src/my_project/foo.h"'); + ok(result.traits === undefined, 'result.traits should be undefined'); + }); + + it('should not add #cpp traits when copilotcppTraits flag is false.', async () => { + arrange({ + cppToolsRelatedFilesApi: true, + vscodeExtension: vscodeExtension, + getIncludeFiles: { includedFiles: ['c:\\system\\include\\vector', 'c:\\system\\include\\string', 'C:\\src\\my_project\\foo.h'] }, + chatContext: { + language: 'c++', + standardVersion: 'c++20', + compiler: 'msvc', + targetPlatform: 'windows', + targetArchitecture: 'x64' + }, + rootUri: vscode.Uri.file('C:\\src\\my_project'), + flags: { copilotcppTraits: false } + }); + await moduleUnderTest.registerRelatedFilesProvider(); + + const result = await callbackPromise; + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(vscodeGetExtensionsStub.calledOnce, 'vscode.extensions.getExtension should be called once'); + ok(mockCopilotApi.registerRelatedFilesProvider.calledWithMatch(sinon.match({ extensionId: 'test-extension-id', languageId: sinon.match.in(['c', 'cpp', 'cuda-cpp']) })), 'registerRelatedFilesProvider should be called with the correct providerId and languageId'); + ok(getActiveClientStub.callCount !== 0, 'getActiveClient should be called'); + ok(callbackPromise, 'callbackPromise should be defined'); + ok(result, 'result should be defined'); + ok(result.entries.length === 1, 'result.entries should have 1 included file'); + ok(result.entries[0].toString() === 'file:///c%3A/src/my_project/foo.h', 'result.entries should have "file:///c%3A/src/my_project/foo.h"'); + ok(result.traits === undefined, 'result.traits should be undefined'); + }); + + it('should add #cpp traits when copilotcppTraits flag is true.', async () => { + arrange({ + cppToolsRelatedFilesApi: true, + vscodeExtension: vscodeExtension, + getIncludeFiles: { includedFiles: ['c:\\system\\include\\vector', 'c:\\system\\include\\string', 'C:\\src\\my_project\\foo.h'] }, + chatContext: { + language: 'c++', + standardVersion: 'c++20', + compiler: 'msvc', + targetPlatform: 'windows', + targetArchitecture: 'x64' + }, + rootUri: vscode.Uri.file('C:\\src\\my_project'), + flags: { copilotcppTraits: true } + }); + await moduleUnderTest.registerRelatedFilesProvider(); + + const result = await callbackPromise; + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(vscodeGetExtensionsStub.calledOnce, 'vscode.extensions.getExtension should be called once'); + ok(mockCopilotApi.registerRelatedFilesProvider.calledThrice, 'registerRelatedFilesProvider should be called three times'); + ok(mockCopilotApi.registerRelatedFilesProvider.calledWithMatch(sinon.match({ extensionId: 'test-extension-id', languageId: sinon.match.in(['c', 'cpp', 'cuda-cpp']) })), 'registerRelatedFilesProvider should be called with the correct providerId and languageId'); + ok(getActiveClientStub.callCount !== 0, 'getActiveClient should be called'); + ok(callbackPromise, 'callbackPromise should be defined'); + ok(result, 'result should be defined'); + ok(result.entries.length === 1, 'result.entries should have 1 included file'); + ok(result.entries[0].toString() === 'file:///c%3A/src/my_project/foo.h', 'result.entries should have "file:///c%3A/src/my_project/foo.h"'); + ok(result.traits, 'result.traits should be defined'); + ok(result.traits.length === 5, 'result.traits should have 5 traits'); + ok(result.traits[0].name === 'language', 'result.traits[0].name should be "language"'); + ok(result.traits[0].value === 'c++', 'result.traits[0].value should be "c++"'); + ok(result.traits[0].includeInPrompt, 'result.traits[0].includeInPrompt should be true'); + ok(result.traits[0].promptTextOverride === 'The language is c++.', 'result.traits[0].promptTextOverride should be "The language is c++."'); + ok(result.traits[1].name === 'compiler', 'result.traits[1].name should be "compiler"'); + ok(result.traits[1].value === 'msvc', 'result.traits[1].value should be "msvc"'); + ok(result.traits[1].includeInPrompt, 'result.traits[1].includeInPrompt should be true'); + ok(result.traits[1].promptTextOverride === 'This project compiles using msvc.', 'result.traits[1].promptTextOverride should be "This project compiles using msvc."'); + ok(result.traits[2].name === 'standardVersion', 'result.traits[2].name should be "standardVersion"'); + ok(result.traits[2].value === 'c++20', 'result.traits[2].value should be "c++20"'); + ok(result.traits[2].includeInPrompt, 'result.traits[2].includeInPrompt should be true'); + ok(result.traits[2].promptTextOverride === 'This project uses the c++20 language standard.', 'result.traits[2].promptTextOverride should be "This project uses the c++20 language standard."'); + ok(result.traits[3].name === 'targetPlatform', 'result.traits[3].name should be "targetPlatform"'); + ok(result.traits[3].value === 'windows', 'result.traits[3].value should be "windows"'); + ok(result.traits[3].includeInPrompt, 'result.traits[3].includeInPrompt should be true'); + ok(result.traits[3].promptTextOverride === 'This build targets windows.', 'result.traits[3].promptTextOverride should be "This build targets windows."'); + ok(result.traits[4].name === 'targetArchitecture', 'result.traits[4].name should be "targetArchitecture"'); + ok(result.traits[4].value === 'x64', 'result.traits[4].value should be "x64"'); + ok(result.traits[4].includeInPrompt, 'result.traits[4].includeInPrompt should be true'); + ok(result.traits[4].promptTextOverride === 'This build targets x64.', 'result.traits[4].promptTextOverride should be "This build targets x64."'); + }); + + it('should exclude #cpp traits per copilotcppExcludeTraits.', async () => { + const excludeTraits = ['compiler', 'targetPlatform']; + arrange({ + cppToolsRelatedFilesApi: true, + vscodeExtension: vscodeExtension, + getIncludeFiles: { includedFiles: ['c:\\system\\include\\vector', 'c:\\system\\include\\string', 'C:\\src\\my_project\\foo.h'] }, + chatContext: { + language: 'c++', + standardVersion: 'c++20', + compiler: 'msvc', + targetPlatform: 'windows', + targetArchitecture: 'x64' + }, + rootUri: vscode.Uri.file('C:\\src\\my_project'), + flags: { copilotcppTraits: true, copilotcppExcludeTraits: excludeTraits } + }); + await moduleUnderTest.registerRelatedFilesProvider(); + + const result = await callbackPromise; + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(vscodeGetExtensionsStub.calledOnce, 'vscode.extensions.getExtension should be called once'); + ok(mockCopilotApi.registerRelatedFilesProvider.calledThrice, 'registerRelatedFilesProvider should be called three times'); + ok(mockCopilotApi.registerRelatedFilesProvider.calledWithMatch(sinon.match({ extensionId: 'test-extension-id', languageId: sinon.match.in(['c', 'cpp', 'cuda-cpp']) })), 'registerRelatedFilesProvider should be called with the correct providerId and languageId'); + ok(getActiveClientStub.callCount !== 0, 'getActiveClient should be called'); + ok(callbackPromise, 'callbackPromise should be defined'); + ok(result, 'result should be defined'); + ok(result.entries.length === 1, 'result.entries should have 1 included file'); + ok(result.entries[0].toString() === 'file:///c%3A/src/my_project/foo.h', 'result.entries should have "file:///c%3A/src/my_project/foo.h"'); + ok(result.traits, 'result.traits should be defined'); + ok(result.traits.length === 3, 'result.traits should have 3 traits'); + ok(result.traits.filter(trait => excludeTraits.includes(trait.name)).length === 0, 'result.traits should not include excluded traits'); + }); + + it('should handle errors during provider registration', async () => { + arrange( + { cppToolsRelatedFilesApi: true } + ); + + await moduleUnderTest.registerRelatedFilesProvider(); + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(vscodeGetExtensionsStub.calledOnce, 'vscode.extensions.getExtension should be called once'); + ok(mockCopilotApi.registerRelatedFilesProvider.notCalled, 'registerRelatedFilesProvider should not be called'); + }); +}) + +describe('registerRelatedFilesCommands', () => { + let moduleUnderTest: any; + let getIsRelatedFilesApiEnabledStub: sinon.SinonStub; + let registerCommandStub: sinon.SinonStub; + let commandDisposables: vscode.Disposable[]; + let getActiveClientStub: sinon.SinonStub; + let activeClientStub: sinon.SinonStubbedInstance; + let getIncludesResult: Promise = Promise.resolve(undefined); + + beforeEach(() => { + proxyquire.noPreserveCache(); // Tells proxyquire to not fetch the module from cache + // Ensures that each test has a freshly loaded instance of moduleUnderTest + moduleUnderTest = proxyquire( + '../../../../src/LanguageServer/copilotProviders', + {} // Stub if you need to, or keep the object empty + ); + + activeClientStub = sinon.createStubInstance(DefaultClient); + getActiveClientStub = sinon.stub(extension, 'getActiveClient').returns(activeClientStub); + activeClientStub.getIncludes.resolves({ includedFiles: [] }); + getIsRelatedFilesApiEnabledStub = sinon.stub(telemetry, 'isExperimentEnabled'); + registerCommandStub = sinon.stub(vscode.commands, 'registerCommand'); + commandDisposables = []; + }); + + afterEach(() => { + sinon.restore(); + }); + + const arrange = ({ cppToolsRelatedFilesApi, getIncludeFiles, chatContext, rootUri }: + { cppToolsRelatedFilesApi: boolean, getIncludeFiles?: GetIncludesResult, chatContext?: ChatContextResult, rootUri?: vscode.Uri } = + { cppToolsRelatedFilesApi: false, getIncludeFiles: undefined, chatContext: undefined, rootUri: undefined } + ) => { + getIsRelatedFilesApiEnabledStub.withArgs('CppToolsRelatedFilesApi').resolves(cppToolsRelatedFilesApi); + activeClientStub.getIncludes.resolves(getIncludeFiles); + sinon.stub(activeClientStub, 'RootUri').get(() => rootUri); + registerCommandStub.callsFake((command: string, callback: (maxDepth: number) => Promise) => { + getIncludesResult = callback(1); + }); + } + + it('should register C_Cpp.getIncludes command if CppToolsRelatedFilesApi is enabled', async () => { + arrange({ cppToolsRelatedFilesApi: true }); + + await moduleUnderTest.registerRelatedFilesCommands(commandDisposables, true); + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(registerCommandStub.calledOnce, 'registerCommand should be called once'); + ok(commandDisposables.length === 1, 'commandDisposables should have one disposable'); + ok(registerCommandStub.calledWithMatch('C_Cpp.getIncludes', sinon.match.func), 'registerCommand should be called with "C_Cpp.getIncludes" and a function'); + }); + + it('should register C_Cpp.getIncludes command that can handle requests properly', async () => { + arrange({ + cppToolsRelatedFilesApi: true, + getIncludeFiles: { includedFiles: ['c:\\system\\include\\vector', 'c:\\system\\include\\string', 'C:\\src\\my_project\\foo1.h', 'C:\\src\\my_project\\foo2.h'] }, + rootUri: vscode.Uri.file('C:\\src\\my_project') + }); + await moduleUnderTest.registerRelatedFilesCommands(commandDisposables, true); + + const includesResult = await getIncludesResult; + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(getActiveClientStub.calledOnce, 'getActiveClient should be called once'); + ok(includesResult, 'includesResult should be defined'); + ok(includesResult?.includedFiles.length === 2, 'includesResult should have 2 included files'); + ok(includesResult?.includedFiles[0] === 'C:\\src\\my_project\\foo1.h', 'includesResult should have "C:\\src\\my_project\\foo1.h"'); + ok(includesResult?.includedFiles[1] === 'C:\\src\\my_project\\foo2.h', 'includesResult should have "C:\\src\\my_project\\foo2.h"'); + }); + + it('should not register C_Cpp.getIncludes command if CppToolsRelatedFilesApi is not enabled', async () => { + arrange({ + cppToolsRelatedFilesApi: false + }); + + await moduleUnderTest.registerRelatedFilesCommands(commandDisposables, true); + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(registerCommandStub.notCalled, 'registerCommand should not be called'); + ok(commandDisposables.length === 0, 'commandDisposables should be empty'); + }); + + it('should register C_Cpp.getIncludes as no-op command if enabled is false', async () => { + arrange({ + cppToolsRelatedFilesApi: true, + getIncludeFiles: { includedFiles: ['c:\\system\\include\\vector', 'c:\\system\\include\\string', 'C:\\src\\my_project\\foo.h'] }, + rootUri: vscode.Uri.file('C:\\src\\my_project') + }); + await moduleUnderTest.registerRelatedFilesCommands(commandDisposables, false); + + const includesResult = await getIncludesResult; + + ok(getIsRelatedFilesApiEnabledStub.calledOnce, 'getIsRelatedFilesApiEnabled should be called once'); + ok(registerCommandStub.calledOnce, 'registerCommand should be called once'); + ok(commandDisposables.length === 1, 'commandDisposables should have one disposable'); + ok(includesResult === undefined, 'includesResult should be undefined'); + ok(registerCommandStub.calledWithMatch('C_Cpp.getIncludes', sinon.match.func), 'registerCommand should be called with "C_Cpp.getIncludes" and a function'); + }); +})