-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathingest_movies.py
More file actions
121 lines (102 loc) · 3.44 KB
/
ingest_movies.py
File metadata and controls
121 lines (102 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import mlcroissant as mlc
from typer import Typer
import json
import pandas as pd
from tqdm import tqdm
from movie_record import make_movie_with_all_connections, create_indexes, DATASET_NAME
from movie_parser import MovieParser
from aperturedb.ParallelLoader import ParallelLoader
from aperturedb.CommonLibrary import (
create_connector,
execute_query
)
from aperturedb.Utils import Utils
from embeddings import Embedder, DEFAULT_MODEL
# constants
CROISSANT_URL = "https://www.kaggle.com/datasets/tmdb/tmdb-movie-metadata/croissant/download"
app = Typer()
def deserialize_record(record):
deserialized = record.decode('utf-8') if isinstance(record, bytes) else record
if isinstance(deserialized, str):
try:
deserialized = json.loads(deserialized)
except:
pass
return deserialized
def cleanup_movies(db):
"""
Cleanup the movies dataset from ApertureDB.
"""
query = [
{
"DeleteEntity": {
"constraints": {
"dataset_name": ["==", DATASET_NAME]
}
}
},
{
"DeleteImage": {
"constraints": {
"dataset_name": ["==", DATASET_NAME]
}
}
},
{
"DeleteDescriptorSet": {
"constraints": {
"dataset_name": ["==", DATASET_NAME]
}
}
}
]
execute_query(db, query=query)
@app.command()
def ingest_movies(ingest_posters: bool = False, embed_tagline: bool = False, sample_count: int = -1):
"""
Ingest the movies dataset into ApertureDB.
"""
# Fetch the Croissant JSON-LD
croissant_dataset = mlc.Dataset(CROISSANT_URL)
# Get record sets in the dataset
record_sets = croissant_dataset.metadata.record_sets
# Fetch the records and put them in a DataFrame. The archive, downloads, load into a DataFrame
# is managed by the croissant library.
# croisant recrds are ~ DataFrame. TMDB has 2 record sets
# The first records are the movies, the second are the casts.
# The association between the two is the movie_id
record_set_df_0 = pd.DataFrame(croissant_dataset.records(record_set=record_sets[0].uuid))
record_set_df_1 = pd.DataFrame(croissant_dataset.records(record_set=record_sets[1].uuid))
# Merge the two DataFrames on the movie_id
records = record_set_df_0.merge(
record_set_df_1,
right_on="tmdb_5000_movies.csv/id",
left_on="tmdb_5000_credits.csv/movie_id")
if sample_count > 0:
records = records.head(sample_count)
collection = []
db = create_connector()
cleanup_movies(db)
descriptor_set = "wf_embeddings_clip"
embedder = Embedder.from_new_descriptor_set(
db, descriptor_set,
provider="clip",
model_name="ViT-B/16",
properties={"type": "text", "source_type": "movie", "dataset_name": DATASET_NAME})
for record in tqdm(records.iterrows()):
columns = records.columns
count = 0
j = {}
for c in columns:
j[c] = deserialize_record(record[1][c])
count += 1
movie = make_movie_with_all_connections(j, embedder, ingest_posters, embed_tagline)
collection.append(movie)
parser = MovieParser(collection)
utils = Utils(db)
create_indexes(utils)
loader = ParallelLoader(db)
ParallelLoader.setSuccessStatus([0, 2])
loader.ingest(parser, batchsize=10, numthreads=8, stats=True)
if __name__ == "__main__":
app()