Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ env.sh
.mypy_cache
notebooks/output
notebooks/repos
.venv/
Comment thread
bzz marked this conversation as resolved.
Outdated
.vscode/
72 changes: 52 additions & 20 deletions notebooks/codesearchnet-opennmt.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
"""
CLI tool for converting CodeSearchNet dataset to OpenNMT format for
function name suggestion task.

Usage example:
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
unzip java.zip
python notebooks/codesearchnet-opennmt.py \
--data_dir='java/final/jsonl/valid' \
Comment thread
bzz marked this conversation as resolved.
Outdated
--newline='\\n'
Comment thread
m09 marked this conversation as resolved.
"""
from argparse import ArgumentParser, Namespace
import logging
from pathlib import Path
from time import time
from typing import List, Tuple

import pandas as pd
from torch.utils.data import Dataset


logging.basicConfig(level=logging.INFO)


class CodeSearchNetRAM(Dataset):
"""Stores one split of CodeSearchNet data in memory

Usage example:
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
unzip java.zip
python notebooks/codesearchnet-opennmt.py \
--data_dir='java/final/jsonl/valid' \
--newline='\\n'
"""
class CodeSearchNetRAM(object):
Comment thread
bzz marked this conversation as resolved.
Outdated
"""Stores one split of CodeSearchNet data in memory"""

def __init__(self, split_path: Path, newline_repl: str):
super().__init__()
self.pd = pd
self.newline_repl = newline_repl

files = sorted(split_path.glob("**/*.gz"))
logging.info(f"Total number of files: {len(files):,}")
assert files, "could not find files under %s" % split_path

columns_list = ["code", "func_name"]
columns_list = ["code", "func_name", "code_tokens"]

start = time()
self.pd = self._jsonl_list_to_dataframe(files, columns_list)
Expand Down Expand Up @@ -61,10 +64,21 @@ def __getitem__(self, idx: int) -> Tuple[str, str]:

# drop fn signature
code = row["code"]
fn_body = code[code.find("{") + 1 : code.rfind("}")].lstrip().rstrip()
fn_body = fn_body.replace("\n", "\\n")
fn_body = (
code[
code.find("{", code.find(fn_name) + len(fn_name)) + 1 : code.rfind("}")
]
.lstrip()
.rstrip()
Comment thread
bzz marked this conversation as resolved.
Outdated
)
fn_body = fn_body.replace("\n", self.newline_repl)
# fn_body_enc = self.enc.encode(fn_body)
return (fn_name, fn_body)

tokens = row["code_tokens"]
body_tokens = tokens[tokens.index(fn_name) + 2 :]
fn_body_tokens = body_tokens[body_tokens.index("{") + 1 : len(body_tokens) - 1]

return (fn_name, fn_body, fn_body_tokens)

def __len__(self) -> int:
return len(self.pd)
Expand All @@ -76,11 +90,16 @@ def main(args: Namespace) -> None:
with open(args.src_file % split_name, mode="w", encoding="utf8") as s, open(
args.tgt_file % split_name, mode="w", encoding="utf8"
) as t:
for fn_name, fn_body in dataset:
for fn_name, fn_body, fn_body_tokens in dataset:
if not fn_name or not fn_body:
continue
print(fn_body, file=s)
print(fn_name if args.word_level_targets else " ".join(fn_name), file=t)
src = " ".join(fn_body_tokens) if args.token_level_sources else fn_body
tgt = fn_name if args.word_level_targets else " ".join(fn_name)
if args.print:
print(f"'{tgt[:40]:40}' - '{src[:40]:40}'")
else:
print(src, file=s)
print(tgt, file=t)


if __name__ == "__main__":
Expand All @@ -96,18 +115,31 @@ def main(args: Namespace) -> None:
"--newline", type=str, default="\\n", help="Replace newline with this"
Comment thread
bzz marked this conversation as resolved.
Outdated
)

parser.add_argument(
"--token-level-sources",
action="store_true",
help="Use language-specific token sources instead of word level ones",
)

parser.add_argument(
"--word-level-targets",
action="store_true",
help="Use word level targets instead of char level ones",
)

parser.add_argument(
"--src_file", type=str, default="src-%s.txt", help="File with function bodies",
"--src_file",
Comment thread
bzz marked this conversation as resolved.
Outdated
type=str,
Comment thread
bzz marked this conversation as resolved.
Outdated
default="src-%s.token",
help="File with function bodies",
)

parser.add_argument(
"--tgt_file", type=str, default="tgt-%s.token", help="File with function texts"
Comment thread
bzz marked this conversation as resolved.
Outdated
)

parser.add_argument(
"--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts"
"--print", action="store_true", help="Print data preview to the STDOUT"
)

args = parser.parse_args()
Expand Down