Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/codescan_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,49 @@ on:
pull_request:
branches:
- main
jobs:
codescan_ci:
// AI FIX START
runs-on: ${{ matrix.os }}
permissions:
contents: read
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Check out code
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
run: |
pip install build
pip install -r requirements.txt
pip install -r requirements_ci.txt

- name: Enforce proper code formatting with isort
run: |
isort --profile=black --check-only core/ core_tests/

- name: Check linting with pylint
run: |
pylint core/

- name: Run tests with unittest
run: |
python -m unittest discover -s core_tests -p "*.py"

- name: Build the project
run: |
python -m build
// AI FIX END

jobs:
codescan_ci:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/release-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:

publish-to-testpypi:
needs: release
permissions:
contents: read # Explicitly set minimal permissions
runs-on: ubuntu-latest
if: ${{ needs.release.outputs.release_created }}
steps:
Expand Down Expand Up @@ -56,6 +58,8 @@ jobs:

publish-to-pypi:
needs: publish-to-testpypi
permissions:
contents: read # Explicitly set minimal permissions
runs-on: ubuntu-latest
if: ${{ needs.publish-to-testpypi.result == 'success' }}
steps:
Expand Down
11 changes: 9 additions & 2 deletions core/code_scanner/code_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def _scan_changes(self):
return "No changes detected in the directory."

code_summary = generate_code_summary(self.args.directory, changed_files)
if not code_summary.strip():
logging.info("No readable source files found in the detected changes.")
return "No readable source files found in the detected changes."

return self.provider.scan_code(code_summary)

Expand All @@ -71,10 +74,14 @@ def _scan_files(self):
file_paths.append(os.path.join(root, file))

code_summary = read_files_and_extract_code_summary(file_paths)
if not code_summary.strip():
logging.info("No readable files found in the specified directory.")
return "No readable files found in the specified directory."

return self.provider.scan_code(code_summary)

def _is_repo_valid(self):
return len(self.args.repo) > 0
return bool(getattr(self.args, "repo", ""))

def _is_pr_number_valid(self):
return self.args.pr_number > 0
return getattr(self.args, "pr_number", 0) > 0
18 changes: 16 additions & 2 deletions core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
This is the runner of the codescan-ai CLI tool.
"""

from IPython.display import display_markdown
try:
from IPython.display import display_markdown
except ImportError: # pragma: no cover - exercised via fallback behavior
display_markdown = None

from core.code_scanner.code_scanner import CodeScanner
from core.utils.argument_parser import parse_arguments
Expand All @@ -17,6 +20,17 @@ def format_as_markdown(result):
return output


def display_scan_result(result):
"""
Displays the scan result in notebook environments and falls back to stdout for CLI use.
"""
formatted_result = format_as_markdown(result)
if display_markdown is not None:
display_markdown(formatted_result)
return
print(formatted_result)


def main():
"""
Main entry point for the CLI. Parses arguments, calls the centralized CodeScanner
Expand All @@ -25,7 +39,7 @@ def main():
"""
args = parse_arguments()
scan_result = CodeScanner(args).scan()
display_markdown(format_as_markdown(scan_result))
display_scan_result(scan_result)


if __name__ == "__main__":
Expand Down
15 changes: 10 additions & 5 deletions core/utils/file_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
"""

import logging
import os
import subprocess

from github import Github

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
Expand Down Expand Up @@ -53,6 +50,13 @@ def get_changed_files_in_pr(repo_name, pr_number, github_token):
logging.error("GitHub token is required for scanning PR changes.")
raise ValueError("GitHub token is required for scanning PR changes.")

try:
from github import Github
except ImportError as exc:
raise ImportError(
"PyGithub is required for scanning pull request changes."
) from exc

files = Github(github_token).get_repo(repo_name).get_pull(pr_number).get_files()

changed_files = [file.filename for file in files]
Expand Down Expand Up @@ -81,8 +85,9 @@ def get_changed_files_in_repo(directory):

