import torch.nn from torch.utils import data import numpy as np import pandas as pd from torch.utils.data.dataset import T_co import pickle from tqdm import tqdm import sys sys.path.append('../') from common.log_utils import logFactory logger = logFactory("UserDataset").log class UserDataset(data.Dataset): def __init__(self, dict_file_path, file_type): self.data_path = dict_file_path self.label_dict = self.load_obj(file_type + "_label_dict") self.path_df = self.load_obj(file_type + "_path_df") # f = open(self.data_path, "r") # lines = f.read().split("\n") # self.label_dict = {} # self.path_df = {} # logger.info(f"开始加载所有{dict_file_path}文件") # for i, line in enumerate(tqdm(lines)): # if line == "": # continue # file_path = line.split(" ")[0] # key = file_path.split("/")[-1].split(".")[0] # label_value = line.split(" ")[1] # self.label_dict[key] = label_value # df = pd.read_pickle(file_path) # self.path_df[key] = df.drop(columns=['uuid']) # self.save_obj(self.path_df, file_type + "_path_df") # self.save_obj(self.label_dict, file_type + "_label_dict") logger.info(f"所有{dict_file_path}中的dataframe加载完成") # f.close() def save_obj(self, obj, name): with open('./obj/' + name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def load_obj(self, name): with open("C:\\project\\code\\unsubscribe_predict\\lstm\\obj\\" + name + ".pkl", "rb") as f: return pickle.load(f) def __getitem__(self, index) -> T_co: # logger.info(f"current get item index:{index}") label = int(list(self.label_dict.values())[index - 1]) file_path = list(self.label_dict.keys())[index - 1] df = self.path_df[file_path] # df = pd.read_pickle(file_path) # df.drop(columns=['uuid'], inplace=True) content = np.array(df) # logger.info(f"file_name:{file_path}") return torch.from_numpy(content).float(), torch.LongTensor([label]) def get_labels(self): return list(self.label_dict.values()) def __len__(self) -> int: return len(self.label_dict) if __name__ == "__main__": custom_dataset = UserDataset("a") # hidden_dim =10 # lstm = torch.nn.LSTM(2,3) # inputs = [torch.randn(1, 3) for _ in range(5)] # x = torch.randn(4,3,2) pass