Skip to content

Commit e0c30e1

Browse files
authored
Specify torch_type and device_map when loading the model
1 parent e732bcf commit e0c30e1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

text_generation_demo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
def init_model(args):
1212
global tokenizer, model
13-
tokenizer = AutoTokenizer.from_pretrained(args.model_path, truncation_side="left", padding_side="left")
14-
model = AutoModelForCausalLM.from_pretrained(args.tokenizer_path, trust_remote_code=True).half().cuda()
13+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, truncation_side="left", padding_side="left")
14+
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map='auto')
1515
model = model.eval()
1616

1717

@@ -79,4 +79,4 @@ def get_args():
7979
temperature, top_k, top_p],
8080
outputs=outputs)
8181

82-
demo.launch(server_name="0.0.0.0", server_port=args.port)
82+
demo.launch(server_name="0.0.0.0", server_port=args.port)

0 commit comments

Comments
 (0)