上一节用PaddleOCR提取出了,火车票图片的文字和位置信息,并手动打上了对应标签,本节我们将根据带标签的文件中的位置信息,创建一个图结构,并计算标准化的邻接矩阵。
因为考虑到火车票或者其他票据信息,关联关系一般为左右结构,或者上下结构,比如 起点-终点(左右),总金额:¥100.00(上下或左右),所以我们从当前节点开始,向右和向下找最近的关联节点,建立无向图。
代码示例
1、创建配置文件
# config.py import os ROOT_PATH = os.path.dirname(__file__) TRAIN_CSV_DIR = ROOT_PATH + '/output/train/csv_label/' TRAIN_GRAPH_DIR = ROOT_PATH + '/output/train/graph/' TEST_CSV_DIR = ROOT_PATH + '/output/test/csv_label/' TEST_GRAPH_DIR = ROOT_PATH + '/output/test/graph/'
2、创建图生成文件
# process/graph.py
from glob import glob
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sys
from tqdm import tqdm
sys.path.append('..')
from config import *
3、构建图结构
当前节点与节点右边、下边最近的节点,建立无向边。
class Graph():
# 创建链接
def connect(self, file_path):
graph_dict = {}
df = pd.read_csv(file_path, index_col=0)
for src_idx, src_row in df.iterrows():
neighbor_x = [] # 同一行节点
neighbor_y = [] # 同一列节点
# 再次遍历,两两比较
for dest_idx, dest_row in df.iterrows():
if src_idx == dest_idx:
continue
# 右边的节点
if src_row.x2 < dest_row.x1 and \
src_row.y1 < dest_row.y2 and src_row.y2 > dest_row.y1:
# (距离, 节点id),距离在前方便直接比较大小
neighbor_x.append((dest_row.x1 - src_row.x2, dest_idx))
# 下边的节点
if src_row.y2 < dest_row.y1 and \
src_row.x1 < dest_row.x2 and src_row.x2 > dest_row.x1:
neighbor_y.append((dest_row.y1 - src_row.y2, dest_idx))
# 取最近的节点,其他的忽略
min_x = [min(neighbor_x)[1]] if neighbor_x else []
min_y = [min(neighbor_y)[1]] if neighbor_y else []
graph_dict[src_idx] = min_x + min_y
# 过滤空节点
graph_dict = {k: v for k, v in graph_dict.items() if v}
# 找出孤立点,键和值中都未出现过
node_idx = set(graph_dict.keys())
node_idx.update([i for v in graph_dict.values() for i in v])
loss_idx = set(df.index) - node_idx
return graph_dict, list(loss_idx)
4、图结构可视化
# 画图 G = nx.from_dict_of_lists(graph_dict) fig, ax = plt.subplots() nx.draw(G, ax=ax, with_labels=True) # show node label plt.show() exit()
5、计算A矩阵
参考文档:https://ai.plainenglish.io/graph-convolutional-networks-gcn-baf337d5cb6b
class Graph():
# 计算A矩阵
def get_adjacency_norm(self, graph_dict):
G = nx.from_dict_of_lists(graph_dict)
A = nx.adjacency_matrix(G)
A_new = A + np.eye(*A.shape)
D = np.array(A_new.sum(1)).flatten()
# D^-0.5 A D^-0.5
return np.diag(D**(-0.5)) @ A_new @ np.diag(D**(-0.5))
6、生成缓存文件
if __name__ == '__main__':
graph = Graph()
for file_path in tqdm(glob(TRAIN_CSV_DIR + '*.csv')):
graph_dict, loss_idx = graph.connect(file_path)
adj = graph.get_adjacency_norm(graph_dict)
file_name = os.path.split(file_path)[1][:-3] + 'pkl'
file_dump([adj, loss_idx], TRAIN_GRAPH_DIR + file_name)
for file_path in tqdm(glob(TEST_CSV_DIR + '*.csv')):
graph_dict, loss_idx = graph.connect(file_path)
adj = graph.get_adjacency_norm(graph_dict)
file_name = os.path.split(file_path)[1][:-3] + 'pkl'
file_dump([adj, loss_idx], TEST_GRAPH_DIR + file_name)
7、文件生成和读取方法
# utils.py
import pickle
def file_dump(obj, file_path):
pickle.dump(obj, open(file_path, 'wb'))
def file_load(file_path):
return pickle.load(open(file_path, 'rb'))
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/329