changed_files = []
try:
os.chdir(directory)
result = subprocess.check_output(["git", "diff", "--name-only"], text=True)
result = subprocess.check_output(
["git", "-C", directory, "diff", "--name-only"], text=True
)
if result.strip():
changed_files = result.strip().split("\n")
logging.info(
Expand Down
29 changes: 18 additions & 11 deletions core/utils/provider_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@
This module provides util methods used for initializing an AIProvider based on the user args.
"""

from core.providers.custom_ai_provider import CustomAIProvider
from core.providers.google_gemini_ai_provider import GoogleGeminiAIProvider
from core.providers.open_ai_provider import OpenAIProvider

PROVIDERS = {
"openai": OpenAIProvider,
"gemini": GoogleGeminiAIProvider,
"custom": CustomAIProvider,
}

DEFAULT_MODELS = {
"openai": "gpt-4o-mini",
"gemini": "gemini-pro",
}


def _get_provider_class(provider):
if provider == "openai":
from core.providers.open_ai_provider import OpenAIProvider

return OpenAIProvider
if provider == "gemini":
from core.providers.google_gemini_ai_provider import GoogleGeminiAIProvider

return GoogleGeminiAIProvider
if provider == "custom":
from core.providers.custom_ai_provider import CustomAIProvider

return CustomAIProvider
raise ValueError(f"Unsupported provider: {provider}")


def init_provider(provider, model, host=None, port=None, token=None, endpoint=None): # pylint: disable=R0917
"""
Initializes and returns the appropriate AI client based on the provider.
Expand All @@ -36,4 +42,5 @@ def init_provider(provider, model, host=None, port=None, token=None, endpoint=No
"model": model if model else DEFAULT_MODELS[provider],
}

return PROVIDERS[provider](**client_params)
provider_class = _get_provider_class(provider)
return provider_class(**client_params)
122 changes: 102 additions & 20 deletions core_tests/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import subprocess
import sys
import unittest
from types import ModuleType
from unittest.mock import MagicMock, mock_open, patch

from core.code_scanner.code_scanner import CodeScanner
from core.runner import display_scan_result, format_as_markdown
from core.utils.code_summary_extractor import (
generate_code_summary,
read_files_and_extract_code_summary,
Expand Down Expand Up @@ -31,8 +35,8 @@ def test__isGitRepo__valid(self, mock_check_output):
def test__isGitRepo__invalid(self, mock_check_output):
self.assertFalse(is_git_repo(os.path.join("test", "invalid", "repo")))

@patch("core.utils.file_extractor.Github")
def test__getChangedFilesInPr(self, mock_github):
def test__getChangedFilesInPr(self):
mock_github = MagicMock()
mock_pr = MagicMock()
mock_pr.get_files.return_value = [
MagicMock(filename="file_one.py"),
Expand All @@ -42,19 +46,22 @@ def test__getChangedFilesInPr(self, mock_github):
mock_repo.get_pull.return_value = mock_pr
mock_github.return_value.get_repo.return_value = mock_repo

files = get_changed_files_in_pr("some/repo", 1, "fake_token")
github_module = ModuleType("github")
github_module.Github = mock_github
with patch.dict(sys.modules, {"github": github_module}):
files = get_changed_files_in_pr("some/repo", 1, "fake_token")
self.assertEqual(files, ["file_one.py", "file_two.py"])

@patch("subprocess.check_output")
@patch("os.chdir")
@patch("core.utils.file_extractor.is_git_repo", return_value=True)
def test__getChangedFilesInRepo(
self, mock_is_git_repo, mock_chdir, mock_check_output
):
def test__getChangedFilesInRepo(self, mock_is_git_repo, mock_check_output):
mock_check_output.return_value = "file_one.py\nfile_two.py\n"
files = get_changed_files_in_repo(os.path.join("some", "repo"))
self.assertEqual(files, ["file_one.py", "file_two.py"])
mock_chdir.assert_called_once_with(os.path.join("some", "repo"))
mock_check_output.assert_called_once_with(
["git", "-C", os.path.join("some", "repo"), "diff", "--name-only"],
text=True,
)

@patch("core.utils.file_extractor.is_git_repo", return_value=False)
def test__getChangedFilesInRepo_inValid(self, mock_is_git_repo):
Expand Down Expand Up @@ -134,22 +141,26 @@ def test__generateCodeSummary__isValid(self, mock_isfile, mock_open):
Provider Creator Tests
"""

@patch("core.utils.provider_creator.OpenAIProvider")
@patch("os.getenv")
def test__initOpenAIProvider(self, mock_getenv, mock_openai_provider):
mock_getenv.return_value = "test_openai_api_key"
@patch("core.utils.provider_creator._get_provider_class")
def test__initOpenAIProvider(self, mock_get_provider_class):
mock_provider_class = MagicMock()
mock_get_provider_class.return_value = mock_provider_class
init_provider("openai", None)
mock_getenv.assert_called_once_with("OPENAI_API_KEY")
mock_get_provider_class.assert_called_once_with("openai")
mock_provider_class.assert_called_once_with(model="gpt-4o-mini")

@patch("core.utils.provider_creator.GoogleGeminiAIProvider")
@patch("os.getenv")
def test__initGoogleGeminiAIProvider(self, mock_getenv, mock_google_client):
mock_getenv.return_value = "test_gemini_api_key"
@patch("core.utils.provider_creator._get_provider_class")
def test__initGoogleGeminiAIProvider(self, mock_get_provider_class):
mock_provider_class = MagicMock()
mock_get_provider_class.return_value = mock_provider_class
init_provider("gemini", None)
mock_getenv.assert_called_once_with("GEMINI_API_KEY")
mock_get_provider_class.assert_called_once_with("gemini")
mock_provider_class.assert_called_once_with(model="gemini-pro")

@patch("core.utils.provider_creator.CustomAIProvider")
def test_initialize_client_custom(self, mock_custom_client):
@patch("core.utils.provider_creator._get_provider_class")
def test_initialize_client_custom(self, mock_get_provider_class):
mock_provider_class = MagicMock()
mock_get_provider_class.return_value = mock_provider_class
init_provider(
"custom",
"custom-model",
Expand All @@ -158,6 +169,77 @@ def test_initialize_client_custom(self, mock_custom_client):
"custom-token",
"/api/v1/scan",
)
mock_get_provider_class.assert_called_once_with("custom")
mock_provider_class.assert_called_once_with(
model="custom-model",
host="http://localhost",
port=5000,
token="custom-token",
endpoint="/api/v1/scan",
)

"""
Runner Tests
"""

def test__formatAsMarkdown(self):
result = "This is a test result"
expected_output = "## Code Security Analysis Results\nThis is a test result"
self.assertEqual(format_as_markdown(result), expected_output)

@patch("builtins.print")
@patch("core.runner.display_markdown", None)
def test__displayScanResult__fallsBackToPrint(self, mock_print):
display_scan_result("scan result")
mock_print.assert_called_once_with(
"## Code Security Analysis Results\nscan result"
)

"""
Code Scanner Tests
"""

@patch("core.code_scanner.code_scanner.init_provider")
@patch("core.code_scanner.code_scanner.read_files_and_extract_code_summary")
@patch("os.walk")
def test__scanFiles__skipsProviderCallWhenNoReadableFiles(
self, mock_walk, mock_read_files, mock_init_provider
):
mock_walk.return_value = [("repo", (), ("binary.dat",))]
mock_read_files.return_value = " "
mock_provider = MagicMock()
mock_init_provider.return_value = mock_provider
mock_args = MagicMock(changes_only=False, directory=".", repo=None, pr_number=None)

scan_result = CodeScanner(args=mock_args).scan()

self.assertEqual(scan_result, "No readable files found in the specified directory.")
mock_provider.scan_code.assert_not_called()

@patch("core.code_scanner.code_scanner.get_changed_files_in_pr")
@patch("core.code_scanner.code_scanner.generate_code_summary")
@patch("core.code_scanner.code_scanner.init_provider")
def test__scanChanges__skipsProviderCallWhenChangedFilesAreUnreadable(
self, mock_init_provider, mock_generate_summary, mock_get_changed_files_in_pr
):
mock_get_changed_files_in_pr.return_value = ["binary.dat"]
mock_generate_summary.return_value = ""
mock_provider = MagicMock()
mock_init_provider.return_value = mock_provider
mock_args = MagicMock(
changes_only=True,
directory=".",
repo="owner/repo",
pr_number=1,
github_token="token",
)

scan_result = CodeScanner(args=mock_args).scan()

self.assertEqual(
scan_result, "No readable source files found in the detected changes."
)
mock_provider.scan_code.assert_not_called()


if __name__ == "__main__":
Expand Down
Loading