""" A simple example to illustrate the use of FAISS for ranking documents and T5 for generating answers. """ # Copyright (c) 2025, Carnegie Mellon University. All Rights Reserved. import faiss import numpy import torch from transformers import AutoTokenizer, AutoModel from transformers import T5Tokenizer, T5ForConditionalGeneration from Idx import Idx # ------------------ Global variables ---------------------- # model_max_sequence_length = 512 # Max WordPiece tokens trecEvalOutputLength = 10 index_path = "INPUT_DIR/index-cw09" dense_index_path = "INPUT_DIR/index-cw09-faiss-t32b300-Fp" dpr_model_path = "INPUT_DIR/co-condenser-marco-retriever/" rag_model_path = "INPUT_DIR/flan-t5-base" # ------------------ Methods ------------------------------- # def dense_encode(input_dict): """ Encode a token sequence. Input: the tokenized sequence. Output: the sense representation. """ with torch.no_grad(): outputs = dense_model(**input_dict) rep = outputs.last_hidden_state[:,0] # The hidden state of [CLS] rep = rep.squeeze() # [1, 768] -> [768] rep = rep.tolist() # Avoid tensor memory leak return(rep) def dense_tokenize_string(s): """ Use the model to tokenize s, convert to token ids, and return as tensors. Input: a text string. Output: a dictionary of tensors that BERT understands. "input_ids": The ids for each token. "token_type_ids": The token type (sequence) id of each token. "attention_mask": For each token, mask(0) or don't mask(1). Not used. """ return(dense_tokenizer.encode_plus( s, # sequence max_length=model_max_sequence_length, truncation=True, # Truncate if too long return_tensors="pt")) # Return PyTorch tensors def text_truncate(s): """Truncate a string to model_max_sequence_length tokens.""" return(" ".join(s.split()[:model_max_sequence_length])) # ------------------ Script body --------------------------- # question = "Do cigarettes cause cancer?" print('==> Retrieval <==', flush=True) print(f'Query: {question}', flush=True) # Initialize retrieval Idx.open(index_path) faiss_index = faiss.read_index(dense_index_path) dense_tokenizer = AutoTokenizer.from_pretrained(dpr_model_path) dense_model = AutoModel.from_pretrained(dpr_model_path) dense_model.eval() # Tokenizing and encoding a string is similar to HW4 encoded_query = dense_encode(dense_tokenize_string(question)) # FAISS evaluates a list of queries. Our list has just 1 query. encoded_query = [encoded_query] scores, docids = faiss_index.search(numpy.array(encoded_query), trecEvalOutputLength) print(f'Internal docids: {docids[0]}', flush=True) print(f'Scores: {scores[0]}', flush=True) # Get (a very simple) first passage from the first document. body = Idx.getAttribute("body-string", docids[0][0]) passage = body[:600] print("\n==> Retrieval augmented generation <==", flush=True) print(f'Question: {question}', flush=True) # Initialize the generator rag_tokenizer = T5Tokenizer.from_pretrained(rag_model_path) rag_model = T5ForConditionalGeneration.from_pretrained(rag_model_path) # Do generation with no retrieval. prompt = f"question: {question} \n answer: \n" input_ids = rag_tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=model_max_sequence_length) output = rag_model.generate(input_ids, num_return_sequences=1) answer = rag_tokenizer.decode(output[0], skip_special_tokens=True) print(f'Answer 1 (no retrieval): {answer}', flush=True) # Do RAG. prompt = f"question: {question} \n context: \n {passage}" input_ids = rag_tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=model_max_sequence_length) output = rag_model.generate(input_ids, num_return_sequences=1) answer = rag_tokenizer.decode(output[0], skip_special_tokens=True) print(f'Answer 2 (w/retrieval): {answer}', flush=True)