Skip to content

Commit 74931d9

Browse files
committed
transformers-pipeline-demo.py
1 parent a5510c0 commit 74931d9

1 file changed

Lines changed: 21 additions & 0 deletions

File tree

llm/transformers-pipeline-demo.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/env python3
2+
3+
print("initializing")
4+
5+
from transformers import pipeline, GPT2LMHeadModel, AutoTokenizer
6+
7+
length = 4
8+
input = " ".join(str(i) for i in range(length))
9+
print("input:", input)
10+
input_list = input.split()
11+
12+
model = GPT2LMHeadModel.from_pretrained('trained_model')
13+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
14+
tokenizer.pad_token = tokenizer.eos_token
15+
16+
print("inference")
17+
gen = pipeline("text-generation", model=model, tokenizer=tokenizer, truncation=True)
18+
output = gen(input_list[0], max_new_tokens=length - 1, num_return_sequences=1)[0]['generated_text']
19+
print(output)
20+
print(gen('0 1', max_new_tokens=length - 2, num_return_sequences=1)[0]['generated_text'])
21+
exit(output != input)

0 commit comments

Comments
 (0)